Fix leaked server connections (#2596)

Fix leaked server connections

close unused connections in the client lib
close deprecated connection in the server lib
The Server Picker is reusable in the guard if we want in the future. So we can support the server address changes.

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* Add logging

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
This commit is contained in:
Zoltan Papp 2024-09-16 16:11:10 +02:00 committed by GitHub
parent 6c50b0c84b
commit 97e10e440c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 187 additions and 78 deletions

View File

@ -142,7 +142,7 @@ type Client struct {
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID) hashedID, hashedStringId := messages.HashID(peerID)
return &Client{ return &Client{
log: log.WithField("client_id", hashedStringId), log: log.WithFields(log.Fields{"client_id": hashedStringId, "relay": serverURL}),
parentCtx: ctx, parentCtx: ctx,
connectionURL: serverURL, connectionURL: serverURL,
authTokenStore: authTokenStore, authTokenStore: authTokenStore,
@ -159,7 +159,7 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error { func (c *Client) Connect() error {
c.log.Infof("connecting to relay server: %s", c.connectionURL) c.log.Infof("connecting to relay server")
c.readLoopMutex.Lock() c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock() defer c.readLoopMutex.Unlock()
@ -180,7 +180,7 @@ func (c *Client) Connect() error {
c.wgReadLoop.Add(1) c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn) go c.readLoop(c.relayConn)
c.log.Infof("relay connection established with: %s", c.connectionURL) c.log.Infof("relay connection established")
return nil return nil
} }
@ -202,7 +202,7 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
return nil, ErrConnAlreadyExists return nil, ErrConnAlreadyExists
} }
log.Infof("open connection to peer: %s", hashedStringID) c.log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 2) msgChannel := make(chan Msg, 2)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL) conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
@ -250,7 +250,7 @@ func (c *Client) connect() error {
if err != nil { if err != nil {
cErr := conn.Close() cErr := conn.Close()
if cErr != nil { if cErr != nil {
log.Errorf("failed to close connection: %s", cErr) c.log.Errorf("failed to close connection: %s", cErr)
} }
return err return err
} }
@ -261,19 +261,19 @@ func (c *Client) connect() error {
func (c *Client) handShake() error { func (c *Client) handShake() error {
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil { if err != nil {
log.Errorf("failed to marshal auth message: %s", err) c.log.Errorf("failed to marshal auth message: %s", err)
return err return err
} }
_, err = c.relayConn.Write(msg) _, err = c.relayConn.Write(msg)
if err != nil { if err != nil {
log.Errorf("failed to send auth message: %s", err) c.log.Errorf("failed to send auth message: %s", err)
return err return err
} }
buf := make([]byte, messages.MaxHandshakeRespSize) buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(buf) n, err := c.readWithTimeout(buf)
if err != nil { if err != nil {
log.Errorf("failed to read auth response: %s", err) c.log.Errorf("failed to read auth response: %s", err)
return err return err
} }
@ -284,12 +284,12 @@ func (c *Client) handShake() error {
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil { if err != nil {
log.Errorf("failed to determine message type: %s", err) c.log.Errorf("failed to determine message type: %s", err)
return err return err
} }
if msgType != messages.MsgTypeAuthResponse { if msgType != messages.MsgTypeAuthResponse {
log.Errorf("unexpected message type: %s", msgType) c.log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type") return fmt.Errorf("unexpected message type")
} }
@ -318,6 +318,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
buf := *bufPtr buf := *bufPtr
n, errExit = relayConn.Read(buf) n, errExit = relayConn.Read(buf)
if errExit != nil { if errExit != nil {
c.log.Infof("start to Relay read loop exit")
c.mu.Lock() c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() { if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit) c.log.Debugf("failed to read message from relay server: %s", errExit)
@ -364,7 +365,7 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
case messages.MsgTypeTransport: case messages.MsgTypeTransport:
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag) return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
case messages.MsgTypeClose: case messages.MsgTypeClose:
log.Debugf("relay connection close by server") c.log.Debugf("relay connection close by server")
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
return false return false
} }
@ -433,14 +434,14 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
// todo: use buffer pool instead of create new transport msg. // todo: use buffer pool instead of create new transport msg.
msg, err := messages.MarshalTransportMsg(dstID, payload) msg, err := messages.MarshalTransportMsg(dstID, payload)
if err != nil { if err != nil {
log.Errorf("failed to marshal transport message: %s", err) c.log.Errorf("failed to marshal transport message: %s", err)
return 0, err return 0, err
} }
// the write always return with 0 length because the underling does not support the size feedback. // the write always return with 0 length because the underling does not support the size feedback.
_, err = c.relayConn.Write(msg) _, err = c.relayConn.Write(msg)
if err != nil { if err != nil {
log.Errorf("failed to write transport message: %s", err) c.log.Errorf("failed to write transport message: %s", err)
} }
return len(payload), err return len(payload), err
} }
@ -459,7 +460,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
case <-c.parentCtx.Done(): case <-c.parentCtx.Done():
err := c.close(true) err := c.close(true)
if err != nil { if err != nil {
log.Errorf("failed to teardown connection: %s", err) c.log.Errorf("failed to teardown connection: %s", err)
} }
return return
} }
@ -499,10 +500,12 @@ func (c *Client) close(gracefullyExit bool) error {
var err error var err error
if !c.serviceIsRunning { if !c.serviceIsRunning {
c.mu.Unlock() c.mu.Unlock()
c.log.Warn("relay connection was already marked as not running")
return nil return nil
} }
c.serviceIsRunning = false c.serviceIsRunning = false
c.log.Infof("closing all peer connections")
c.closeAllConns() c.closeAllConns()
if gracefullyExit { if gracefullyExit {
c.writeCloseMsg() c.writeCloseMsg()
@ -510,8 +513,9 @@ func (c *Client) close(gracefullyExit bool) error {
err = c.relayConn.Close() err = c.relayConn.Close()
c.mu.Unlock() c.mu.Unlock()
c.log.Infof("waiting for read loop to close")
c.wgReadLoop.Wait() c.wgReadLoop.Wait()
c.log.Infof("relay connection closed with: %s", c.connectionURL) c.log.Infof("relay connection closed")
return err return err
} }

View File

@ -3,7 +3,6 @@ package client
import ( import (
"container/list" "container/list"
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"reflect" "reflect"
@ -17,8 +16,6 @@ import (
var ( var (
relayCleanupInterval = 60 * time.Second relayCleanupInterval = 60 * time.Second
connectionTimeout = 30 * time.Second
maxConcurrentServers = 7
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected") ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
) )
@ -92,31 +89,15 @@ func (m *Manager) Serve() error {
} }
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs) log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
totalServers := len(m.serverURLs) sp := ServerPicker{
TokenStore: m.tokenStore,
successChan := make(chan *Client, 1) PeerID: m.peerID,
errChan := make(chan error, len(m.serverURLs))
ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout)
defer cancel()
sem := make(chan struct{}, maxConcurrentServers)
for _, url := range m.serverURLs {
sem <- struct{}{}
go func(url string) {
defer func() { <-sem }()
m.connect(m.ctx, url, successChan, errChan)
}(url)
} }
var errCount int client, err := sp.PickServer(m.ctx, m.serverURLs)
if err != nil {
for { return err
select { }
case client := <-successChan:
log.Infof("Successfully connected to relay server: %s", client.connectionURL)
m.relayClient = client m.relayClient = client
m.reconnectGuard = NewGuard(m.ctx, m.relayClient) m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
@ -125,34 +106,6 @@ func (m *Manager) Serve() error {
}) })
m.startCleanupLoop() m.startCleanupLoop()
return nil return nil
case err := <-errChan:
errCount++
log.Warnf("Connection attempt failed: %v", err)
if errCount == totalServers {
return errors.New("failed to connect to any relay server: all attempts failed")
}
case <-ctx.Done():
return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
}
}
}
func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) {
// TODO: abort the connection if another connection was successful
relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID)
if err := relayClient.Connect(); err != nil {
errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err)
return
}
select {
case successChan <- relayClient:
// This client was the first to connect successfully
default:
if err := relayClient.Close(); err != nil {
log.Debugf("failed to close relay client: %s", err)
}
}
} }
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be

