From 4ff069a102718ce04d6132924172c8e39eafde24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Papp?= Date: Wed, 29 May 2024 16:40:26 +0200 Subject: [PATCH] Support multiple server --- relay/client/client.go | 92 ++++++++++++++++++++++-------- relay/client/manager.go | 106 ++++++++++++++++++++--------------- relay/client/manager_test.go | 97 ++++++++++++++++++++++++++++++++ 3 files changed, 225 insertions(+), 70 deletions(-) create mode 100644 relay/client/manager_test.go diff --git a/relay/client/client.go b/relay/client/client.go index 69e9b391b..c2f6f9c71 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -3,7 +3,6 @@ package client import ( "context" "fmt" - "github.com/netbirdio/netbird/relay/client/dialer/udp" "io" "net" "sync" @@ -11,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/relay/client/dialer/udp" "github.com/netbirdio/netbird/relay/messages" ) @@ -48,6 +48,8 @@ type Client struct { serviceIsRunningMutex sync.Mutex wgReadLoop sync.WaitGroup onDisconnected chan struct{} + + remoteAddr net.Addr } func NewClient(ctx context.Context, serverAddress, peerID string) *Client { @@ -97,31 +99,35 @@ func (c *Client) Connect() error { return nil } -func (c *Client) reconnectGuard() { - for { - c.wgReadLoop.Wait() - - c.serviceIsRunningMutex.Lock() - if !c.serviceIsRunning { - c.serviceIsRunningMutex.Unlock() - return - } - - log.Infof("reconnecting to relay server") - err := c.connect() - if err != nil { - log.Errorf("failed to reconnect to relay server: %s", err) - c.serviceIsRunningMutex.Unlock() - time.Sleep(reconnectingTimeout) - continue - } - log.Infof("reconnected to relay server") - c.wgReadLoop.Add(1) - go c.readLoop() - +func (c *Client) ConnectWithoutReconnect() error { + c.serviceIsRunningMutex.Lock() + if c.serviceIsRunning { c.serviceIsRunningMutex.Unlock() - + return nil } + + err := c.connect() + if err != nil { + c.serviceIsRunningMutex.Unlock() + return err + } + + c.serviceIsRunning = true + + c.wgReadLoop.Add(1) + go c.readLoop() + + c.serviceIsRunningMutex.Unlock() + + go func() { + <-c.ctx.Done() + cErr := c.close() + if cErr != nil { + log.Errorf("failed to close relay connection: %s", cErr) + } + }() + + return nil } func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { @@ -144,6 +150,15 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { return conn, nil } +func (c *Client) RelayRemoteAddress() (net.Addr, error) { + c.serviceIsRunningMutex.Lock() + defer c.serviceIsRunningMutex.Unlock() + if c.remoteAddr == nil { + return nil, fmt.Errorf("relay connection is not established") + } + return c.remoteAddr, nil +} + func (c *Client) Close() error { c.serviceIsRunningMutex.Lock() if !c.serviceIsRunning { @@ -172,6 +187,8 @@ func (c *Client) connect() error { return err } + c.remoteAddr = conn.RemoteAddr() + c.readyToOpenConns = true return nil } @@ -193,6 +210,33 @@ func (c *Client) close() error { return err } +func (c *Client) reconnectGuard() { + for { + c.wgReadLoop.Wait() + + c.serviceIsRunningMutex.Lock() + if !c.serviceIsRunning { + c.serviceIsRunningMutex.Unlock() + return + } + + log.Infof("reconnecting to relay server") + err := c.connect() + if err != nil { + log.Errorf("failed to reconnect to relay server: %s", err) + c.serviceIsRunningMutex.Unlock() + time.Sleep(reconnectingTimeout) + continue + } + log.Infof("reconnected to relay server") + c.wgReadLoop.Add(1) + go c.readLoop() + + c.serviceIsRunningMutex.Unlock() + + } +} + func (c *Client) handShake() error { defer func() { err := c.relayConn.SetReadDeadline(time.Time{}) diff --git a/relay/client/manager.go b/relay/client/manager.go index 97793b3ea..686b4ac4f 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -2,74 +2,88 @@ package client import ( "context" + "fmt" "net" - "sync" - "time" - - log "github.com/sirupsen/logrus" ) +// Manager todo: thread safe type Manager struct { ctx context.Context srvAddress string peerID string - reconnectTime time.Duration + relayClient *Client - mu sync.Mutex - client *Client + relayClients map[string]*Client } func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager { return &Manager{ - ctx: ctx, - srvAddress: serverAddress, - peerID: peerID, - reconnectTime: 5 * time.Second, + ctx: ctx, + srvAddress: serverAddress, + peerID: peerID, + relayClients: make(map[string]*Client), } } -func (m *Manager) Serve() { - ok := m.mu.TryLock() - if !ok { - return +func (m *Manager) Serve() error { + m.relayClient = NewClient(m.ctx, m.srvAddress, m.peerID) + err := m.relayClient.Connect() + if err != nil { + return err } + return nil +} - m.client = NewClient(m.ctx, m.srvAddress, m.peerID) - - go func() { - defer m.mu.Unlock() - - // todo this is not thread safe - for { - select { - case <-m.ctx.Done(): - return - default: - m.connect() - } - - select { - case <-m.ctx.Done(): - return - case <-time.After(2 * time.Second): //timeout - } - } - }() +func (m *Manager) RelayAddress() (net.Addr, error) { + if m.relayClient == nil { + return nil, fmt.Errorf("relay client not connected") + } + return m.relayClient.RelayRemoteAddress() } func (m *Manager) OpenConn(peerKey string) (net.Conn, error) { - // todo m.client nil check - return m.client.OpenConn(peerKey) + if m.relayClient == nil { + return nil, fmt.Errorf("relay client not connected") + } + + rAddr, err := m.relayClient.RelayRemoteAddress() + if err != nil { + return nil, fmt.Errorf("relay client not connected") + } + + return m.OpenConnTo(rAddr.String(), peerKey) } -// connect is blocking -func (m *Manager) connect() { - err := m.client.Connect() - if err != nil { - if m.ctx.Err() != nil { - return - } - log.Errorf("connection error with '%s': %s", m.srvAddress, err) +func (m *Manager) OpenConnTo(serverAddress, peerKey string) (net.Conn, error) { + if m.relayClient == nil { + return nil, fmt.Errorf("relay client not connected") } + rAddr, err := m.relayClient.RelayRemoteAddress() + if err != nil { + return nil, fmt.Errorf("relay client not connected") + } + + if rAddr.String() == serverAddress { + return m.relayClient.OpenConn(peerKey) + } + + relayClient, ok := m.relayClients[serverAddress] + if ok { + return relayClient.OpenConn(peerKey) + } + + relayClient = NewClient(m.ctx, serverAddress, m.peerID) + err = relayClient.ConnectWithoutReconnect() + if err != nil { + return nil, err + } + + conn, err := relayClient.OpenConn(peerKey) + if err != nil { + return nil, err + } + + m.relayClients[serverAddress] = relayClient + return conn, nil } diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go new file mode 100644 index 000000000..6c8edd534 --- /dev/null +++ b/relay/client/manager_test.go @@ -0,0 +1,97 @@ +package client + +import ( + "context" + "testing" + + "github.com/netbirdio/netbird/relay/server" +) + +func TestNewManager(t *testing.T) { + ctx := context.Background() + idAlice := "alice" + idBob := "bob" + addr1 := "localhost:1234" + srv1 := server.NewServer() + go func() { + err := srv1.Listen(addr1) + if err != nil { + t.Fatalf("failed to bind server: %s", err) + } + }() + + defer func() { + err := srv1.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + addr2 := "localhost:2234" + srv2 := server.NewServer() + go func() { + err := srv2.Listen(addr2) + if err != nil { + t.Fatalf("failed to bind server: %s", err) + } + }() + + defer func() { + err := srv2.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + clientAlice := NewManager(ctx, addr1, idAlice) + err := clientAlice.Serve() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + + clientBob := NewManager(ctx, addr2, idBob) + err = clientBob.Serve() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + + bobsSrvAddr, err := clientBob.RelayAddress() + if err != nil { + t.Fatalf("failed to get relay address: %s", err) + } + connAliceToBob, err := clientAlice.OpenConnTo(bobsSrvAddr.String(), idBob) + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + connBobToAlice, err := clientBob.OpenConn(idAlice) + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + payload := "hello bob, I am alice" + _, err = connAliceToBob.Write([]byte(payload)) + if err != nil { + t.Fatalf("failed to write to channel: %s", err) + } + + buf := make([]byte, 65535) + n, err := connBobToAlice.Read(buf) + if err != nil { + t.Fatalf("failed to read from channel: %s", err) + } + + _, err = connBobToAlice.Write(buf[:n]) + if err != nil { + t.Fatalf("failed to write to channel: %s", err) + } + + n, err = connAliceToBob.Read(buf) + if err != nil { + t.Fatalf("failed to read from channel: %s", err) + } + + if payload != string(buf[:n]) { + t.Fatalf("expected %s, got %s", payload, string(buf[:n])) + } +}