zrepl/internal/rpc/rpc_server.go
2024-10-18 19:21:17 +02:00

133 lines
4.7 KiB
Go

package rpc
import (
"context"
"time"
"github.com/zrepl/zrepl/internal/endpoint"
"github.com/zrepl/zrepl/internal/replication/logic/pdu"
"github.com/zrepl/zrepl/internal/rpc/dataconn"
"github.com/zrepl/zrepl/internal/rpc/grpcclientidentity"
"github.com/zrepl/zrepl/internal/rpc/grpcclientidentity/grpchelper"
"github.com/zrepl/zrepl/internal/rpc/versionhandshake"
"github.com/zrepl/zrepl/internal/transport"
"github.com/zrepl/zrepl/internal/util/envconst"
)
type Handler interface {
pdu.ReplicationServer
dataconn.Handler
}
type serveFunc func(ctx context.Context, demuxedListener transport.AuthenticatedListener, errOut chan<- error)
// Server abstracts the accept and request routing infrastructure for the
// passive side of a replication setup.
type Server struct {
logger Logger
handler Handler
controlServerServe serveFunc
dataServer *dataconn.Server
dataServerServe serveFunc
}
type HandlerContextInterceptorData interface {
FullMethod() string
ClientIdentity() string
}
type interceptorData struct {
prefixMethod string
wrapped HandlerContextInterceptorData
}
func (d interceptorData) ClientIdentity() string { return d.wrapped.ClientIdentity() }
func (d interceptorData) FullMethod() string { return d.prefixMethod + d.wrapped.FullMethod() }
type HandlerContextInterceptor func(ctx context.Context, data HandlerContextInterceptorData, handler func(ctx context.Context))
// config must be valid (use its Validate function).
func NewServer(handler Handler, loggers Loggers, ctxInterceptor HandlerContextInterceptor) *Server {
// setup control server
controlServerServe := func(ctx context.Context, controlListener transport.AuthenticatedListener, errOut chan<- error) {
var controlCtxInterceptor grpcclientidentity.Interceptor = func(ctx context.Context, data grpcclientidentity.ContextInterceptorData, handler func(ctx context.Context)) {
ctxInterceptor(ctx, interceptorData{"control://", data}, handler)
}
controlServer, serve := grpchelper.NewServer(controlListener, endpoint.ClientIdentityKey, loggers.Control, controlCtxInterceptor)
pdu.RegisterReplicationServer(controlServer, handler)
// give time for graceful stop until deadline expires, then hard stop
go func() {
<-ctx.Done()
if dl, ok := ctx.Deadline(); ok {
go time.AfterFunc(dl.Sub(dl), controlServer.Stop)
}
loggers.Control.Debug("gracefully shutting down control server")
controlServer.GracefulStop()
loggers.Control.Debug("gracdeful shut down of control server complete")
}()
errOut <- serve()
}
// setup data server
dataServerClientIdentitySetter := func(ctx context.Context, wire *transport.AuthConn) (context.Context, *transport.AuthConn) {
ci := wire.ClientIdentity()
ctx = context.WithValue(ctx, endpoint.ClientIdentityKey, ci)
return ctx, wire
}
var dataCtxInterceptor dataconn.ContextInterceptor = func(ctx context.Context, data dataconn.ContextInterceptorData, handler func(ctx context.Context)) {
ctxInterceptor(ctx, interceptorData{"data://", data}, handler)
}
dataServer := dataconn.NewServer(dataServerClientIdentitySetter, dataCtxInterceptor, loggers.Data, handler)
dataServerServe := func(ctx context.Context, dataListener transport.AuthenticatedListener, errOut chan<- error) {
dataServer.Serve(ctx, dataListener)
errOut <- nil // TODO bad design of dataServer?
}
server := &Server{
logger: loggers.General,
handler: handler,
controlServerServe: controlServerServe,
dataServer: dataServer,
dataServerServe: dataServerServe,
}
return server
}
// The context is used for cancellation only.
// Serve never returns an error, it logs them to the Server's logger.
func (s *Server) Serve(ctx context.Context, l transport.AuthenticatedListener) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
defer s.logger.Debug("rpc.(*Server).Serve done")
l = versionhandshake.Listener(l, envconst.Duration("ZREPL_RPC_SERVER_VERSIONHANDSHAKE_TIMEOUT", 10*time.Second))
// it is important that demux's context is cancelled,
// it has background goroutines attached
demuxListener := demux(ctx, l)
serveErrors := make(chan error, 2)
go s.controlServerServe(ctx, demuxListener.control, serveErrors)
go s.dataServerServe(ctx, demuxListener.data, serveErrors)
select {
case serveErr := <-serveErrors:
s.logger.WithError(serveErr).Error("serve error")
s.logger.Debug("wait for other server to shut down")
cancel()
secondServeErr := <-serveErrors
s.logger.WithError(secondServeErr).Error("serve error")
case <-ctx.Done():
s.logger.Debug("context cancelled, wait for control and data servers")
cancel()
for i := 0; i < 2; i++ {
<-serveErrors
}
s.logger.Debug("control and data server shut down, returning from Serve")
}
}