From 13eb457132e9291559b538b1d55429e94e879e18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Papp?= Date: Tue, 21 May 2024 15:51:37 +0200 Subject: [PATCH] Add registration response message to the communication --- relay/client/client.go | 37 +++++++++++++++++++--- relay/messages/message.go | 17 +++++++--- relay/server/server.go | 5 ++- relay/test/client_test.go | 66 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 115 insertions(+), 10 deletions(-) diff --git a/relay/client/client.go b/relay/client/client.go index 4cb29857d..3e7fe185a 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -9,12 +9,13 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/relay/client/dialer/ws" + "github.com/netbirdio/netbird/relay/client/dialer/udp" "github.com/netbirdio/netbird/relay/messages" ) const ( - bufferSize = 1500 // optimise the buffer size + bufferSize = 1500 // optimise the buffer size + serverResponseTimeout = 8 * time.Second ) type connContainer struct { @@ -45,7 +46,7 @@ func NewClient(serverAddress, peerID string) *Client { } func (c *Client) Connect() error { - conn, err := ws.Dial(c.serverAddress) + conn, err := udp.Dial(c.serverAddress) if err != nil { return err } @@ -80,7 +81,7 @@ func (c *Client) BindChannel(remotePeerID string) (net.Conn, error) { return nil, err } - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), serverResponseTimeout) defer cancel() select { case <-ctx.Done(): @@ -91,6 +92,10 @@ func (c *Client) BindChannel(remotePeerID string) (net.Conn, error) { } func (c *Client) Close() error { + if !c.relayConnState { + return nil + } + for _, conn := range c.channels { close(conn.messages) } @@ -110,6 +115,30 @@ func (c *Client) handShake() error { log.Errorf("failed to send hello message: %s", err) return err } + + err = c.relayConn.SetReadDeadline(time.Now().Add(serverResponseTimeout)) + if err != nil { + log.Errorf("failed to set read deadline: %s", err) + return err + } + + buf := make([]byte, 1500) // todo: optimise buffer size + n, err := c.relayConn.Read(buf) + if err != nil { + log.Errorf("failed to read hello response: %s", err) + return err + } + + msgType, err := messages.DetermineServerMsgType(buf[:n]) + if err != nil { + log.Errorf("failed to determine message type: %s", err) + return err + } + + if msgType != messages.MsgTypeHelloResponse { + log.Errorf("unexpected message type: %s", msgType) + return fmt.Errorf("unexpected message type") + } return nil } diff --git a/relay/messages/message.go b/relay/messages/message.go index 9d728a498..02945e2f1 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -6,9 +6,10 @@ import ( const ( MsgTypeHello MsgType = 0 - MsgTypeBindNewChannel MsgType = 1 - MsgTypeBindResponse MsgType = 2 - MsgTypeTransport MsgType = 3 + MsgTypeHelloResponse MsgType = 1 + MsgTypeBindNewChannel MsgType = 2 + MsgTypeBindResponse MsgType = 3 + MsgTypeTransport MsgType = 4 ) var ( @@ -51,12 +52,14 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) { // todo: validate magic byte msgType := MsgType(msg[0]) switch msgType { + case MsgTypeHelloResponse: + return msgType, nil case MsgTypeBindResponse: return msgType, nil case MsgTypeTransport: return msgType, nil default: - return 0, fmt.Errorf("invalid msg type, len: %d", len(msg)) + return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg)) } } @@ -78,6 +81,12 @@ func UnmarshalHelloMsg(msg []byte) (string, error) { return string(msg[1:]), nil } +func MarshalHelloResponse() []byte { + msg := make([]byte, 1) + msg[0] = byte(MsgTypeHelloResponse) + return msg +} + // Bind new channel func MarshalBindNewChannelMsg(destinationPeerId string) []byte { diff --git a/relay/server/server.go b/relay/server/server.go index e26ea3f5c..a66341e51 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -145,5 +145,8 @@ func handShake(conn net.Conn) (*Peer, error) { return nil, err } p := NewPeer(peerId, conn) - return p, nil + + msg := messages.MarshalHelloResponse() + _, err = conn.Write(msg) + return p, err } diff --git a/relay/test/client_test.go b/relay/test/client_test.go index b98d6690e..87a07fa94 100644 --- a/relay/test/client_test.go +++ b/relay/test/client_test.go @@ -1,12 +1,14 @@ package test import ( + "net" "os" "testing" - "github.com/netbirdio/netbird/util" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/relay/server" ) @@ -89,6 +91,68 @@ func TestClient(t *testing.T) { } } +func TestRegistration(t *testing.T) { + addr := "localhost:1234" + srv := server.NewServer() + go func() { + err := srv.Listen(addr) + if err != nil { + t.Fatalf("failed to bind server: %s", err) + } + }() + + defer func() { + err := srv.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + clientAlice := client.NewClient(addr, "alice") + err := clientAlice.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer func() { + err = clientAlice.Close() + if err != nil { + t.Errorf("failed to close conn: %s", err) + } + }() +} + +func TestRegistrationTimeout(t *testing.T) { + udpListener, err := net.ListenUDP("udp", &net.UDPAddr{ + Port: 1234, + IP: net.ParseIP("0.0.0.0"), + }) + if err != nil { + t.Fatalf("failed to bind UDP server: %s", err) + } + defer udpListener.Close() + + tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{ + Port: 1234, + IP: net.ParseIP("0.0.0.0"), + }) + if err != nil { + t.Fatalf("failed to bind TCP server: %s", err) + } + defer tcpListener.Close() + + clientAlice := client.NewClient("127.0.0.1:1234", "alice") + err = clientAlice.Connect() + if err == nil { + t.Errorf("failed to connect to server: %s", err) + } + defer func() { + err = clientAlice.Close() + if err != nil { + t.Errorf("failed to close conn: %s", err) + } + }() +} + func TestEcho(t *testing.T) { addr := "localhost:1234" srv := server.NewServer()