mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-19 04:19:48 +01:00
Add registration response message to the communication
This commit is contained in:
parent
1c9c9ae47e
commit
13eb457132
@ -9,12 +9,13 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
"github.com/netbirdio/netbird/relay/client/dialer/udp"
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
bufferSize = 1500 // optimise the buffer size
|
bufferSize = 1500 // optimise the buffer size
|
||||||
|
serverResponseTimeout = 8 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
type connContainer struct {
|
type connContainer struct {
|
||||||
@ -45,7 +46,7 @@ func NewClient(serverAddress, peerID string) *Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect() error {
|
||||||
conn, err := ws.Dial(c.serverAddress)
|
conn, err := udp.Dial(c.serverAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -80,7 +81,7 @@ func (c *Client) BindChannel(remotePeerID string) (net.Conn, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), serverResponseTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@ -91,6 +92,10 @@ func (c *Client) BindChannel(remotePeerID string) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
|
if !c.relayConnState {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, conn := range c.channels {
|
for _, conn := range c.channels {
|
||||||
close(conn.messages)
|
close(conn.messages)
|
||||||
}
|
}
|
||||||
@ -110,6 +115,30 @@ func (c *Client) handShake() error {
|
|||||||
log.Errorf("failed to send hello message: %s", err)
|
log.Errorf("failed to send hello message: %s", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = c.relayConn.SetReadDeadline(time.Now().Add(serverResponseTimeout))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to set read deadline: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 1500) // todo: optimise buffer size
|
||||||
|
n, err := c.relayConn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to read hello response: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
msgType, err := messages.DetermineServerMsgType(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to determine message type: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgType != messages.MsgTypeHelloResponse {
|
||||||
|
log.Errorf("unexpected message type: %s", msgType)
|
||||||
|
return fmt.Errorf("unexpected message type")
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,9 +6,10 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
MsgTypeHello MsgType = 0
|
MsgTypeHello MsgType = 0
|
||||||
MsgTypeBindNewChannel MsgType = 1
|
MsgTypeHelloResponse MsgType = 1
|
||||||
MsgTypeBindResponse MsgType = 2
|
MsgTypeBindNewChannel MsgType = 2
|
||||||
MsgTypeTransport MsgType = 3
|
MsgTypeBindResponse MsgType = 3
|
||||||
|
MsgTypeTransport MsgType = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -51,12 +52,14 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
|||||||
// todo: validate magic byte
|
// todo: validate magic byte
|
||||||
msgType := MsgType(msg[0])
|
msgType := MsgType(msg[0])
|
||||||
switch msgType {
|
switch msgType {
|
||||||
|
case MsgTypeHelloResponse:
|
||||||
|
return msgType, nil
|
||||||
case MsgTypeBindResponse:
|
case MsgTypeBindResponse:
|
||||||
return msgType, nil
|
return msgType, nil
|
||||||
case MsgTypeTransport:
|
case MsgTypeTransport:
|
||||||
return msgType, nil
|
return msgType, nil
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("invalid msg type, len: %d", len(msg))
|
return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,6 +81,12 @@ func UnmarshalHelloMsg(msg []byte) (string, error) {
|
|||||||
return string(msg[1:]), nil
|
return string(msg[1:]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MarshalHelloResponse() []byte {
|
||||||
|
msg := make([]byte, 1)
|
||||||
|
msg[0] = byte(MsgTypeHelloResponse)
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
// Bind new channel
|
// Bind new channel
|
||||||
|
|
||||||
func MarshalBindNewChannelMsg(destinationPeerId string) []byte {
|
func MarshalBindNewChannelMsg(destinationPeerId string) []byte {
|
||||||
|
@ -145,5 +145,8 @@ func handShake(conn net.Conn) (*Peer, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
p := NewPeer(peerId, conn)
|
p := NewPeer(peerId, conn)
|
||||||
return p, nil
|
|
||||||
|
msg := messages.MarshalHelloResponse()
|
||||||
|
_, err = conn.Write(msg)
|
||||||
|
return p, err
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
package test
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/client"
|
"github.com/netbirdio/netbird/relay/client"
|
||||||
"github.com/netbirdio/netbird/relay/server"
|
"github.com/netbirdio/netbird/relay/server"
|
||||||
)
|
)
|
||||||
@ -89,6 +91,68 @@ func TestClient(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRegistration(t *testing.T) {
|
||||||
|
addr := "localhost:1234"
|
||||||
|
srv := server.NewServer()
|
||||||
|
go func() {
|
||||||
|
err := srv.Listen(addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := srv.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientAlice := client.NewClient(addr, "alice")
|
||||||
|
err := clientAlice.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = clientAlice.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close conn: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistrationTimeout(t *testing.T) {
|
||||||
|
udpListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||||
|
Port: 1234,
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind UDP server: %s", err)
|
||||||
|
}
|
||||||
|
defer udpListener.Close()
|
||||||
|
|
||||||
|
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||||
|
Port: 1234,
|
||||||
|
IP: net.ParseIP("0.0.0.0"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind TCP server: %s", err)
|
||||||
|
}
|
||||||
|
defer tcpListener.Close()
|
||||||
|
|
||||||
|
clientAlice := client.NewClient("127.0.0.1:1234", "alice")
|
||||||
|
err = clientAlice.Connect()
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = clientAlice.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close conn: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
func TestEcho(t *testing.T) {
|
func TestEcho(t *testing.T) {
|
||||||
addr := "localhost:1234"
|
addr := "localhost:1234"
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
|
Loading…
Reference in New Issue
Block a user