diff --git a/relay/client/client.go b/relay/client/client.go index d574e04a1..e7492afee 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -1,6 +1,7 @@ package client import ( + "context" "fmt" "io" "net" @@ -29,20 +30,25 @@ type connContainer struct { type Client struct { log *log.Entry + ctx context.Context + ctxCancel context.CancelFunc serverAddress string hashedID []byte - conns map[string]*connContainer + conns map[string]*connContainer // todo handle it in thread safe way relayConn net.Conn relayConnState bool mu sync.Mutex } -func NewClient(serverAddress, peerID string) *Client { +func NewClient(ctx context.Context, serverAddress, peerID string) *Client { + ctx, ctxCancel := context.WithCancel(ctx) hashedID, hashedStringId := messages.HashID(peerID) return &Client{ log: log.WithField("client_id", hashedStringId), + ctx: ctx, + ctxCancel: ctxCancel, serverAddress: serverAddress, hashedID: hashedID, conns: make(map[string]*connContainer), @@ -51,7 +57,11 @@ func NewClient(serverAddress, peerID string) *Client { func (c *Client) Connect() error { c.mu.Lock() - defer c.mu.Unlock() + if c.relayConnState { + c.mu.Unlock() + return nil + } + conn, err := udp.Dial(c.serverAddress) if err != nil { return err @@ -68,18 +78,39 @@ func (c *Client) Connect() error { return err } - err = c.relayConn.SetReadDeadline(time.Time{}) - if err != nil { - log.Errorf("failed to reset read deadline: %s", err) - return err - } - c.relayConnState = true - go c.readLoop() + c.mu.Unlock() + + go func() { + <-c.ctx.Done() + cErr := c.close() + if cErr != nil { + log.Errorf("failed to close relay connection: %s", cErr) + } + }() + // blocking function + c.readLoop() + + c.mu.Lock() + + // close all Conn types + for _, container := range c.conns { + close(container.messages) + } + c.conns = make(map[string]*connContainer) + + c.mu.Unlock() + return nil } func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { + c.mu.Lock() + defer c.mu.Unlock() + if !c.relayConnState { + return nil, fmt.Errorf("relay connection is not established") + } + hashedID, hashedStringID := messages.HashID(dstPeerID) log.Infof("open connection to peer: %s", hashedStringID) messageBuffer := make(chan Msg, 2) @@ -93,6 +124,11 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { } func (c *Client) Close() error { + c.ctxCancel() + return c.close() +} + +func (c *Client) close() error { c.mu.Lock() defer c.mu.Unlock() @@ -101,11 +137,20 @@ func (c *Client) Close() error { } c.relayConnState = false + err := c.relayConn.Close() + return err } func (c *Client) handShake() error { + defer func() { + err := c.relayConn.SetReadDeadline(time.Time{}) + if err != nil { + log.Errorf("failed to reset read deadline: %s", err) + } + }() + msg, err := messages.MarshalHelloMsg(c.hashedID) if err != nil { log.Errorf("failed to marshal hello message: %s", err) @@ -145,7 +190,7 @@ func (c *Client) handShake() error { func (c *Client) readLoop() { defer func() { - c.log.Debugf("exit from read loop") + c.log.Tracef("exit from read loop") }() var errExit error var n int diff --git a/relay/client/dialer/ws/client_conn.go b/relay/client/dialer/ws/client_conn.go index 3298bd228..72a3fa9b4 100644 --- a/relay/client/dialer/ws/client_conn.go +++ b/relay/client/dialer/ws/client_conn.go @@ -6,6 +6,7 @@ import ( "time" "github.com/gorilla/websocket" + log "github.com/sirupsen/logrus" ) type Conn struct { @@ -51,6 +52,9 @@ func (c *Conn) SetDeadline(t time.Time) error { } func (c *Conn) Close() error { - _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)) + err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)) + if err != nil { + log.Errorf("failed to close conn?: %s", err) + } return c.Conn.Close() } diff --git a/relay/client/manager.go b/relay/client/manager.go index 4d1aeca79..97793b3ea 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -2,42 +2,74 @@ package client import ( "context" + "net" "sync" + "time" + + log "github.com/sirupsen/logrus" ) type Manager struct { ctx context.Context - ctxCancel context.CancelFunc srvAddress string peerID string - wg sync.WaitGroup + reconnectTime time.Duration - clients map[string]*Client - clientsMutex sync.RWMutex + mu sync.Mutex + client *Client } func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager { - ctx, cancel := context.WithCancel(ctx) return &Manager{ - ctx: ctx, - ctxCancel: cancel, - srvAddress: serverAddress, - peerID: peerID, - clients: make(map[string]*Client), + ctx: ctx, + srvAddress: serverAddress, + peerID: peerID, + reconnectTime: 5 * time.Second, } } -func (m *Manager) Teardown() { - m.ctxCancel() - m.wg.Wait() -} - -func (m *Manager) newSrvConnection(address string) { - if _, ok := m.clients[address]; ok { +func (m *Manager) Serve() { + ok := m.mu.TryLock() + if !ok { return } - // client := NewClient(address, m.peerID) - //err = client.Connect() + m.client = NewClient(m.ctx, m.srvAddress, m.peerID) + + go func() { + defer m.mu.Unlock() + + // todo this is not thread safe + for { + select { + case <-m.ctx.Done(): + return + default: + m.connect() + } + + select { + case <-m.ctx.Done(): + return + case <-time.After(2 * time.Second): //timeout + } + } + }() +} + +func (m *Manager) OpenConn(peerKey string) (net.Conn, error) { + // todo m.client nil check + return m.client.OpenConn(peerKey) +} + +// connect is blocking +func (m *Manager) connect() { + err := m.client.Connect() + if err != nil { + if m.ctx.Err() != nil { + return + } + log.Errorf("connection error with '%s': %s", m.srvAddress, err) + } } diff --git a/relay/cmd/main.go b/relay/cmd/main.go index b89ab26ef..80e71486a 100644 --- a/relay/cmd/main.go +++ b/relay/cmd/main.go @@ -23,4 +23,6 @@ func main() { log.Errorf("failed to bind server: %s", err) os.Exit(1) } + + select {} } diff --git a/relay/server/listener/udp/udp_conn.go b/relay/server/listener/udp/conn.go similarity index 100% rename from relay/server/listener/udp/udp_conn.go rename to relay/server/listener/udp/conn.go diff --git a/relay/server/listener/udp/listener.go b/relay/server/listener/udp/listener.go index 400b68a88..ebd1c53f1 100644 --- a/relay/server/listener/udp/listener.go +++ b/relay/server/listener/udp/listener.go @@ -61,6 +61,7 @@ func (l *Listener) Close() error { l.lock.Lock() defer l.lock.Unlock() + log.Infof("closing UDP server") if l.listener == nil { return nil } @@ -95,6 +96,7 @@ func (l *Listener) readLoop() { } pConn = NewConn(l.listener, addr) + log.Infof("new connection from: %s", pConn.RemoteAddr()) l.conns[addr.String()] = pConn go l.onAcceptFn(pConn) pConn.onNewMsg(buf[:n]) diff --git a/relay/server/listener/ws/server_conn.go b/relay/server/listener/ws/conn.go similarity index 71% rename from relay/server/listener/ws/server_conn.go rename to relay/server/listener/ws/conn.go index de3a16781..8734293ac 100644 --- a/relay/server/listener/ws/server_conn.go +++ b/relay/server/listener/ws/conn.go @@ -1,7 +1,9 @@ package ws import ( + "errors" "fmt" + "io" "sync" "time" @@ -24,7 +26,7 @@ func NewConn(wsConn *websocket.Conn) *Conn { func (c *Conn) Read(b []byte) (n int, err error) { t, r, err := c.NextReader() if err != nil { - return 0, err + return 0, ioErrHandling(err) } if t != websocket.BinaryMessage { @@ -32,7 +34,11 @@ func (c *Conn) Read(b []byte) (n int, err error) { return 0, fmt.Errorf("unexpected message type") } - return r.Read(b) + n, err = r.Read(b) + if err != nil { + return 0, ioErrHandling(err) + } + return n, err } func (c *Conn) Write(b []byte) (int, error) { @@ -55,3 +61,14 @@ func (c *Conn) SetDeadline(t time.Time) error { } return nil } + +func ioErrHandling(err error) error { + var wErr *websocket.CloseError + if !errors.As(err, &wErr) { + return err + } + if wErr.Code == websocket.CloseNormalClosure { + return io.EOF + } + return err +} diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index cee57e348..d93dfe0c3 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -42,7 +42,7 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { Addr: l.address, } - log.Debugf("WS server is listening on address: %s", l.address) + log.Infof("WS server is listening on address: %s", l.address) err := l.server.ListenAndServe() if errors.Is(err, http.ErrServerClosed) { return nil @@ -77,6 +77,7 @@ func (l *Listener) onAccept(writer http.ResponseWriter, request *http.Request) { return } conn := NewConn(wsConn) + log.Infof("new connection from: %s", conn.RemoteAddr()) l.acceptFn(conn) return } diff --git a/relay/server/peer.go b/relay/server/peer.go index d4b98b9b4..7af113079 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -16,7 +16,6 @@ type Peer struct { } func NewPeer(id []byte, conn net.Conn) *Peer { - log.Debugf("new peer: %v", id) stringID := messages.HashIDToString(id) return &Peer{ Log: log.WithField("peer_id", stringID), diff --git a/relay/server/server.go b/relay/server/server.go index 4b05975ec..f71c888e3 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -1,15 +1,18 @@ package server import ( + "errors" "fmt" "io" "net" + "sync" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener/udp" + "github.com/netbirdio/netbird/relay/server/listener/ws" ) // Server @@ -19,7 +22,8 @@ import ( type Server struct { store *Store - listener listener.Listener + UDPListener listener.Listener + WSListener listener.Listener } func NewServer() *Server { @@ -29,15 +33,45 @@ func NewServer() *Server { } func (r *Server) Listen(address string) error { - r.listener = udp.NewListener(address) - return r.listener.Listen(r.accept) + wg := sync.WaitGroup{} + wg.Add(2) + + r.WSListener = ws.NewListener(address) + var wslErr error + go func() { + defer wg.Done() + wslErr = r.WSListener.Listen(r.accept) + if wslErr != nil { + log.Errorf("failed to bind ws server: %s", wslErr) + } + }() + + r.UDPListener = udp.NewListener(address) + var udpLErr error + go func() { + defer wg.Done() + udpLErr = r.UDPListener.Listen(r.accept) + if udpLErr != nil { + log.Errorf("failed to bind ws server: %s", udpLErr) + } + }() + + err := errors.Join(wslErr, udpLErr) + return err } func (r *Server) Close() error { - if r.listener == nil { - return nil + var wErr error + if r.WSListener != nil { + wErr = r.WSListener.Close() } - return r.listener.Close() + + var uErr error + if r.UDPListener != nil { + uErr = r.UDPListener.Close() + } + err := errors.Join(wErr, uErr) + return err } func (r *Server) accept(conn net.Conn) { @@ -50,12 +84,12 @@ func (r *Server) accept(conn net.Conn) { } return } - peer.Log.Debugf("peer connected from: %s", conn.RemoteAddr()) + peer.Log.Infof("peer connected from: %s", conn.RemoteAddr()) r.store.AddPeer(peer) defer func() { - peer.Log.Debugf("teardown connection") r.store.DeletePeer(peer) + peer.Log.Infof("peer left") }() for { diff --git a/relay/test/client_test.go b/relay/test/client_test.go index 675962eef..1d3748abb 100644 --- a/relay/test/client_test.go +++ b/relay/test/client_test.go @@ -1,6 +1,7 @@ package test import ( + "context" "net" "os" "testing" @@ -20,6 +21,8 @@ func TestMain(m *testing.M) { } func TestClient(t *testing.T) { + ctx := context.Background() + addr := "localhost:1234" srv := server.NewServer() go func() { @@ -36,21 +39,21 @@ func TestClient(t *testing.T) { } }() - clientAlice := client.NewClient(addr, "alice") + clientAlice := client.NewClient(ctx, addr, "alice") err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientAlice.Close() - clientPlaceHolder := client.NewClient(addr, "clientPlaceHolder") + clientPlaceHolder := client.NewClient(ctx, addr, "clientPlaceHolder") err = clientPlaceHolder.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientPlaceHolder.Close() - clientBob := client.NewClient(addr, "bob") + clientBob := client.NewClient(ctx, addr, "bob") err = clientBob.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -87,6 +90,7 @@ func TestClient(t *testing.T) { } func TestRegistration(t *testing.T) { + ctx := context.Background() addr := "localhost:1234" srv := server.NewServer() go func() { @@ -103,7 +107,7 @@ func TestRegistration(t *testing.T) { } }() - clientAlice := client.NewClient(addr, "alice") + clientAlice := client.NewClient(ctx, addr, "alice") err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -117,6 +121,7 @@ func TestRegistration(t *testing.T) { } func TestRegistrationTimeout(t *testing.T) { + ctx := context.Background() udpListener, err := net.ListenUDP("udp", &net.UDPAddr{ Port: 1234, IP: net.ParseIP("0.0.0.0"), @@ -135,7 +140,7 @@ func TestRegistrationTimeout(t *testing.T) { } defer tcpListener.Close() - clientAlice := client.NewClient("127.0.0.1:1234", "alice") + clientAlice := client.NewClient(ctx, "127.0.0.1:1234", "alice") err = clientAlice.Connect() if err == nil { t.Errorf("failed to connect to server: %s", err) @@ -149,6 +154,7 @@ func TestRegistrationTimeout(t *testing.T) { } func TestEcho(t *testing.T) { + ctx := context.Background() idAlice := "alice" idBob := "bob" addr := "localhost:1234" @@ -167,7 +173,7 @@ func TestEcho(t *testing.T) { } }() - clientAlice := client.NewClient(addr, idAlice) + clientAlice := client.NewClient(ctx, addr, idAlice) err := clientAlice.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -179,7 +185,7 @@ func TestEcho(t *testing.T) { } }() - clientBob := client.NewClient(addr, idBob) + clientBob := client.NewClient(ctx, addr, idBob) err = clientBob.Connect() if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -229,6 +235,8 @@ func TestEcho(t *testing.T) { } func TestBindToUnavailabePeer(t *testing.T) { + ctx := context.Background() + addr := "localhost:1234" srv := server.NewServer() go func() { @@ -246,7 +254,7 @@ func TestBindToUnavailabePeer(t *testing.T) { } }() - clientAlice := client.NewClient(addr, "alice") + clientAlice := client.NewClient(ctx, addr, "alice") err := clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -266,6 +274,8 @@ func TestBindToUnavailabePeer(t *testing.T) { } func TestBindReconnect(t *testing.T) { + ctx := context.Background() + addr := "localhost:1234" srv := server.NewServer() go func() { @@ -283,7 +293,7 @@ func TestBindReconnect(t *testing.T) { } }() - clientAlice := client.NewClient(addr, "alice") + clientAlice := client.NewClient(ctx, addr, "alice") err := clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -294,7 +304,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to bind channel: %s", err) } - clientBob := client.NewClient(addr, "bob") + clientBob := client.NewClient(ctx, addr, "bob") err = clientBob.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -311,7 +321,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to close client: %s", err) } - clientAlice = client.NewClient(addr, "alice") + clientAlice = client.NewClient(ctx, addr, "alice") err = clientAlice.Connect() if err != nil { t.Errorf("failed to connect to server: %s", err) diff --git a/relay/test/manager_test.go b/relay/test/manager_test.go new file mode 100644 index 000000000..09edeff84 --- /dev/null +++ b/relay/test/manager_test.go @@ -0,0 +1,57 @@ +package test + +import ( + "context" + "testing" + "time" + + "github.com/netbirdio/netbird/relay/client" + "github.com/netbirdio/netbird/relay/server" +) + +func TestManager(t *testing.T) { + addr := "localhost:1239" + + srv := server.NewServer() + go func() { + err := srv.Listen(addr) + if err != nil { + t.Fatalf("failed to bind server: %s", err) + } + }() + + defer func() { + err := srv.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cm := client.NewManager(ctx, addr, "me") + cm.Serve() + + // wait for the relay handshake to complete + time.Sleep(1 * time.Second) + conn, err := cm.OpenConn("remotepeer") + if err != nil { + t.Errorf("failed to open connection: %s", err) + } + + readCtx, readCancel := context.WithCancel(context.Background()) + defer readCancel() + go func() { + _, _ = conn.Read(make([]byte, 1)) + readCancel() + }() + + cancel() + + select { + case <-time.After(2 * time.Second): + t.Errorf("client peer conn did not close automatically") + case <-readCtx.Done(): + // conn exited well + } +}