94
relay/client/picker.go Normal file
View File

@ -0,0 +1,94 @@
package client
import (
"context"
"errors"
"fmt"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
)
const (
connectionTimeout = 30 * time.Second
maxConcurrentServers = 7
)
type connResult struct {
RelayClient *Client
Url string
Err error
}
type ServerPicker struct {
TokenStore *auth.TokenStore
PeerID string
}
func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) {
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
defer cancel()
totalServers := len(urls)
connResultChan := make(chan connResult, totalServers)
successChan := make(chan connResult, 1)
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
for _, url := range urls {
concurrentLimiter <- struct{}{}
go func(url string) {
defer func() { <-concurrentLimiter }()
sp.startConnection(parentCtx, connResultChan, url)
}(url)
}
go sp.processConnResults(connResultChan, successChan)
select {
case cr, ok := <-successChan:
if !ok {
return nil, errors.New("failed to connect to any relay server: all attempts failed")
}
log.Infof("chosen home Relay server: %s", cr.Url)
return cr.RelayClient, nil
case <-ctx.Done():
return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
}
}
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID)
err := relayClient.Connect()
resultChan <- connResult{
RelayClient: relayClient,
Url: url,
Err: err,
}
}
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
var hasSuccess bool
for cr := range resultChan {
if cr.Err != nil {
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
continue
}
log.Infof("connected to Relay server: %s", cr.Url)
if hasSuccess {
log.Infof("closing unnecessary Relay connection to: %s", cr.Url)
if err := cr.RelayClient.Close(); err != nil {
log.Errorf("failed to close connection to %s: %v", cr.Url, err)
}
continue
}
hasSuccess = true
successChan <- cr
}
close(successChan)
}

