Add registration response message to the communication

This commit is contained in:
Zoltán Papp 2024-05-21 15:51:37 +02:00
parent 1c9c9ae47e
commit 13eb457132
4 changed files with 115 additions and 10 deletions

View File

@ -9,12 +9,13 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/client/dialer/udp"
"github.com/netbirdio/netbird/relay/messages"
)
const (
bufferSize = 1500 // optimise the buffer size
serverResponseTimeout = 8 * time.Second
)
type connContainer struct {
@ -45,7 +46,7 @@ func NewClient(serverAddress, peerID string) *Client {
}
func (c *Client) Connect() error {
conn, err := ws.Dial(c.serverAddress)
conn, err := udp.Dial(c.serverAddress)
if err != nil {
return err
}
@ -80,7 +81,7 @@ func (c *Client) BindChannel(remotePeerID string) (net.Conn, error) {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), serverResponseTimeout)
defer cancel()
select {
case <-ctx.Done():
@ -91,6 +92,10 @@ func (c *Client) BindChannel(remotePeerID string) (net.Conn, error) {
}
func (c *Client) Close() error {
if !c.relayConnState {
return nil
}
for _, conn := range c.channels {
close(conn.messages)
}
@ -110,6 +115,30 @@ func (c *Client) handShake() error {
log.Errorf("failed to send hello message: %s", err)
return err
}
err = c.relayConn.SetReadDeadline(time.Now().Add(serverResponseTimeout))
if err != nil {
log.Errorf("failed to set read deadline: %s", err)
return err
}
buf := make([]byte, 1500) // todo: optimise buffer size
n, err := c.relayConn.Read(buf)
if err != nil {
log.Errorf("failed to read hello response: %s", err)
return err
}
msgType, err := messages.DetermineServerMsgType(buf[:n])
if err != nil {
log.Errorf("failed to determine message type: %s", err)
return err
}
if msgType != messages.MsgTypeHelloResponse {
log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type")
}
return nil
}

View File

@ -6,9 +6,10 @@ import (
const (
MsgTypeHello MsgType = 0
MsgTypeBindNewChannel MsgType = 1
MsgTypeBindResponse MsgType = 2
MsgTypeTransport MsgType = 3
MsgTypeHelloResponse MsgType = 1
MsgTypeBindNewChannel MsgType = 2
MsgTypeBindResponse MsgType = 3
MsgTypeTransport MsgType = 4
)
var (
@ -51,12 +52,14 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
// todo: validate magic byte
msgType := MsgType(msg[0])
switch msgType {
case MsgTypeHelloResponse:
return msgType, nil
case MsgTypeBindResponse:
return msgType, nil
case MsgTypeTransport:
return msgType, nil
default:
return 0, fmt.Errorf("invalid msg type, len: %d", len(msg))
return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg))
}
}
@ -78,6 +81,12 @@ func UnmarshalHelloMsg(msg []byte) (string, error) {
return string(msg[1:]), nil
}
func MarshalHelloResponse() []byte {
msg := make([]byte, 1)
msg[0] = byte(MsgTypeHelloResponse)
return msg
}
// Bind new channel
func MarshalBindNewChannelMsg(destinationPeerId string) []byte {

View File

@ -145,5 +145,8 @@ func handShake(conn net.Conn) (*Peer, error) {
return nil, err
}
p := NewPeer(peerId, conn)
return p, nil
msg := messages.MarshalHelloResponse()
_, err = conn.Write(msg)
return p, err
}

View File

@ -1,12 +1,14 @@
package test
import (
"net"
"os"
"testing"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/relay/server"
)
@ -89,6 +91,68 @@ func TestClient(t *testing.T) {
}
}
func TestRegistration(t *testing.T) {
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
clientAlice := client.NewClient(addr, "alice")
err := clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
}()
}
func TestRegistrationTimeout(t *testing.T) {
udpListener, err := net.ListenUDP("udp", &net.UDPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
})
if err != nil {
t.Fatalf("failed to bind UDP server: %s", err)
}
defer udpListener.Close()
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
})
if err != nil {
t.Fatalf("failed to bind TCP server: %s", err)
}
defer tcpListener.Close()
clientAlice := client.NewClient("127.0.0.1:1234", "alice")
err = clientAlice.Connect()
if err == nil {
t.Errorf("failed to connect to server: %s", err)
}
defer func() {
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
}()
}
func TestEcho(t *testing.T) {
addr := "localhost:1234"
srv := server.NewServer()