From 97e10e440cf172e94921c983395c56252c0e429e Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 16 Sep 2024 16:11:10 +0200 Subject: [PATCH] Fix leaked server connections (#2596) Fix leaked server connections close unused connections in the client lib close deprecated connection in the server lib The Server Picker is reusable in the guard if we want in the future. So we can support the server address changes. --------- Co-authored-by: Maycon Santos * Add logging --------- Co-authored-by: Maycon Santos --- relay/client/client.go | 34 ++++++++------ relay/client/manager.go | 73 ++++++----------------------- relay/client/picker.go | 94 ++++++++++++++++++++++++++++++++++++++ relay/server/peer.go | 9 ++++ relay/server/store.go | 6 ++- relay/server/store_test.go | 49 +++++++++++++++++++- 6 files changed, 187 insertions(+), 78 deletions(-) create mode 100644 relay/client/picker.go diff --git a/relay/client/client.go b/relay/client/client.go index 7ff17944f..3e5c0ba24 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -142,7 +142,7 @@ type Client struct { func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { hashedID, hashedStringId := messages.HashID(peerID) return &Client{ - log: log.WithField("client_id", hashedStringId), + log: log.WithFields(log.Fields{"client_id": hashedStringId, "relay": serverURL}), parentCtx: ctx, connectionURL: serverURL, authTokenStore: authTokenStore, @@ -159,7 +159,7 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. func (c *Client) Connect() error { - c.log.Infof("connecting to relay server: %s", c.connectionURL) + c.log.Infof("connecting to relay server") c.readLoopMutex.Lock() defer c.readLoopMutex.Unlock() @@ -180,7 +180,7 @@ func (c *Client) Connect() error { c.wgReadLoop.Add(1) go c.readLoop(c.relayConn) - c.log.Infof("relay connection established with: %s", c.connectionURL) + c.log.Infof("relay connection established") return nil } @@ -202,7 +202,7 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { return nil, ErrConnAlreadyExists } - log.Infof("open connection to peer: %s", hashedStringID) + c.log.Infof("open connection to peer: %s", hashedStringID) msgChannel := make(chan Msg, 2) conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL) @@ -250,7 +250,7 @@ func (c *Client) connect() error { if err != nil { cErr := conn.Close() if cErr != nil { - log.Errorf("failed to close connection: %s", cErr) + c.log.Errorf("failed to close connection: %s", cErr) } return err } @@ -261,19 +261,19 @@ func (c *Client) connect() error { func (c *Client) handShake() error { msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) if err != nil { - log.Errorf("failed to marshal auth message: %s", err) + c.log.Errorf("failed to marshal auth message: %s", err) return err } _, err = c.relayConn.Write(msg) if err != nil { - log.Errorf("failed to send auth message: %s", err) + c.log.Errorf("failed to send auth message: %s", err) return err } buf := make([]byte, messages.MaxHandshakeRespSize) n, err := c.readWithTimeout(buf) if err != nil { - log.Errorf("failed to read auth response: %s", err) + c.log.Errorf("failed to read auth response: %s", err) return err } @@ -284,12 +284,12 @@ func (c *Client) handShake() error { msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) if err != nil { - log.Errorf("failed to determine message type: %s", err) + c.log.Errorf("failed to determine message type: %s", err) return err } if msgType != messages.MsgTypeAuthResponse { - log.Errorf("unexpected message type: %s", msgType) + c.log.Errorf("unexpected message type: %s", msgType) return fmt.Errorf("unexpected message type") } @@ -318,6 +318,7 @@ func (c *Client) readLoop(relayConn net.Conn) { buf := *bufPtr n, errExit = relayConn.Read(buf) if errExit != nil { + c.log.Infof("start to Relay read loop exit") c.mu.Lock() if c.serviceIsRunning && !internallyStoppedFlag.isSet() { c.log.Debugf("failed to read message from relay server: %s", errExit) @@ -364,7 +365,7 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, case messages.MsgTypeTransport: return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag) case messages.MsgTypeClose: - log.Debugf("relay connection close by server") + c.log.Debugf("relay connection close by server") c.bufPool.Put(bufPtr) return false } @@ -433,14 +434,14 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [ // todo: use buffer pool instead of create new transport msg. msg, err := messages.MarshalTransportMsg(dstID, payload) if err != nil { - log.Errorf("failed to marshal transport message: %s", err) + c.log.Errorf("failed to marshal transport message: %s", err) return 0, err } // the write always return with 0 length because the underling does not support the size feedback. _, err = c.relayConn.Write(msg) if err != nil { - log.Errorf("failed to write transport message: %s", err) + c.log.Errorf("failed to write transport message: %s", err) } return len(payload), err } @@ -459,7 +460,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in case <-c.parentCtx.Done(): err := c.close(true) if err != nil { - log.Errorf("failed to teardown connection: %s", err) + c.log.Errorf("failed to teardown connection: %s", err) } return } @@ -499,10 +500,12 @@ func (c *Client) close(gracefullyExit bool) error { var err error if !c.serviceIsRunning { c.mu.Unlock() + c.log.Warn("relay connection was already marked as not running") return nil } c.serviceIsRunning = false + c.log.Infof("closing all peer connections") c.closeAllConns() if gracefullyExit { c.writeCloseMsg() @@ -510,8 +513,9 @@ func (c *Client) close(gracefullyExit bool) error { err = c.relayConn.Close() c.mu.Unlock() + c.log.Infof("waiting for read loop to close") c.wgReadLoop.Wait() - c.log.Infof("relay connection closed with: %s", c.connectionURL) + c.log.Infof("relay connection closed") return err } diff --git a/relay/client/manager.go b/relay/client/manager.go index a9d294160..4554c7c0f 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -3,7 +3,6 @@ package client import ( "container/list" "context" - "errors" "fmt" "net" "reflect" @@ -17,8 +16,6 @@ import ( var ( relayCleanupInterval = 60 * time.Second - connectionTimeout = 30 * time.Second - maxConcurrentServers = 7 ErrRelayClientNotConnected = fmt.Errorf("relay client not connected") ) @@ -92,67 +89,23 @@ func (m *Manager) Serve() error { } log.Debugf("starting relay client manager with %v relay servers", m.serverURLs) - totalServers := len(m.serverURLs) - - successChan := make(chan *Client, 1) - errChan := make(chan error, len(m.serverURLs)) - - ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout) - defer cancel() - - sem := make(chan struct{}, maxConcurrentServers) - - for _, url := range m.serverURLs { - sem <- struct{}{} - go func(url string) { - defer func() { <-sem }() - m.connect(m.ctx, url, successChan, errChan) - }(url) + sp := ServerPicker{ + TokenStore: m.tokenStore, + PeerID: m.peerID, } - var errCount int - - for { - select { - case client := <-successChan: - log.Infof("Successfully connected to relay server: %s", client.connectionURL) - - m.relayClient = client - - m.reconnectGuard = NewGuard(m.ctx, m.relayClient) - m.relayClient.SetOnDisconnectListener(func() { - m.onServerDisconnected(client.connectionURL) - }) - m.startCleanupLoop() - return nil - case err := <-errChan: - errCount++ - log.Warnf("Connection attempt failed: %v", err) - if errCount == totalServers { - return errors.New("failed to connect to any relay server: all attempts failed") - } - case <-ctx.Done(): - return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err()) - } + client, err := sp.PickServer(m.ctx, m.serverURLs) + if err != nil { + return err } -} + m.relayClient = client -func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) { - // TODO: abort the connection if another connection was successful - relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID) - if err := relayClient.Connect(); err != nil { - errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err) - return - } - - select { - case successChan <- relayClient: - // This client was the first to connect successfully - default: - if err := relayClient.Close(); err != nil { - log.Debugf("failed to close relay client: %s", err) - } - } + m.reconnectGuard = NewGuard(m.ctx, m.relayClient) + m.relayClient.SetOnDisconnectListener(func() { + m.onServerDisconnected(client.connectionURL) + }) + m.startCleanupLoop() + return nil } // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be diff --git a/relay/client/picker.go b/relay/client/picker.go new file mode 100644 index 000000000..b0888a4a0 --- /dev/null +++ b/relay/client/picker.go @@ -0,0 +1,94 @@ +package client + +import ( + "context" + "errors" + "fmt" + "time" + + log "github.com/sirupsen/logrus" + + auth "github.com/netbirdio/netbird/relay/auth/hmac" +) + +const ( + connectionTimeout = 30 * time.Second + maxConcurrentServers = 7 +) + +type connResult struct { + RelayClient *Client + Url string + Err error +} + +type ServerPicker struct { + TokenStore *auth.TokenStore + PeerID string +} + +func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) { + ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) + defer cancel() + + totalServers := len(urls) + + connResultChan := make(chan connResult, totalServers) + successChan := make(chan connResult, 1) + + concurrentLimiter := make(chan struct{}, maxConcurrentServers) + for _, url := range urls { + concurrentLimiter <- struct{}{} + go func(url string) { + defer func() { <-concurrentLimiter }() + sp.startConnection(parentCtx, connResultChan, url) + }(url) + } + + go sp.processConnResults(connResultChan, successChan) + + select { + case cr, ok := <-successChan: + if !ok { + return nil, errors.New("failed to connect to any relay server: all attempts failed") + } + log.Infof("chosen home Relay server: %s", cr.Url) + return cr.RelayClient, nil + case <-ctx.Done(): + return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err()) + } +} + +func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) { + log.Infof("try to connecting to relay server: %s", url) + relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID) + err := relayClient.Connect() + resultChan <- connResult{ + RelayClient: relayClient, + Url: url, + Err: err, + } +} + +func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) { + var hasSuccess bool + for cr := range resultChan { + if cr.Err != nil { + log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) + continue + } + log.Infof("connected to Relay server: %s", cr.Url) + + if hasSuccess { + log.Infof("closing unnecessary Relay connection to: %s", cr.Url) + if err := cr.RelayClient.Close(); err != nil { + log.Errorf("failed to close connection to %s: %v", cr.Url, err) + } + continue + } + + hasSuccess = true + successChan <- cr + } + close(successChan) +} diff --git a/relay/server/peer.go b/relay/server/peer.go index 0de601996..00341e98b 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -115,6 +115,7 @@ func (p *Peer) Write(b []byte) (int, error) { // connection. func (p *Peer) CloseGracefully(ctx context.Context) { p.connMu.Lock() + defer p.connMu.Unlock() err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg()) if err != nil { p.log.Errorf("failed to send close message to peer: %s", p.String()) @@ -124,8 +125,15 @@ func (p *Peer) CloseGracefully(ctx context.Context) { if err != nil { p.log.Errorf("failed to close connection to peer: %s", err) } +} +func (p *Peer) Close() { + p.connMu.Lock() defer p.connMu.Unlock() + + if err := p.conn.Close(); err != nil { + p.log.Errorf("failed to close connection to peer: %s", err) + } } // String returns the peer ID @@ -167,6 +175,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send if err != nil { p.log.Errorf("failed to close connection to peer: %s", err) } + p.log.Info("peer connection closed due healthcheck timeout") return case <-ctx.Done(): return diff --git a/relay/server/store.go b/relay/server/store.go index 96879dae1..4288e62c5 100644 --- a/relay/server/store.go +++ b/relay/server/store.go @@ -19,10 +19,14 @@ func NewStore() *Store { } // AddPeer adds a peer to the store -// todo: consider to close peer conn if the peer already exists func (s *Store) AddPeer(peer *Peer) { s.peersLock.Lock() defer s.peersLock.Unlock() + odlPeer, ok := s.peers[peer.String()] + if ok { + odlPeer.Close() + } + s.peers[peer.String()] = peer } diff --git a/relay/server/store_test.go b/relay/server/store_test.go index 4a30bc131..41c7baa92 100644 --- a/relay/server/store_test.go +++ b/relay/server/store_test.go @@ -2,13 +2,57 @@ package server import ( "context" + "net" "testing" + "time" "go.opentelemetry.io/otel" "github.com/netbirdio/netbird/relay/metrics" ) +type mockConn struct { +} + +func (m mockConn) Read(b []byte) (n int, err error) { + //TODO implement me + panic("implement me") +} + +func (m mockConn) Write(b []byte) (n int, err error) { + //TODO implement me + panic("implement me") +} + +func (m mockConn) Close() error { + return nil +} + +func (m mockConn) LocalAddr() net.Addr { + //TODO implement me + panic("implement me") +} + +func (m mockConn) RemoteAddr() net.Addr { + //TODO implement me + panic("implement me") +} + +func (m mockConn) SetDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + +func (m mockConn) SetReadDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + +func (m mockConn) SetWriteDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + func TestStore_DeletePeer(t *testing.T) { s := NewStore() @@ -27,8 +71,9 @@ func TestStore_DeleteDeprecatedPeer(t *testing.T) { m, _ := metrics.NewMetrics(context.Background(), otel.Meter("")) - p1 := NewPeer(m, []byte("peer_id"), nil, nil) - p2 := NewPeer(m, []byte("peer_id"), nil, nil) + conn := &mockConn{} + p1 := NewPeer(m, []byte("peer_id"), conn, nil) + p2 := NewPeer(m, []byte("peer_id"), conn, nil) s.AddPeer(p1) s.AddPeer(p2)