mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-24 17:13:30 +01:00
Fix leaked server connections (#2596)
Fix leaked server connections close unused connections in the client lib close deprecated connection in the server lib The Server Picker is reusable in the guard if we want in the future. So we can support the server address changes. --------- Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> * Add logging --------- Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
This commit is contained in:
parent
6c50b0c84b
commit
97e10e440c
@ -142,7 +142,7 @@ type Client struct {
|
||||
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
||||
hashedID, hashedStringId := messages.HashID(peerID)
|
||||
return &Client{
|
||||
log: log.WithField("client_id", hashedStringId),
|
||||
log: log.WithFields(log.Fields{"client_id": hashedStringId, "relay": serverURL}),
|
||||
parentCtx: ctx,
|
||||
connectionURL: serverURL,
|
||||
authTokenStore: authTokenStore,
|
||||
@ -159,7 +159,7 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
|
||||
|
||||
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
||||
func (c *Client) Connect() error {
|
||||
c.log.Infof("connecting to relay server: %s", c.connectionURL)
|
||||
c.log.Infof("connecting to relay server")
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
@ -180,7 +180,7 @@ func (c *Client) Connect() error {
|
||||
c.wgReadLoop.Add(1)
|
||||
go c.readLoop(c.relayConn)
|
||||
|
||||
c.log.Infof("relay connection established with: %s", c.connectionURL)
|
||||
c.log.Infof("relay connection established")
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -202,7 +202,7 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
||||
return nil, ErrConnAlreadyExists
|
||||
}
|
||||
|
||||
log.Infof("open connection to peer: %s", hashedStringID)
|
||||
c.log.Infof("open connection to peer: %s", hashedStringID)
|
||||
msgChannel := make(chan Msg, 2)
|
||||
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
|
||||
|
||||
@ -250,7 +250,7 @@ func (c *Client) connect() error {
|
||||
if err != nil {
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close connection: %s", cErr)
|
||||
c.log.Errorf("failed to close connection: %s", cErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@ -261,19 +261,19 @@ func (c *Client) connect() error {
|
||||
func (c *Client) handShake() error {
|
||||
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal auth message: %s", err)
|
||||
c.log.Errorf("failed to marshal auth message: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to send auth message: %s", err)
|
||||
c.log.Errorf("failed to send auth message: %s", err)
|
||||
return err
|
||||
}
|
||||
buf := make([]byte, messages.MaxHandshakeRespSize)
|
||||
n, err := c.readWithTimeout(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read auth response: %s", err)
|
||||
c.log.Errorf("failed to read auth response: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -284,12 +284,12 @@ func (c *Client) handShake() error {
|
||||
|
||||
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to determine message type: %s", err)
|
||||
c.log.Errorf("failed to determine message type: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if msgType != messages.MsgTypeAuthResponse {
|
||||
log.Errorf("unexpected message type: %s", msgType)
|
||||
c.log.Errorf("unexpected message type: %s", msgType)
|
||||
return fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
@ -318,6 +318,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
||||
buf := *bufPtr
|
||||
n, errExit = relayConn.Read(buf)
|
||||
if errExit != nil {
|
||||
c.log.Infof("start to Relay read loop exit")
|
||||
c.mu.Lock()
|
||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
||||
@ -364,7 +365,7 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
|
||||
case messages.MsgTypeTransport:
|
||||
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
|
||||
case messages.MsgTypeClose:
|
||||
log.Debugf("relay connection close by server")
|
||||
c.log.Debugf("relay connection close by server")
|
||||
c.bufPool.Put(bufPtr)
|
||||
return false
|
||||
}
|
||||
@ -433,14 +434,14 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
|
||||
// todo: use buffer pool instead of create new transport msg.
|
||||
msg, err := messages.MarshalTransportMsg(dstID, payload)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal transport message: %s", err)
|
||||
c.log.Errorf("failed to marshal transport message: %s", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// the write always return with 0 length because the underling does not support the size feedback.
|
||||
_, err = c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to write transport message: %s", err)
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
}
|
||||
return len(payload), err
|
||||
}
|
||||
@ -459,7 +460,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
|
||||
case <-c.parentCtx.Done():
|
||||
err := c.close(true)
|
||||
if err != nil {
|
||||
log.Errorf("failed to teardown connection: %s", err)
|
||||
c.log.Errorf("failed to teardown connection: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -499,10 +500,12 @@ func (c *Client) close(gracefullyExit bool) error {
|
||||
var err error
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
c.log.Warn("relay connection was already marked as not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
c.serviceIsRunning = false
|
||||
c.log.Infof("closing all peer connections")
|
||||
c.closeAllConns()
|
||||
if gracefullyExit {
|
||||
c.writeCloseMsg()
|
||||
@ -510,8 +513,9 @@ func (c *Client) close(gracefullyExit bool) error {
|
||||
err = c.relayConn.Close()
|
||||
c.mu.Unlock()
|
||||
|
||||
c.log.Infof("waiting for read loop to close")
|
||||
c.wgReadLoop.Wait()
|
||||
c.log.Infof("relay connection closed with: %s", c.connectionURL)
|
||||
c.log.Infof("relay connection closed")
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,6 @@ package client
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
@ -17,8 +16,6 @@ import (
|
||||
|
||||
var (
|
||||
relayCleanupInterval = 60 * time.Second
|
||||
connectionTimeout = 30 * time.Second
|
||||
maxConcurrentServers = 7
|
||||
|
||||
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
|
||||
)
|
||||
@ -92,31 +89,15 @@ func (m *Manager) Serve() error {
|
||||
}
|
||||
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
|
||||
|
||||
totalServers := len(m.serverURLs)
|
||||
|
||||
successChan := make(chan *Client, 1)
|
||||
errChan := make(chan error, len(m.serverURLs))
|
||||
|
||||
ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
sem := make(chan struct{}, maxConcurrentServers)
|
||||
|
||||
for _, url := range m.serverURLs {
|
||||
sem <- struct{}{}
|
||||
go func(url string) {
|
||||
defer func() { <-sem }()
|
||||
m.connect(m.ctx, url, successChan, errChan)
|
||||
}(url)
|
||||
sp := ServerPicker{
|
||||
TokenStore: m.tokenStore,
|
||||
PeerID: m.peerID,
|
||||
}
|
||||
|
||||
var errCount int
|
||||
|
||||
for {
|
||||
select {
|
||||
case client := <-successChan:
|
||||
log.Infof("Successfully connected to relay server: %s", client.connectionURL)
|
||||
|
||||
client, err := sp.PickServer(m.ctx, m.serverURLs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.relayClient = client
|
||||
|
||||
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
|
||||
@ -125,34 +106,6 @@ func (m *Manager) Serve() error {
|
||||
})
|
||||
m.startCleanupLoop()
|
||||
return nil
|
||||
case err := <-errChan:
|
||||
errCount++
|
||||
log.Warnf("Connection attempt failed: %v", err)
|
||||
if errCount == totalServers {
|
||||
return errors.New("failed to connect to any relay server: all attempts failed")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) {
|
||||
// TODO: abort the connection if another connection was successful
|
||||
relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID)
|
||||
if err := relayClient.Connect(); err != nil {
|
||||
errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case successChan <- relayClient:
|
||||
// This client was the first to connect successfully
|
||||
default:
|
||||
if err := relayClient.Close(); err != nil {
|
||||
log.Debugf("failed to close relay client: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||
|
94
relay/client/picker.go
Normal file
94
relay/client/picker.go
Normal file
@ -0,0 +1,94 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
)
|
||||
|
||||
const (
|
||||
connectionTimeout = 30 * time.Second
|
||||
maxConcurrentServers = 7
|
||||
)
|
||||
|
||||
type connResult struct {
|
||||
RelayClient *Client
|
||||
Url string
|
||||
Err error
|
||||
}
|
||||
|
||||
type ServerPicker struct {
|
||||
TokenStore *auth.TokenStore
|
||||
PeerID string
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
totalServers := len(urls)
|
||||
|
||||
connResultChan := make(chan connResult, totalServers)
|
||||
successChan := make(chan connResult, 1)
|
||||
|
||||
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
|
||||
for _, url := range urls {
|
||||
concurrentLimiter <- struct{}{}
|
||||
go func(url string) {
|
||||
defer func() { <-concurrentLimiter }()
|
||||
sp.startConnection(parentCtx, connResultChan, url)
|
||||
}(url)
|
||||
}
|
||||
|
||||
go sp.processConnResults(connResultChan, successChan)
|
||||
|
||||
select {
|
||||
case cr, ok := <-successChan:
|
||||
if !ok {
|
||||
return nil, errors.New("failed to connect to any relay server: all attempts failed")
|
||||
}
|
||||
log.Infof("chosen home Relay server: %s", cr.Url)
|
||||
return cr.RelayClient, nil
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
||||
log.Infof("try to connecting to relay server: %s", url)
|
||||
relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID)
|
||||
err := relayClient.Connect()
|
||||
resultChan <- connResult{
|
||||
RelayClient: relayClient,
|
||||
Url: url,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
|
||||
var hasSuccess bool
|
||||
for cr := range resultChan {
|
||||
if cr.Err != nil {
|
||||
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
|
||||
continue
|
||||
}
|
||||
log.Infof("connected to Relay server: %s", cr.Url)
|
||||
|
||||
if hasSuccess {
|
||||
log.Infof("closing unnecessary Relay connection to: %s", cr.Url)
|
||||
if err := cr.RelayClient.Close(); err != nil {
|
||||
log.Errorf("failed to close connection to %s: %v", cr.Url, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
hasSuccess = true
|
||||
successChan <- cr
|
||||
}
|
||||
close(successChan)
|
||||
}
|
@ -115,6 +115,7 @@ func (p *Peer) Write(b []byte) (int, error) {
|
||||
// connection.
|
||||
func (p *Peer) CloseGracefully(ctx context.Context) {
|
||||
p.connMu.Lock()
|
||||
defer p.connMu.Unlock()
|
||||
err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to send close message to peer: %s", p.String())
|
||||
@ -124,8 +125,15 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) Close() {
|
||||
p.connMu.Lock()
|
||||
defer p.connMu.Unlock()
|
||||
|
||||
if err := p.conn.Close(); err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the peer ID
|
||||
@ -167,6 +175,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
}
|
||||
p.log.Info("peer connection closed due healthcheck timeout")
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
@ -19,10 +19,14 @@ func NewStore() *Store {
|
||||
}
|
||||
|
||||
// AddPeer adds a peer to the store
|
||||
// todo: consider to close peer conn if the peer already exists
|
||||
func (s *Store) AddPeer(peer *Peer) {
|
||||
s.peersLock.Lock()
|
||||
defer s.peersLock.Unlock()
|
||||
odlPeer, ok := s.peers[peer.String()]
|
||||
if ok {
|
||||
odlPeer.Close()
|
||||
}
|
||||
|
||||
s.peers[peer.String()] = peer
|
||||
}
|
||||
|
||||
|
@ -2,13 +2,57 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
)
|
||||
|
||||
type mockConn struct {
|
||||
}
|
||||
|
||||
func (m mockConn) Read(b []byte) (n int, err error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) Write(b []byte) (n int, err error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockConn) LocalAddr() net.Addr {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) RemoteAddr() net.Addr {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) SetDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) SetReadDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m mockConn) SetWriteDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func TestStore_DeletePeer(t *testing.T) {
|
||||
s := NewStore()
|
||||
|
||||
@ -27,8 +71,9 @@ func TestStore_DeleteDeprecatedPeer(t *testing.T) {
|
||||
|
||||
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
|
||||
|
||||
p1 := NewPeer(m, []byte("peer_id"), nil, nil)
|
||||
p2 := NewPeer(m, []byte("peer_id"), nil, nil)
|
||||
conn := &mockConn{}
|
||||
p1 := NewPeer(m, []byte("peer_id"), conn, nil)
|
||||
p2 := NewPeer(m, []byte("peer_id"), conn, nil)
|
||||
|
||||
s.AddPeer(p1)
|
||||
s.AddPeer(p2)
|
||||
|
Loading…
Reference in New Issue
Block a user