diff --git a/rpc/dataconn/dataconn_server.go b/rpc/dataconn/dataconn_server.go index aecaf96..4fa6899 100644 --- a/rpc/dataconn/dataconn_server.go +++ b/rpc/dataconn/dataconn_server.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "io" + "sync" "github.com/golang/protobuf/proto" @@ -50,16 +51,26 @@ func NewServer(wi WireInterceptor, logger Logger, handler Handler) *Server { // No accept errors are returned: they are logged to the Logger passed // to the constructor. func (s *Server) Serve(ctx context.Context, l transport.AuthenticatedListener) { + var wg sync.WaitGroup + defer wg.Wait() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + wg.Add(1) go func() { + defer wg.Done() <-ctx.Done() - s.log.Debug("context done") + s.log.Debug("context done, closing listener") if err := l.Close(); err != nil { s.log.WithError(err).Error("cannot close listener") } }() conns := make(chan *transport.AuthConn) + wg.Add(1) go func() { + defer wg.Done() + defer close(conns) for { conn, err := l.Accept(ctx) if err != nil { @@ -74,7 +85,11 @@ func (s *Server) Serve(ctx context.Context, l transport.AuthenticatedListener) { } }() for conn := range conns { - go s.serveConn(conn) + wg.Add(1) + go func(conn *transport.AuthConn) { + defer wg.Done() + s.serveConn(conn) + }(conn) } } diff --git a/rpc/netadaptor/authlistener_netlistener_adaptor.go b/rpc/netadaptor/authlistener_netlistener_adaptor.go index 05f9611..b15259d 100644 --- a/rpc/netadaptor/authlistener_netlistener_adaptor.go +++ b/rpc/netadaptor/authlistener_netlistener_adaptor.go @@ -33,8 +33,12 @@ import ( type Logger = logger.Logger +type acceptRes struct { + conn *transport.AuthConn + err error +} type acceptReq struct { - callback chan net.Conn + callback chan acceptRes } type Listener struct { @@ -64,10 +68,19 @@ func New(authListener transport.AuthenticatedListener, l Logger) *Listener { // The returned net.Conn is guaranteed to be *transport.AuthConn, i.e., the type of connection // returned by the wrapped transport.AuthenticatedListener. func (a Listener) Accept() (net.Conn, error) { - req := acceptReq{make(chan net.Conn, 1)} - a.accepts <- req - conn := <-req.callback - return conn, nil + req := acceptReq{make(chan acceptRes, 1)} + + select { + case a.accepts <- req: + case <-a.stop: + return nil, fmt.Errorf("already closed") // TODO net.Error + } + + res, ok := <-req.callback + if !ok { + return nil, fmt.Errorf("already closed") // TODO net.Error + } + return res.conn, res.err } func (a Listener) handleAccept() { @@ -77,18 +90,9 @@ func (a Listener) handleAccept() { a.logger.Debug("handleAccept stop accepting") return case req := <-a.accepts: - for { - a.logger.Debug("accept authListener") - authConn, err := a.al.Accept(context.Background()) - if err != nil { - a.logger.WithError(err).Error("accept error") - continue - } - a.logger.WithField("type", fmt.Sprintf("%T", authConn)). - Debug("accept complete") - req.callback <- authConn - break - } + a.logger.Debug("accept authListener") + authConn, err := a.al.Accept(context.Background()) + req.callback <- acceptRes{authConn, err} } } } diff --git a/rpc/rpc_server.go b/rpc/rpc_server.go index 69e6335..7363860 100644 --- a/rpc/rpc_server.go +++ b/rpc/rpc_server.go @@ -47,8 +47,9 @@ func NewServer(handler Handler, loggers Loggers, ctxInterceptor HandlerContextIn if dl, ok := ctx.Deadline(); ok { go time.AfterFunc(dl.Sub(dl), controlServer.Stop) } - loggers.Control.Debug("shutting down control server") + loggers.Control.Debug("gracefully shutting down control server") controlServer.GracefulStop() + loggers.Control.Debug("gracdeful shut down of control server complete") }() errOut <- serve() @@ -84,6 +85,8 @@ func NewServer(handler Handler, loggers Loggers, ctxInterceptor HandlerContextIn // Serve never returns an error, it logs them to the Server's logger. func (s *Server) Serve(ctx context.Context, l transport.AuthenticatedListener) { ctx, cancel := context.WithCancel(ctx) + defer cancel() + defer s.logger.Debug("rpc.(*Server).Serve done") l = versionhandshake.Listener(l, envconst.Duration("ZREPL_RPC_SERVER_VERSIONHANDSHAKE_TIMEOUT", 10*time.Second)) diff --git a/rpc/transportmux/transportmux.go b/rpc/transportmux/transportmux.go index 90ec48c..0f6f69a 100644 --- a/rpc/transportmux/transportmux.go +++ b/rpc/transportmux/transportmux.go @@ -7,6 +7,8 @@ package transportmux import ( "context" + "sync/atomic" + "syscall" "fmt" "io" @@ -42,12 +44,31 @@ type acceptRes struct { } type demuxListener struct { - conns chan acceptRes + closed int32 + conns chan acceptRes +} + +var ErrClosed = &net.OpError{ + Op: "accept", + Net: "demux", + Source: nil, + Addr: nil, + Err: syscall.EINVAL, } func (l *demuxListener) Accept(ctx context.Context) (*transport.AuthConn, error) { - res := <-l.conns - return res.conn, res.err + if atomic.LoadInt32(&l.closed) != 0 { + return nil, ErrClosed + } + select { + case r, ok := <-l.conns: + if !ok { + return nil, ErrClosed + } + return r.conn, r.err + case <-ctx.Done(): + return nil, ctx.Err() + } } type demuxAddr struct{} @@ -59,7 +80,10 @@ func (l *demuxListener) Addr() net.Addr { return demuxAddr{} } -func (l *demuxListener) Close() error { return nil } // TODO +func (l *demuxListener) Close() error { + atomic.StoreInt32(&l.closed, 1) + return nil +} // Exact length of a label in bytes (0-byte padded if it is shorter). // This is a protocol constant, changing it breaks the wire protocol. @@ -90,7 +114,10 @@ func Demux(ctx context.Context, rawListener transport.AuthenticatedListener, lab if _, ok := padded[labelPadded]; ok { return nil, fmt.Errorf("duplicate label %q", label) } - dl := &demuxListener{make(chan acceptRes)} + dl := &demuxListener{ + closed: 0, + conns: make(chan acceptRes, 1), + } padded[labelPadded] = dl ret[label] = dl } @@ -103,10 +130,37 @@ func Demux(ctx context.Context, rawListener transport.AuthenticatedListener, lab if err := rawListener.Close(); err != nil { getLog(ctx).WithError(err).Error("error closing listener") } + + drainConns := func(ch chan acceptRes) { + for c := range ch { + if c.conn != nil { + if err := c.conn.Close(); err != nil { + getLog(ctx).WithError(err).Error("error closing connection while draining after listener was closed") + } + } + } + } + for _, dl := range ret { + atomic.StoreInt32(&dl.(*demuxListener).closed, 1) + drainConns(dl.(*demuxListener).conns) + } }() go func() { + defer func() { + for _, dl := range ret { + close(dl.(*demuxListener).conns) + } + }() + for { + select { + case <-ctx.Done(): + getLog(ctx).WithError(ctx.Err()).Info("stop accepting new connections after context done") + return + default: + } + rawConn, err := rawListener.Accept(ctx) if err != nil { if ctx.Err() != nil { @@ -147,7 +201,6 @@ func Demux(ctx context.Context, rawListener transport.AuthenticatedListener, lab if err != nil { getLog(ctx).WithError(err).Error("cannot reset deadline") } - // blocking is intentional demuxListener.conns <- acceptRes{conn: rawConn, err: nil} } }() diff --git a/transport/tcp/serve_tcp.go b/transport/tcp/serve_tcp.go index b9627dd..6a7679e 100644 --- a/transport/tcp/serve_tcp.go +++ b/transport/tcp/serve_tcp.go @@ -65,6 +65,12 @@ type TCPAuthListener struct { } func (f *TCPAuthListener) Accept(ctx context.Context) (*transport.AuthConn, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + <-ctx.Done() + cancel() + }() nc, err := f.TCPListener.AcceptTCP() if err != nil { return nil, err