diff --git a/relay/client/client.go b/relay/client/client.go index 44325e1f9..2b722c56e 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -3,6 +3,7 @@ package client import ( "context" "fmt" + "github.com/netbirdio/netbird/relay/client/dialer/udp" "io" "net" "sync" @@ -10,7 +11,6 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/relay/client/dialer/udp" "github.com/netbirdio/netbird/relay/messages" ) @@ -19,6 +19,10 @@ const ( serverResponseTimeout = 8 * time.Second ) +var ( + reconnectingTimeout = 5 * time.Second +) + type Msg struct { buf []byte } @@ -35,56 +39,51 @@ type Client struct { serverAddress string hashedID []byte - conns map[string]*connContainer // todo handle it in thread safe way + relayConnIsEstablished bool + conns map[string]*connContainer + connsMutext sync.Mutex // protect conns and relayConnIsEstablished bool - relayConn net.Conn - relayConnState bool - wgRelayConn sync.WaitGroup - mu sync.Mutex + relayConn net.Conn + serviceIsRunning bool + wgRelayConn sync.WaitGroup + mu sync.Mutex + onDisconnected chan struct{} } func NewClient(ctx context.Context, serverAddress, peerID string) *Client { ctx, ctxCancel := context.WithCancel(ctx) hashedID, hashedStringId := messages.HashID(peerID) return &Client{ - log: log.WithField("client_id", hashedStringId), - ctx: ctx, - ctxCancel: ctxCancel, - serverAddress: serverAddress, - hashedID: hashedID, - conns: make(map[string]*connContainer), + log: log.WithField("client_id", hashedStringId), + ctx: ctx, + ctxCancel: ctxCancel, + serverAddress: serverAddress, + hashedID: hashedID, + conns: make(map[string]*connContainer), + onDisconnected: make(chan struct{}), } } func (c *Client) Connect() error { c.mu.Lock() - if c.relayConnState { + if c.serviceIsRunning { c.mu.Unlock() return nil } - conn, err := udp.Dial(c.serverAddress) + err := c.connect() if err != nil { - return err - } - c.relayConn = conn - - err = c.handShake() - if err != nil { - cErr := conn.Close() - if cErr != nil { - log.Errorf("failed to close connection: %s", cErr) - } - c.relayConn = nil + c.mu.Unlock() return err } - c.relayConnState = true - c.mu.Unlock() + c.serviceIsRunning = true c.wgRelayConn.Add(1) go c.readLoop() + c.mu.Unlock() + go func() { <-c.ctx.Done() cErr := c.close() @@ -93,13 +92,50 @@ func (c *Client) Connect() error { } }() + go c.reconnectGuard() + return nil } +func (c *Client) reconnectGuard() { + for { + c.wgRelayConn.Wait() + + c.mu.Lock() + if !c.serviceIsRunning { + c.mu.Unlock() + return + } + + log.Infof("reconnecting to relay server") + err := c.connect() + if err != nil { + log.Errorf("failed to reconnect to relay server: %s", err) + c.mu.Unlock() + time.Sleep(reconnectingTimeout) + continue + } + log.Infof("reconnected to relay server") + c.wgRelayConn.Add(1) + go c.readLoop() + + c.mu.Unlock() + + } +} + func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { c.mu.Lock() defer c.mu.Unlock() - if !c.relayConnState { + + c.connsMutext.Lock() + defer c.connsMutext.Unlock() + + if !c.relayConnIsEstablished { + return nil, fmt.Errorf("relay connection is not established") + } + + if !c.serviceIsRunning { return nil, fmt.Errorf("relay connection is not established") } @@ -120,26 +156,41 @@ func (c *Client) Close() error { return c.close() } +func (c *Client) connect() error { + conn, err := udp.Dial(c.serverAddress) + if err != nil { + return err + } + c.relayConn = conn + + err = c.handShake() + if err != nil { + cErr := conn.Close() + if cErr != nil { + log.Errorf("failed to close connection: %s", cErr) + } + c.relayConn = nil + return err + } + + c.relayConnIsEstablished = true + return nil +} + func (c *Client) close() error { c.mu.Lock() defer c.mu.Unlock() - if !c.relayConnState { + if !c.serviceIsRunning { return nil } - c.relayConnState = false + c.serviceIsRunning = false err := c.relayConn.Close() c.wgRelayConn.Wait() - // close all Conn types - for _, container := range c.conns { - close(container.messages) - } - c.conns = make(map[string]*connContainer) - return err } @@ -189,17 +240,13 @@ func (c *Client) handShake() error { } func (c *Client) readLoop() { - defer func() { - c.log.Tracef("exit from read loop") - c.wgRelayConn.Done() - }() var errExit error var n int for { buf := make([]byte, bufferSize) n, errExit = c.relayConn.Read(buf) if errExit != nil { - if c.relayConnState { + if c.serviceIsRunning { c.log.Debugf("failed to read message from relay server: %s", errExit) } break @@ -232,10 +279,20 @@ func (c *Client) readLoop() { } } - if c.relayConnState { - c.log.Errorf("failed to read message from relay server: %s", errExit) + if c.serviceIsRunning { _ = c.relayConn.Close() } + + c.connsMutext.Lock() + c.relayConnIsEstablished = false + for _, container := range c.conns { + close(container.messages) + } + c.conns = make(map[string]*connContainer) + c.connsMutext.Unlock() + + c.log.Tracef("exit from read loop") + c.wgRelayConn.Done() } func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) { @@ -275,6 +332,9 @@ func (c *Client) closeConn(id string) error { c.mu.Lock() defer c.mu.Unlock() + c.connsMutext.Lock() + defer c.connsMutext.Unlock() + conn, ok := c.conns[id] if !ok { return fmt.Errorf("connection already closed") diff --git a/relay/test/client_test.go b/relay/client/client_test.go similarity index 76% rename from relay/test/client_test.go rename to relay/client/client_test.go index f70048d22..c180b922b 100644 --- a/relay/test/client_test.go +++ b/relay/client/client_test.go @@ -1,16 +1,16 @@ -package test +package client import ( "context" "net" "os" "testing" + "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/util" - "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/relay/server" ) @@ -39,21 +39,21 @@ func TestClient(t *testing.T) { } }() - clientAlice := client.NewClient(ctx, addr, "alice") + clientAlice := NewClient(ctx, addr, "alice") err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientAlice.Close() - clientPlaceHolder := client.NewClient(ctx, addr, "clientPlaceHolder") + clientPlaceHolder := NewClient(ctx, addr, "clientPlaceHolder") err = clientPlaceHolder.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientPlaceHolder.Close() - clientBob := client.NewClient(ctx, addr, "bob") + clientBob := NewClient(ctx, addr, "bob") err = clientBob.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -107,7 +107,7 @@ func TestRegistration(t *testing.T) { } }() - clientAlice := client.NewClient(ctx, addr, "alice") + clientAlice := NewClient(ctx, addr, "alice") err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -140,7 +140,7 @@ func TestRegistrationTimeout(t *testing.T) { } defer tcpListener.Close() - clientAlice := client.NewClient(ctx, "127.0.0.1:1234", "alice") + clientAlice := NewClient(ctx, "127.0.0.1:1234", "alice") err = clientAlice.Connect() if err == nil { t.Errorf("failed to connect to server: %s", err) @@ -173,7 +173,7 @@ func TestEcho(t *testing.T) { } }() - clientAlice := client.NewClient(ctx, addr, idAlice) + clientAlice := NewClient(ctx, addr, idAlice) err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -185,7 +185,7 @@ func TestEcho(t *testing.T) { } }() - clientBob := client.NewClient(ctx, addr, idBob) + clientBob := NewClient(ctx, addr, idBob) err = clientBob.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -254,7 +254,7 @@ func TestBindToUnavailabePeer(t *testing.T) { } }() - clientAlice := client.NewClient(ctx, addr, "alice") + clientAlice := NewClient(ctx, addr, "alice") err := clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -293,7 +293,7 @@ func TestBindReconnect(t *testing.T) { } }() - clientAlice := client.NewClient(ctx, addr, "alice") + clientAlice := NewClient(ctx, addr, "alice") err := clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -304,7 +304,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to bind channel: %s", err) } - clientBob := client.NewClient(ctx, addr, "bob") + clientBob := NewClient(ctx, addr, "bob") err = clientBob.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -321,7 +321,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to close client: %s", err) } - clientAlice = client.NewClient(ctx, addr, "alice") + clientAlice = NewClient(ctx, addr, "alice") err = clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -375,7 +375,7 @@ func TestCloseConn(t *testing.T) { } }() - clientAlice := client.NewClient(ctx, addr, "alice") + clientAlice := NewClient(ctx, addr, "alice") err := clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -402,3 +402,92 @@ func TestCloseConn(t *testing.T) { t.Errorf("unexpected writing from closed connection") } } + +func TestAutoReconnect(t *testing.T) { + ctx := context.Background() + + addr := "localhost:1234" + srv := server.NewServer() + go func() { + err := srv.Listen(addr) + if err != nil { + t.Errorf("failed to bind server: %s", err) + } + }() + + defer func() { + err := srv.Close() + if err != nil { + log.Errorf("failed to close server: %s", err) + } + }() + + clientAlice := NewClient(ctx, addr, "alice") + err := clientAlice.Connect() + if err != nil { + t.Errorf("failed to connect to server: %s", err) + } + + conn, err := clientAlice.OpenConn("bob") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + + _ = clientAlice.relayConn.Close() + + _, err = conn.Read(make([]byte, 1)) + if err == nil { + t.Errorf("unexpected reading from closed connection") + } + + log.Infof("waiting for reconnection") + time.Sleep(reconnectingTimeout) + + _, err = clientAlice.OpenConn("bob") + if err != nil { + t.Errorf("failed to open channel: %s", err) + } +} + +func TestCloseRelayConn(t *testing.T) { + ctx := context.Background() + + addr := "localhost:1234" + srv := server.NewServer() + go func() { + err := srv.Listen(addr) + if err != nil { + t.Errorf("failed to bind server: %s", err) + } + }() + + defer func() { + err := srv.Close() + if err != nil { + log.Errorf("failed to close server: %s", err) + } + }() + + clientAlice := NewClient(ctx, addr, "alice") + err := clientAlice.Connect() + if err != nil { + t.Errorf("failed to connect to server: %s", err) + } + + conn, err := clientAlice.OpenConn("bob") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + + _ = clientAlice.relayConn.Close() + + _, err = conn.Read(make([]byte, 1)) + if err == nil { + t.Errorf("unexpected reading from closed connection") + } + + _, err = clientAlice.OpenConn("bob") + if err == nil { + t.Errorf("unexpected opening connection to closed server") + } +}