package rpc import ( "context" "time" "google.golang.org/grpc" "github.com/zrepl/zrepl/endpoint" "github.com/zrepl/zrepl/replication/logic/pdu" "github.com/zrepl/zrepl/rpc/dataconn" "github.com/zrepl/zrepl/rpc/grpcclientidentity" "github.com/zrepl/zrepl/rpc/netadaptor" "github.com/zrepl/zrepl/rpc/versionhandshake" "github.com/zrepl/zrepl/transport" "github.com/zrepl/zrepl/util/envconst" ) type Handler interface { pdu.ReplicationServer dataconn.Handler } type serveFunc func(ctx context.Context, demuxedListener transport.AuthenticatedListener, errOut chan<- error) // Server abstracts the accept and request routing infrastructure for the // passive side of a replication setup. type Server struct { logger Logger handler Handler controlServer *grpc.Server controlServerServe serveFunc dataServer *dataconn.Server dataServerServe serveFunc } type HandlerContextInterceptor func(ctx context.Context) context.Context // config must be valid (use its Validate function). func NewServer(handler Handler, loggers Loggers, ctxInterceptor HandlerContextInterceptor) *Server { // setup control server tcs := grpcclientidentity.NewTransportCredentials(loggers.Control) // TODO different subsystem for log unary, stream := grpcclientidentity.NewInterceptors(loggers.Control, endpoint.ClientIdentityKey) controlServer := grpc.NewServer(grpc.Creds(tcs), grpc.UnaryInterceptor(unary), grpc.StreamInterceptor(stream)) pdu.RegisterReplicationServer(controlServer, handler) controlServerServe := func(ctx context.Context, controlListener transport.AuthenticatedListener, errOut chan<- error) { // give time for graceful stop until deadline expires, then hard stop go func() { <-ctx.Done() if dl, ok := ctx.Deadline(); ok { go time.AfterFunc(dl.Sub(dl), controlServer.Stop) } loggers.Control.Debug("shutting down control server") controlServer.GracefulStop() }() errOut <- controlServer.Serve(netadaptor.New(controlListener, loggers.Control)) } // setup data server dataServerClientIdentitySetter := func(ctx context.Context, wire *transport.AuthConn) (context.Context, *transport.AuthConn) { ci := wire.ClientIdentity() ctx = context.WithValue(ctx, endpoint.ClientIdentityKey, ci) if ctxInterceptor != nil { ctx = ctxInterceptor(ctx) // SHADOWING } return ctx, wire } dataServer := dataconn.NewServer(dataServerClientIdentitySetter, loggers.Data, handler) dataServerServe := func(ctx context.Context, dataListener transport.AuthenticatedListener, errOut chan<- error) { dataServer.Serve(ctx, dataListener) errOut <- nil // TODO bad design of dataServer? } server := &Server{ logger: loggers.General, handler: handler, controlServer: controlServer, controlServerServe: controlServerServe, dataServer: dataServer, dataServerServe: dataServerServe, } return server } // The context is used for cancellation only. // Serve never returns an error, it logs them to the Server's logger. func (s *Server) Serve(ctx context.Context, l transport.AuthenticatedListener) { ctx, cancel := context.WithCancel(ctx) l = versionhandshake.Listener(l, envconst.Duration("ZREPL_RPC_SERVER_VERSIONHANDSHAKE_TIMEOUT", 10*time.Second)) // it is important that demux's context is cancelled, // it has background goroutines attached demuxListener := demux(ctx, l) serveErrors := make(chan error, 2) go s.controlServerServe(ctx, demuxListener.control, serveErrors) go s.dataServerServe(ctx, demuxListener.data, serveErrors) select { case serveErr := <-serveErrors: s.logger.WithError(serveErr).Error("serve error") s.logger.Debug("wait for other server to shut down") cancel() secondServeErr := <-serveErrors s.logger.WithError(secondServeErr).Error("serve error") case <-ctx.Done(): s.logger.Debug("context cancelled, wait for control and data servers") cancel() for i := 0; i < 2; i++ { <-serveErrors } s.logger.Debug("control and data server shut down, returning from Serve") } }