diff --git a/relay/client/client.go b/relay/client/client.go index 7ff17944f..3e5c0ba24 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -142,7 +142,7 @@ type Client struct { func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { hashedID, hashedStringId := messages.HashID(peerID) return &Client{ - log: log.WithField("client_id", hashedStringId), + log: log.WithFields(log.Fields{"client_id": hashedStringId, "relay": serverURL}), parentCtx: ctx, connectionURL: serverURL, authTokenStore: authTokenStore, @@ -159,7 +159,7 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. func (c *Client) Connect() error { - c.log.Infof("connecting to relay server: %s", c.connectionURL) + c.log.Infof("connecting to relay server") c.readLoopMutex.Lock() defer c.readLoopMutex.Unlock() @@ -180,7 +180,7 @@ func (c *Client) Connect() error { c.wgReadLoop.Add(1) go c.readLoop(c.relayConn) - c.log.Infof("relay connection established with: %s", c.connectionURL) + c.log.Infof("relay connection established") return nil } @@ -202,7 +202,7 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { return nil, ErrConnAlreadyExists } - log.Infof("open connection to peer: %s", hashedStringID) + c.log.Infof("open connection to peer: %s", hashedStringID) msgChannel := make(chan Msg, 2) conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL) @@ -250,7 +250,7 @@ func (c *Client) connect() error { if err != nil { cErr := conn.Close() if cErr != nil { - log.Errorf("failed to close connection: %s", cErr) + c.log.Errorf("failed to close connection: %s", cErr) } return err } @@ -261,19 +261,19 @@ func (c *Client) connect() error { func (c *Client) handShake() error { msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) if err != nil { - log.Errorf("failed to marshal auth message: %s", err) + c.log.Errorf("failed to marshal auth message: %s", err) return err } _, err = c.relayConn.Write(msg) if err != nil { - log.Errorf("failed to send auth message: %s", err) + c.log.Errorf("failed to send auth message: %s", err) return err } buf := make([]byte, messages.MaxHandshakeRespSize) n, err := c.readWithTimeout(buf) if err != nil { - log.Errorf("failed to read auth response: %s", err) + c.log.Errorf("failed to read auth response: %s", err) return err } @@ -284,12 +284,12 @@ func (c *Client) handShake() error { msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) if err != nil { - log.Errorf("failed to determine message type: %s", err) + c.log.Errorf("failed to determine message type: %s", err) return err } if msgType != messages.MsgTypeAuthResponse { - log.Errorf("unexpected message type: %s", msgType) + c.log.Errorf("unexpected message type: %s", msgType) return fmt.Errorf("unexpected message type") } @@ -318,6 +318,7 @@ func (c *Client) readLoop(relayConn net.Conn) { buf := *bufPtr n, errExit = relayConn.Read(buf) if errExit != nil { + c.log.Infof("start to Relay read loop exit") c.mu.Lock() if c.serviceIsRunning && !internallyStoppedFlag.isSet() { c.log.Debugf("failed to read message from relay server: %s", errExit) @@ -364,7 +365,7 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, case messages.MsgTypeTransport: return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag) case messages.MsgTypeClose: - log.Debugf("relay connection close by server") + c.log.Debugf("relay connection close by server") c.bufPool.Put(bufPtr) return false } @@ -433,14 +434,14 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [ // todo: use buffer pool instead of create new transport msg. msg, err := messages.MarshalTransportMsg(dstID, payload) if err != nil { - log.Errorf("failed to marshal transport message: %s", err) + c.log.Errorf("failed to marshal transport message: %s", err) return 0, err } // the write always return with 0 length because the underling does not support the size feedback. _, err = c.relayConn.Write(msg) if err != nil { - log.Errorf("failed to write transport message: %s", err) + c.log.Errorf("failed to write transport message: %s", err) } return len(payload), err } @@ -459,7 +460,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in case <-c.parentCtx.Done(): err := c.close(true) if err != nil { - log.Errorf("failed to teardown connection: %s", err) + c.log.Errorf("failed to teardown connection: %s", err) } return } @@ -499,10 +500,12 @@ func (c *Client) close(gracefullyExit bool) error { var err error if !c.serviceIsRunning { c.mu.Unlock() + c.log.Warn("relay connection was already marked as not running") return nil } c.serviceIsRunning = false + c.log.Infof("closing all peer connections") c.closeAllConns() if gracefullyExit { c.writeCloseMsg() @@ -510,8 +513,9 @@ func (c *Client) close(gracefullyExit bool) error { err = c.relayConn.Close() c.mu.Unlock() + c.log.Infof("waiting for read loop to close") c.wgReadLoop.Wait() - c.log.Infof("relay connection closed with: %s", c.connectionURL) + c.log.Infof("relay connection closed") return err } diff --git a/relay/client/manager.go b/relay/client/manager.go index a9d294160..4554c7c0f 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -3,7 +3,6 @@ package client import ( "container/list" "context" - "errors" "fmt" "net" "reflect" @@ -17,8 +16,6 @@ import ( var ( relayCleanupInterval = 60 * time.Second - connectionTimeout = 30 * time.Second - maxConcurrentServers = 7 ErrRelayClientNotConnected = fmt.Errorf("relay client not connected") ) @@ -92,67 +89,23 @@ func (m *Manager) Serve() error { } log.Debugf("starting relay client manager with %v relay servers", m.serverURLs) - totalServers := len(m.serverURLs) - - successChan := make(chan *Client, 1) - errChan := make(chan error, len(m.serverURLs)) - - ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout) - defer cancel() - - sem := make(chan struct{}, maxConcurrentServers) - - for _, url := range m.serverURLs { - sem <- struct{}{} - go func(url string) { - defer func() { <-sem }() - m.connect(m.ctx, url, successChan, errChan) - }(url) + sp := ServerPicker{ + TokenStore: m.tokenStore, + PeerID: m.peerID, } - var errCount int - - for { - select { - case client := <-successChan: - log.Infof("Successfully connected to relay server: %s", client.connectionURL) - - m.relayClient = client - - m.reconnectGuard = NewGuard(m.ctx, m.relayClient) - m.relayClient.SetOnDisconnectListener(func() { - m.onServerDisconnected(client.connectionURL) - }) - m.startCleanupLoop() - return nil - case err := <-errChan: - errCount++ - log.Warnf("Connection attempt failed: %v", err) - if errCount == totalServers { - return errors.New("failed to connect to any relay server: all attempts failed") - } - case <-ctx.Done(): - return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err()) - } + client, err := sp.PickServer(m.ctx, m.serverURLs) + if err != nil { + return err } -} + m.relayClient = client -func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) { - // TODO: abort the connection if another connection was successful - relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID) - if err := relayClient.Connect(); err != nil { - errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err) - return - } - - select { - case successChan <- relayClient: - // This client was the first to connect successfully - default: - if err := relayClient.Close(); err != nil { - log.Debugf("failed to close relay client: %s", err) - } - } + m.reconnectGuard = NewGuard(m.ctx, m.relayClient) + m.relayClient.SetOnDisconnectListener(func() { + m.onServerDisconnected(client.connectionURL) + }) + m.startCleanupLoop() + return nil } // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be diff --git a/relay/client/picker.go b/relay/client/picker.go new file mode 100644 index 000000000..b0888a4a0 --- /dev/null +++ b/relay/client/picker.go @@ -0,0 +1,94 @@ +package client + +import ( + "context" + "errors" + "fmt" + "time" + + log "github.com/sirupsen/logrus" + + auth "github.com/netbirdio/netbird/relay/auth/hmac" +) + +const ( + connectionTimeout = 30 * time.Second + maxConcurrentServers = 7 +) + +type connResult struct { + RelayClient *Client + Url string + Err error +} + +type ServerPicker struct { + TokenStore *auth.TokenStore + PeerID string +} + +func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) { + ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) + defer cancel() + + totalServers := len(urls) + + connResultChan := make(chan connResult, totalServers) + successChan := make(chan connResult, 1) + + concurrentLimiter := make(chan struct{}, maxConcurrentServers) + for _, url := range urls { + concurrentLimiter <- struct{}{} + go func(url string) { + defer func() { <-concurrentLimiter }() + sp.startConnection(parentCtx, connResultChan, url) + }(url) + } + + go sp.processConnResults(connResultChan, successChan) + + select { + case cr, ok := <-successChan: + if !ok { + return nil, errors.New("failed to connect to any relay server: all attempts failed") + } + log.Infof("chosen home Relay server: %s", cr.Url) + return cr.RelayClient, nil + case <-ctx.Done(): + return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err()) + } +} + +func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) { + log.Infof("try to connecting to relay server: %s", url) + relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID) + err := relayClient.Connect() + resultChan <- connResult{ + RelayClient: relayClient, + Url: url, + Err: err, + } +} + +func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) { + var hasSuccess bool + for cr := range resultChan { + if cr.Err != nil { + log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) + continue + } + log.Infof("connected to Relay server: %s", cr.Url) + + if hasSuccess { + log.Infof("closing unnecessary Relay connection to: %s", cr.Url) + if err := cr.RelayClient.Close(); err != nil { + log.Errorf("failed to close connection to %s: %v", cr.Url, err) + } + continue + } + + hasSuccess = true + successChan <- cr + } + close(successChan) +} diff --git a/relay/server/peer.go b/relay/server/peer.go index 0de601996..00341e98b 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -115,6 +115,7 @@ func (p *Peer) Write(b []byte) (int, error) { // connection. func (p *Peer) CloseGracefully(ctx context.Context) { p.connMu.Lock() + defer p.connMu.Unlock() err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg()) if err != nil { p.log.Errorf("failed to send close message to peer: %s", p.String()) @@ -124,8 +125,15 @@ func (p *Peer) CloseGracefully(ctx context.Context) { if err != nil { p.log.Errorf("failed to close connection to peer: %s", err) } +} +func (p *Peer) Close() { + p.connMu.Lock() defer p.connMu.Unlock() + + if err := p.conn.Close(); err != nil { + p.log.Errorf("failed to close connection to peer: %s", err) + } } // String returns the peer ID @@ -167,6 +175,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send if err != nil { p.log.Errorf("failed to close connection to peer: %s", err) } + p.log.Info("peer connection closed due healthcheck timeout") return case <-ctx.Done(): return diff --git a/relay/server/store.go b/relay/server/store.go index 96879dae1..4288e62c5 100644 --- a/relay/server/store.go +++ b/relay/server/store.go @@ -19,10 +19,14 @@ func NewStore() *Store { } // AddPeer adds a peer to the store -// todo: consider to close peer conn if the peer already exists func (s *Store) AddPeer(peer *Peer) { s.peersLock.Lock() defer s.peersLock.Unlock() + odlPeer, ok := s.peers[peer.String()] + if ok { + odlPeer.Close() + } + s.peers[peer.String()] = peer } diff --git a/relay/server/store_test.go b/relay/server/store_test.go index 4a30bc131..41c7baa92 100644 --- a/relay/server/store_test.go +++ b/relay/server/store_test.go @@ -2,13 +2,57 @@ package server import ( "context" + "net" "testing" + "time" "go.opentelemetry.io/otel" "github.com/netbirdio/netbird/relay/metrics" ) +type mockConn struct { +} + +func (m mockConn) Read(b []byte) (n int, err error) { + //TODO implement me + panic("implement me") +} + +func (m mockConn) Write(b []byte) (n int, err error) { + //TODO implement me + panic("implement me") +} + +func (m mockConn) Close() error { + return nil +} + +func (m mockConn) LocalAddr() net.Addr { + //TODO implement me + panic("implement me") +} + +func (m mockConn) RemoteAddr() net.Addr { + //TODO implement me + panic("implement me") +} + +func (m mockConn) SetDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + +func (m mockConn) SetReadDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + +func (m mockConn) SetWriteDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + func TestStore_DeletePeer(t *testing.T) { s := NewStore() @@ -27,8 +71,9 @@ func TestStore_DeleteDeprecatedPeer(t *testing.T) { m, _ := metrics.NewMetrics(context.Background(), otel.Meter("")) - p1 := NewPeer(m, []byte("peer_id"), nil, nil) - p2 := NewPeer(m, []byte("peer_id"), nil, nil) + conn := &mockConn{} + p1 := NewPeer(m, []byte("peer_id"), conn, nil) + p2 := NewPeer(m, []byte("peer_id"), conn, nil) s.AddPeer(p1) s.AddPeer(p2)