View File

@ -115,6 +115,7 @@ func (p *Peer) Write(b []byte) (int, error) {
// connection. // connection.
func (p *Peer) CloseGracefully(ctx context.Context) { func (p *Peer) CloseGracefully(ctx context.Context) {
p.connMu.Lock() p.connMu.Lock()
defer p.connMu.Unlock()
err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg()) err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
if err != nil { if err != nil {
p.log.Errorf("failed to send close message to peer: %s", p.String()) p.log.Errorf("failed to send close message to peer: %s", p.String())
@ -124,8 +125,15 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
if err != nil { if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err) p.log.Errorf("failed to close connection to peer: %s", err)
} }
}
func (p *Peer) Close() {
p.connMu.Lock()
defer p.connMu.Unlock() defer p.connMu.Unlock()
if err := p.conn.Close(); err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
} }
// String returns the peer ID // String returns the peer ID
@ -167,6 +175,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
if err != nil { if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err) p.log.Errorf("failed to close connection to peer: %s", err)
} }
p.log.Info("peer connection closed due healthcheck timeout")
return return
case <-ctx.Done(): case <-ctx.Done():
return return

View File

@ -19,10 +19,14 @@ func NewStore() *Store {
} }
// AddPeer adds a peer to the store // AddPeer adds a peer to the store
// todo: consider to close peer conn if the peer already exists
func (s *Store) AddPeer(peer *Peer) { func (s *Store) AddPeer(peer *Peer) {
s.peersLock.Lock() s.peersLock.Lock()
defer s.peersLock.Unlock() defer s.peersLock.Unlock()
odlPeer, ok := s.peers[peer.String()]
if ok {
odlPeer.Close()
}
s.peers[peer.String()] = peer s.peers[peer.String()] = peer
} }

View File

@ -2,13 +2,57 @@ package server
import ( import (
"context" "context"
"net"
"testing" "testing"
"time"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
) )
type mockConn struct {
}
func (m mockConn) Read(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Write(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Close() error {
return nil
}
func (m mockConn) LocalAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) RemoteAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func TestStore_DeletePeer(t *testing.T) { func TestStore_DeletePeer(t *testing.T) {
s := NewStore() s := NewStore()
@ -27,8 +71,9 @@ func TestStore_DeleteDeprecatedPeer(t *testing.T) {
m, _ := metrics.NewMetrics(context.Background(), otel.Meter("")) m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
p1 := NewPeer(m, []byte("peer_id"), nil, nil) conn := &mockConn{}
p2 := NewPeer(m, []byte("peer_id"), nil, nil) p1 := NewPeer(m, []byte("peer_id"), conn, nil)
p2 := NewPeer(m, []byte("peer_id"), conn, nil)
s.AddPeer(p1) s.AddPeer(p1)
s.AddPeer(p2) s.AddPeer(p2)