mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-19 04:19:48 +01:00
Add registration response message to the communication
This commit is contained in:
parent
1c9c9ae47e
commit
13eb457132
@ -9,12 +9,13 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
||||
"github.com/netbirdio/netbird/relay/client/dialer/udp"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 1500 // optimise the buffer size
|
||||
serverResponseTimeout = 8 * time.Second
|
||||
)
|
||||
|
||||
type connContainer struct {
|
||||
@ -45,7 +46,7 @@ func NewClient(serverAddress, peerID string) *Client {
|
||||
}
|
||||
|
||||
func (c *Client) Connect() error {
|
||||
conn, err := ws.Dial(c.serverAddress)
|
||||
conn, err := udp.Dial(c.serverAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -80,7 +81,7 @@ func (c *Client) BindChannel(remotePeerID string) (net.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), serverResponseTimeout)
|
||||
defer cancel()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@ -91,6 +92,10 @@ func (c *Client) BindChannel(remotePeerID string) (net.Conn, error) {
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
if !c.relayConnState {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, conn := range c.channels {
|
||||
close(conn.messages)
|
||||
}
|
||||
@ -110,6 +115,30 @@ func (c *Client) handShake() error {
|
||||
log.Errorf("failed to send hello message: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.relayConn.SetReadDeadline(time.Now().Add(serverResponseTimeout))
|
||||
if err != nil {
|
||||
log.Errorf("failed to set read deadline: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, 1500) // todo: optimise buffer size
|
||||
n, err := c.relayConn.Read(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read hello response: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineServerMsgType(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to determine message type: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if msgType != messages.MsgTypeHelloResponse {
|
||||
log.Errorf("unexpected message type: %s", msgType)
|
||||
return fmt.Errorf("unexpected message type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -6,9 +6,10 @@ import (
|
||||
|
||||
const (
|
||||
MsgTypeHello MsgType = 0
|
||||
MsgTypeBindNewChannel MsgType = 1
|
||||
MsgTypeBindResponse MsgType = 2
|
||||
MsgTypeTransport MsgType = 3
|
||||
MsgTypeHelloResponse MsgType = 1
|
||||
MsgTypeBindNewChannel MsgType = 2
|
||||
MsgTypeBindResponse MsgType = 3
|
||||
MsgTypeTransport MsgType = 4
|
||||
)
|
||||
|
||||
var (
|
||||
@ -51,12 +52,14 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
||||
// todo: validate magic byte
|
||||
msgType := MsgType(msg[0])
|
||||
switch msgType {
|
||||
case MsgTypeHelloResponse:
|
||||
return msgType, nil
|
||||
case MsgTypeBindResponse:
|
||||
return msgType, nil
|
||||
case MsgTypeTransport:
|
||||
return msgType, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid msg type, len: %d", len(msg))
|
||||
return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg))
|
||||
}
|
||||
}
|
||||
|
||||
@ -78,6 +81,12 @@ func UnmarshalHelloMsg(msg []byte) (string, error) {
|
||||
return string(msg[1:]), nil
|
||||
}
|
||||
|
||||
func MarshalHelloResponse() []byte {
|
||||
msg := make([]byte, 1)
|
||||
msg[0] = byte(MsgTypeHelloResponse)
|
||||
return msg
|
||||
}
|
||||
|
||||
// Bind new channel
|
||||
|
||||
func MarshalBindNewChannelMsg(destinationPeerId string) []byte {
|
||||
|
@ -145,5 +145,8 @@ func handShake(conn net.Conn) (*Peer, error) {
|
||||
return nil, err
|
||||
}
|
||||
p := NewPeer(peerId, conn)
|
||||
return p, nil
|
||||
|
||||
msg := messages.MarshalHelloResponse()
|
||||
_, err = conn.Write(msg)
|
||||
return p, err
|
||||
}
|
||||
|
@ -1,12 +1,14 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
)
|
||||
@ -89,6 +91,68 @@ func TestClient(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistration(t *testing.T) {
|
||||
addr := "localhost:1234"
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := client.NewClient(addr, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close conn: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func TestRegistrationTimeout(t *testing.T) {
|
||||
udpListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
Port: 1234,
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind UDP server: %s", err)
|
||||
}
|
||||
defer udpListener.Close()
|
||||
|
||||
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
Port: 1234,
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind TCP server: %s", err)
|
||||
}
|
||||
defer tcpListener.Close()
|
||||
|
||||
clientAlice := client.NewClient("127.0.0.1:1234", "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err == nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close conn: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func TestEcho(t *testing.T) {
|
||||
addr := "localhost:1234"
|
||||
srv := server.NewServer()
|
||||
|
Loading…
Reference in New Issue
Block a user