diff --git a/relay/client/manager.go b/relay/client/manager.go index e2ced6f5b..f8d5a28fc 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -5,8 +5,19 @@ import ( "fmt" "net" "sync" + + log "github.com/sirupsen/logrus" ) +type RelayTrack struct { + sync.RWMutex + relayClient *Client +} + +func NewRelayTrack() *RelayTrack { + return &RelayTrack{} +} + type Manager struct { ctx context.Context srvAddress string @@ -15,8 +26,8 @@ type Manager struct { relayClient *Client reconnectGuard *Guard - relayClients map[string]*Client - relayClientsMutex sync.Mutex + relayClients map[string]*RelayTrack + relayClientsMutex sync.RWMutex } func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager { @@ -24,7 +35,7 @@ func NewManager(ctx context.Context, serverAddress string, peerID string) *Manag ctx: ctx, srvAddress: serverAddress, peerID: peerID, - relayClients: make(map[string]*Client), + relayClients: make(map[string]*RelayTrack), } } @@ -50,10 +61,10 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { return nil, err } - if foreign { - return m.openConnVia(serverAddress, peerKey) - } else { + if !foreign { return m.relayClient.OpenConn(peerKey) + } else { + return m.openConnVia(serverAddress, peerKey) } } @@ -65,30 +76,50 @@ func (m *Manager) RelayAddress() (net.Addr, error) { } func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { - relayClient, ok := m.relayClients[serverAddress] + m.relayClientsMutex.RLock() + relayTrack, ok := m.relayClients[serverAddress] if ok { - return relayClient.OpenConn(peerKey) + relayTrack.RLock() + m.relayClientsMutex.RUnlock() + defer relayTrack.RUnlock() + return relayTrack.relayClient.OpenConn(peerKey) } + m.relayClientsMutex.RUnlock() - relayClient = NewClient(m.ctx, serverAddress, m.peerID) + rt := NewRelayTrack() + rt.Lock() + + m.relayClientsMutex.Lock() + m.relayClients[serverAddress] = rt + m.relayClientsMutex.Unlock() + + relayClient := NewClient(m.ctx, serverAddress, m.peerID) err := relayClient.Connect() if err != nil { + rt.Unlock() + m.relayClientsMutex.Lock() + delete(m.relayClients, serverAddress) + m.relayClientsMutex.Unlock() return nil, err } relayClient.SetOnDisconnectListener(func() { m.deleteRelayConn(serverAddress) }) + rt.Unlock() + conn, err := relayClient.OpenConn(peerKey) if err != nil { return nil, err } - m.relayClients[serverAddress] = relayClient - return conn, nil } func (m *Manager) deleteRelayConn(address string) { + log.Infof("deleting relay client for %s", address) + m.relayClientsMutex.Lock() + defer m.relayClientsMutex.Unlock() + delete(m.relayClients, address) } diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index 2c041fff4..0fbdddd7e 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -4,13 +4,14 @@ import ( "context" "testing" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/relay/server" ) -func TestNewManager(t *testing.T) { +func TestForeignConn(t *testing.T) { ctx := context.Background() - idAlice := "alice" - idBob := "bob" + addr1 := "localhost:1234" srv1 := server.NewServer() go func() { @@ -43,22 +44,21 @@ func TestNewManager(t *testing.T) { } }() + idAlice := "alice" + log.Debugf("connect by alice") clientAlice := NewManager(ctx, addr1, idAlice) err := clientAlice.Serve() if err != nil { t.Fatalf("failed to connect to server: %s", err) } - aliceSrvAddr, err := clientAlice.RelayAddress() - if err != nil { - t.Fatalf("failed to get relay address: %s", err) - } + idBob := "bob" + log.Debugf("connect by bob") clientBob := NewManager(ctx, addr2, idBob) err = clientBob.Serve() if err != nil { t.Fatalf("failed to connect to server: %s", err) } - bobsSrvAddr, err := clientBob.RelayAddress() if err != nil { t.Fatalf("failed to get relay address: %s", err) @@ -67,8 +67,7 @@ func TestNewManager(t *testing.T) { if err != nil { t.Fatalf("failed to bind channel: %s", err) } - - connBobToAlice, err := clientBob.OpenConn(aliceSrvAddr.String(), idAlice) + connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr.String(), idAlice) if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -99,3 +98,59 @@ func TestNewManager(t *testing.T) { t.Fatalf("expected %s, got %s", payload, string(buf[:n])) } } + +func TestForeginConnClose(t *testing.T) { + ctx := context.Background() + + addr1 := "localhost:1234" + srv1 := server.NewServer() + go func() { + err := srv1.Listen(addr1) + if err != nil { + t.Fatalf("failed to bind server: %s", err) + } + }() + + defer func() { + err := srv1.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + addr2 := "localhost:2234" + srv2 := server.NewServer() + go func() { + err := srv2.Listen(addr2) + if err != nil { + t.Fatalf("failed to bind server: %s", err) + } + }() + + defer func() { + err := srv2.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + idAlice := "alice" + log.Debugf("connect by alice") + clientAlice := NewManager(ctx, addr1, idAlice) + err := clientAlice.Serve() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + + conn, err := clientAlice.OpenConn(addr2, "anotherpeer") + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + err = conn.Close() + if err != nil { + t.Fatalf("failed to close connection: %s", err) + } + + select {} +}