mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 18:22:37 +02:00
Add initial relay code
This commit is contained in:
parent
50201d63c2
commit
57a89cf0cc
1
go.mod
1
go.mod
@ -49,6 +49,7 @@ require (
|
|||||||
github.com/google/martian/v3 v3.0.0
|
github.com/google/martian/v3 v3.0.0
|
||||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
||||||
github.com/gopacket/gopacket v1.1.1
|
github.com/gopacket/gopacket v1.1.1
|
||||||
|
github.com/gorilla/websocket v1.5.1
|
||||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
||||||
github.com/hashicorp/go-multierror v1.1.1
|
github.com/hashicorp/go-multierror v1.1.1
|
||||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||||
|
2
go.sum
2
go.sum
@ -317,6 +317,8 @@ github.com/gopherjs/gopherjs v0.0.0-20220410123724-9e86199038b0 h1:fWY+zXdWhvWnd
|
|||||||
github.com/gopherjs/gopherjs v0.0.0-20220410123724-9e86199038b0/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
|
github.com/gopherjs/gopherjs v0.0.0-20220410123724-9e86199038b0/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
|
||||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
||||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||||
|
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
|
||||||
|
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
|
||||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 h1:Fkzd8ktnpOR9h47SXHe2AYPwelXLH2GjGsjlAloiWfo=
|
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 h1:Fkzd8ktnpOR9h47SXHe2AYPwelXLH2GjGsjlAloiWfo=
|
||||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357/go.mod h1:w9Y7gY31krpLmrVU5ZPG9H7l9fZuRu5/3R3S3FMtVQ4=
|
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357/go.mod h1:w9Y7gY31krpLmrVU5ZPG9H7l9fZuRu5/3R3S3FMtVQ4=
|
||||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||||
|
224
relay/client/client.go
Normal file
224
relay/client/client.go
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
bufferSize = 65535 // optimise the buffer size
|
||||||
|
)
|
||||||
|
|
||||||
|
type connContainer struct {
|
||||||
|
conn *Conn
|
||||||
|
messages chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client Todo:
|
||||||
|
// - handle automatic reconnection
|
||||||
|
type Client struct {
|
||||||
|
serverAddress string
|
||||||
|
peerID string
|
||||||
|
|
||||||
|
channelsPending map[string]chan net.Conn // todo: protect map with mutex
|
||||||
|
channels map[uint16]*connContainer
|
||||||
|
msgPool sync.Pool
|
||||||
|
|
||||||
|
relayConn net.Conn
|
||||||
|
relayConnState bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(serverAddress, peerID string) *Client {
|
||||||
|
return &Client{
|
||||||
|
serverAddress: serverAddress,
|
||||||
|
peerID: peerID,
|
||||||
|
channelsPending: make(map[string]chan net.Conn),
|
||||||
|
channels: make(map[uint16]*connContainer),
|
||||||
|
msgPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return make([]byte, bufferSize)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Connect() error {
|
||||||
|
conn, err := ws.Dial(c.serverAddress)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.relayConn = conn
|
||||||
|
|
||||||
|
err = c.handShake()
|
||||||
|
if err != nil {
|
||||||
|
cErr := conn.Close()
|
||||||
|
if cErr != nil {
|
||||||
|
log.Errorf("failed to close connection: %s", cErr)
|
||||||
|
}
|
||||||
|
c.relayConn = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.relayConnState = true
|
||||||
|
go c.readLoop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) BindChannel(remotePeerID string) (net.Conn, error) {
|
||||||
|
if c.relayConn == nil {
|
||||||
|
return nil, fmt.Errorf("client not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
bindSuccessChan := make(chan net.Conn, 1)
|
||||||
|
c.channelsPending[remotePeerID] = bindSuccessChan
|
||||||
|
msg := messages.MarshalBindNewChannelMsg(remotePeerID)
|
||||||
|
_, err := c.relayConn.Write(msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to write out bind message: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, fmt.Errorf("bind timeout")
|
||||||
|
case c := <-bindSuccessChan:
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Close() error {
|
||||||
|
for _, conn := range c.channels {
|
||||||
|
close(conn.messages)
|
||||||
|
}
|
||||||
|
c.channels = make(map[uint16]*connContainer)
|
||||||
|
c.relayConnState = false
|
||||||
|
err := c.relayConn.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) handShake() error {
|
||||||
|
msg, err := messages.MarshalHelloMsg(c.peerID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = c.relayConn.Write(msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to send hello message: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) readLoop() {
|
||||||
|
log := log.WithField("client_id", c.peerID)
|
||||||
|
var errExit error
|
||||||
|
var n int
|
||||||
|
for {
|
||||||
|
buf := c.msgPool.Get().([]byte)
|
||||||
|
n, errExit = c.relayConn.Read(buf)
|
||||||
|
if errExit != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
msgType, err := messages.DetermineServerMsgType(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to determine message type: %s", err)
|
||||||
|
c.msgPool.Put(buf)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msgType {
|
||||||
|
case messages.MsgTypeBindResponse:
|
||||||
|
channelId, peerId, err := messages.UnmarshalBindResponseMsg(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse bind response message: %v", err)
|
||||||
|
} else {
|
||||||
|
c.handleBindResponse(channelId, peerId)
|
||||||
|
}
|
||||||
|
c.msgPool.Put(buf)
|
||||||
|
continue
|
||||||
|
case messages.MsgTypeTransport:
|
||||||
|
channelId, payload, err := messages.UnmarshalTransportMsg(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse transport message: %v", err)
|
||||||
|
c.msgPool.Put(buf)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.handleTransport(channelId, payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.relayConnState {
|
||||||
|
log.Errorf("failed to read message from relay server: %s", errExit)
|
||||||
|
_ = c.relayConn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) handleBindResponse(channelId uint16, peerId string) {
|
||||||
|
bindSuccessChan, ok := c.channelsPending[peerId]
|
||||||
|
if !ok {
|
||||||
|
log.Errorf("unexpected bind response from: %s", peerId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(c.channelsPending, peerId)
|
||||||
|
|
||||||
|
messageBuffer := make(chan []byte, 10)
|
||||||
|
conn := NewConn(c, channelId, c.generateConnReaderFN(messageBuffer))
|
||||||
|
|
||||||
|
c.channels[channelId] = &connContainer{
|
||||||
|
conn,
|
||||||
|
messageBuffer,
|
||||||
|
}
|
||||||
|
log.Debugf("bind success for '%s': %d", peerId, channelId)
|
||||||
|
|
||||||
|
bindSuccessChan <- conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) handleTransport(channelId uint16, payload []byte) {
|
||||||
|
container, ok := c.channels[channelId]
|
||||||
|
if !ok {
|
||||||
|
log.Errorf("c.channels: %v", c.peerID)
|
||||||
|
log.Errorf("unexpected transport message for channel: %d", channelId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case container.messages <- payload:
|
||||||
|
default:
|
||||||
|
log.Errorf("dropping message for channel: %d", channelId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) writeTo(channelID uint16, payload []byte) (int, error) {
|
||||||
|
msg := messages.MarshalTransportMsg(channelID, payload)
|
||||||
|
n, err := c.relayConn.Write(msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to write transport message: %s", err)
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) generateConnReaderFN(messageBufferChan chan []byte) func(b []byte) (n int, err error) {
|
||||||
|
return func(b []byte) (n int, err error) {
|
||||||
|
select {
|
||||||
|
case msg, ok := <-messageBufferChan:
|
||||||
|
if !ok {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n = copy(b, msg)
|
||||||
|
c.msgPool.Put(msg)
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
}
|
59
relay/client/conn.go
Normal file
59
relay/client/conn.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
client *Client
|
||||||
|
channelID uint16
|
||||||
|
readerFn func(b []byte) (n int, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConn(client *Client, channelID uint16, readerFn func(b []byte) (n int, err error)) *Conn {
|
||||||
|
c := &Conn{
|
||||||
|
client: client,
|
||||||
|
channelID: channelID,
|
||||||
|
readerFn: readerFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Write(p []byte) (n int, err error) {
|
||||||
|
return c.client.writeTo(c.channelID, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
|
return c.readerFn(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) LocalAddr() net.Addr {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) RemoteAddr() net.Addr {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||||
|
//TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
7
relay/client/dialer/tcp/tcp.go
Normal file
7
relay/client/dialer/tcp/tcp.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
func Dial(address string) (net.Conn, error) {
|
||||||
|
return net.Dial("tcp", address)
|
||||||
|
}
|
56
relay/client/dialer/ws/client_conn.go
Normal file
56
relay/client/dialer/ws/client_conn.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
*websocket.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConn(wsConn *websocket.Conn) net.Conn {
|
||||||
|
return &Conn{
|
||||||
|
wsConn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
|
t, r, err := c.NextReader()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if t != websocket.BinaryMessage {
|
||||||
|
return 0, fmt.Errorf("unexpected message type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Write(b []byte) (int, error) {
|
||||||
|
err := c.WriteMessage(websocket.BinaryMessage, b)
|
||||||
|
return len(b), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||||||
|
errR := c.SetReadDeadline(t)
|
||||||
|
errW := c.SetWriteDeadline(t)
|
||||||
|
|
||||||
|
if errR != nil {
|
||||||
|
return errR
|
||||||
|
}
|
||||||
|
|
||||||
|
if errW != nil {
|
||||||
|
return errW
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
_ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5))
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
18
relay/client/dialer/ws/ws.go
Normal file
18
relay/client/dialer/ws/ws.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Dial(address string) (net.Conn, error) {
|
||||||
|
addr := fmt.Sprintf("ws://" + address)
|
||||||
|
wsConn, _, err := websocket.DefaultDialer.Dial(addr, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn := NewConn(wsConn)
|
||||||
|
return conn, nil
|
||||||
|
}
|
134
relay/messages/message.go
Normal file
134
relay/messages/message.go
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
package messages
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MsgTypeHello MsgType = 0
|
||||||
|
MsgTypeBindNewChannel MsgType = 1
|
||||||
|
MsgTypeBindResponse MsgType = 2
|
||||||
|
MsgTypeTransport MsgType = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidMessageLength = fmt.Errorf("invalid message length")
|
||||||
|
)
|
||||||
|
|
||||||
|
type MsgType byte
|
||||||
|
|
||||||
|
func DetermineClientMsgType(msg []byte) (MsgType, error) {
|
||||||
|
// todo: validate magic byte
|
||||||
|
msgType := MsgType(msg[0])
|
||||||
|
switch msgType {
|
||||||
|
case MsgTypeHello:
|
||||||
|
return msgType, nil
|
||||||
|
case MsgTypeBindNewChannel:
|
||||||
|
return msgType, nil
|
||||||
|
case MsgTypeTransport:
|
||||||
|
return msgType, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("invalid msg type: %s", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
||||||
|
// todo: validate magic byte
|
||||||
|
msgType := MsgType(msg[0])
|
||||||
|
switch msgType {
|
||||||
|
case MsgTypeBindResponse:
|
||||||
|
return msgType, nil
|
||||||
|
case MsgTypeTransport:
|
||||||
|
return msgType, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("invalid msg type: %s", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalHelloMsg initial hello message
|
||||||
|
func MarshalHelloMsg(peerID string) ([]byte, error) {
|
||||||
|
if len(peerID) == 0 {
|
||||||
|
return nil, fmt.Errorf("invalid peer id")
|
||||||
|
}
|
||||||
|
msg := make([]byte, 1, 1+len(peerID))
|
||||||
|
msg[0] = byte(MsgTypeHello)
|
||||||
|
msg = append(msg, []byte(peerID)...)
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalHelloMsg(msg []byte) (string, error) {
|
||||||
|
if len(msg) < 2 {
|
||||||
|
return "", fmt.Errorf("invalid 'hello' messge")
|
||||||
|
}
|
||||||
|
return string(msg[1:]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind new channel
|
||||||
|
|
||||||
|
func MarshalBindNewChannelMsg(destinationPeerId string) []byte {
|
||||||
|
msg := make([]byte, 1, 1+len(destinationPeerId))
|
||||||
|
msg[0] = byte(MsgTypeBindNewChannel)
|
||||||
|
msg = append(msg, []byte(destinationPeerId)...)
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalBindNewChannel(msg []byte) (string, error) {
|
||||||
|
if len(msg) < 2 {
|
||||||
|
return "", fmt.Errorf("invalid 'bind new channel' messge")
|
||||||
|
}
|
||||||
|
return string(msg[1:]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind response
|
||||||
|
|
||||||
|
func MarshalBindResponseMsg(channelId uint16, id string) []byte {
|
||||||
|
data := []byte(id)
|
||||||
|
msg := make([]byte, 3, 3+len(data))
|
||||||
|
msg[0] = byte(MsgTypeBindResponse)
|
||||||
|
msg[1], msg[2] = uint8(channelId>>8), uint8(channelId&0xff)
|
||||||
|
msg = append(msg, data...)
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalBindResponseMsg(buf []byte) (uint16, string, error) {
|
||||||
|
if len(buf) < 3 {
|
||||||
|
return 0, "", ErrInvalidMessageLength
|
||||||
|
}
|
||||||
|
channelId := uint16(buf[1])<<8 | uint16(buf[2])
|
||||||
|
peerID := string(buf[3:])
|
||||||
|
return channelId, peerID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transport message
|
||||||
|
|
||||||
|
func MarshalTransportMsg(channelId uint16, payload []byte) []byte {
|
||||||
|
msg := make([]byte, 3, 3+len(payload))
|
||||||
|
msg[0] = byte(MsgTypeTransport)
|
||||||
|
msg[1], msg[2] = uint8(channelId>>8), uint8(channelId&0xff)
|
||||||
|
msg = append(msg, payload...)
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalTransportMsg(buf []byte) (uint16, []byte, error) {
|
||||||
|
if len(buf) < 3 {
|
||||||
|
return 0, nil, ErrInvalidMessageLength
|
||||||
|
}
|
||||||
|
channelId := uint16(buf[1])<<8 | uint16(buf[2])
|
||||||
|
return channelId, buf[3:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalTransportID(buf []byte) (uint16, error) {
|
||||||
|
if len(buf) < 3 {
|
||||||
|
return 0, ErrInvalidMessageLength
|
||||||
|
}
|
||||||
|
channelId := uint16(buf[1])<<8 | uint16(buf[2])
|
||||||
|
return channelId, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateTransportMsg(msg []byte, channelId uint16) error {
|
||||||
|
if len(msg) < 3 {
|
||||||
|
return ErrInvalidMessageLength
|
||||||
|
}
|
||||||
|
msg[1], msg[2] = uint8(channelId>>8), uint8(channelId&0xff)
|
||||||
|
return nil
|
||||||
|
}
|
8
relay/server/listener/listener.go
Normal file
8
relay/server/listener/listener.go
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
package listener
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
type Listener interface {
|
||||||
|
Listen(func(conn net.Conn)) error
|
||||||
|
Close() error
|
||||||
|
}
|
80
relay/server/listener/tcp/listener.go
Normal file
80
relay/server/listener/tcp/listener.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/server/listener"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Listener
|
||||||
|
// Is it just demo code. It does not work in real life environment because the TCP is a streaming protocol, adn
|
||||||
|
// it does not handle framing.
|
||||||
|
type Listener struct {
|
||||||
|
address string
|
||||||
|
|
||||||
|
onAcceptFn func(conn net.Conn)
|
||||||
|
wg sync.WaitGroup
|
||||||
|
quit chan struct{}
|
||||||
|
listener net.Listener
|
||||||
|
lock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewListener(address string) listener.Listener {
|
||||||
|
return &Listener{
|
||||||
|
address: address,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
|
||||||
|
l.lock.Lock()
|
||||||
|
|
||||||
|
l.onAcceptFn = onAcceptFn
|
||||||
|
l.quit = make(chan struct{})
|
||||||
|
|
||||||
|
li, err := net.Listen("tcp", l.address)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to listen on address: %s, %s", l.address, err)
|
||||||
|
l.lock.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debugf("TCP server is listening on address: %s", l.address)
|
||||||
|
l.listener = li
|
||||||
|
l.wg.Add(1)
|
||||||
|
go l.acceptLoop()
|
||||||
|
|
||||||
|
l.lock.Unlock()
|
||||||
|
<-l.quit
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close todo: prevent multiple call (do not close two times the channel)
|
||||||
|
func (l *Listener) Close() error {
|
||||||
|
l.lock.Lock()
|
||||||
|
defer l.lock.Unlock()
|
||||||
|
|
||||||
|
close(l.quit)
|
||||||
|
err := l.listener.Close()
|
||||||
|
l.wg.Wait()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Listener) acceptLoop() {
|
||||||
|
defer l.wg.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := l.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-l.quit:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
log.Errorf("failed to accept connection: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
go l.onAcceptFn(conn)
|
||||||
|
}
|
||||||
|
}
|
82
relay/server/listener/ws/listener.go
Normal file
82
relay/server/listener/ws/listener.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/server/listener"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
upgrader = websocket.Upgrader{} // use default options
|
||||||
|
)
|
||||||
|
|
||||||
|
type Listener struct {
|
||||||
|
address string
|
||||||
|
|
||||||
|
wg sync.WaitGroup
|
||||||
|
server *http.Server
|
||||||
|
acceptFn func(conn net.Conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewListener(address string) listener.Listener {
|
||||||
|
return &Listener{
|
||||||
|
address: address,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen todo: prevent multiple call
|
||||||
|
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
|
||||||
|
l.acceptFn = acceptFn
|
||||||
|
http.HandleFunc("/", l.onAccept)
|
||||||
|
|
||||||
|
l.server = &http.Server{
|
||||||
|
Addr: l.address,
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("WS server is listening on address: %s", l.address)
|
||||||
|
err := l.server.ListenAndServe()
|
||||||
|
if errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Listener) Close() error {
|
||||||
|
if l.server == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
log.Debugf("closing WS server")
|
||||||
|
if err := l.server.Shutdown(ctx); err != nil {
|
||||||
|
return fmt.Errorf("server shutdown failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.wg.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Listener) onAccept(writer http.ResponseWriter, request *http.Request) {
|
||||||
|
l.wg.Add(1)
|
||||||
|
defer l.wg.Done()
|
||||||
|
|
||||||
|
wsConn, err := upgrader.Upgrade(writer, request, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to upgrade connection: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn := NewConn(wsConn)
|
||||||
|
l.acceptFn(conn)
|
||||||
|
return
|
||||||
|
}
|
52
relay/server/listener/ws/server_conn.go
Normal file
52
relay/server/listener/ws/server_conn.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
*websocket.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConn(wsConn *websocket.Conn) *Conn {
|
||||||
|
return &Conn{
|
||||||
|
wsConn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
|
t, r, err := c.NextReader()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if t != websocket.BinaryMessage {
|
||||||
|
log.Errorf("unexpected message type: %d", t)
|
||||||
|
return 0, fmt.Errorf("unexpected message type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Write(b []byte) (int, error) {
|
||||||
|
err := c.WriteMessage(websocket.BinaryMessage, b)
|
||||||
|
return len(b), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||||||
|
errR := c.SetReadDeadline(t)
|
||||||
|
errW := c.SetWriteDeadline(t)
|
||||||
|
|
||||||
|
if errR != nil {
|
||||||
|
return errR
|
||||||
|
}
|
||||||
|
|
||||||
|
if errW != nil {
|
||||||
|
return errW
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
113
relay/server/peer.go
Normal file
113
relay/server/peer.go
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Participant struct {
|
||||||
|
ChannelID uint16
|
||||||
|
ChannelIDForeign uint16
|
||||||
|
ConnForeign net.Conn
|
||||||
|
Peer *Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
type Peer struct {
|
||||||
|
Log *log.Entry
|
||||||
|
id string
|
||||||
|
conn net.Conn
|
||||||
|
|
||||||
|
pendingParticipantByChannelID map[uint16]*Participant
|
||||||
|
participantByID map[uint16]*Participant // used for package transfer
|
||||||
|
participantByPeerID map[string]*Participant // used for channel linking
|
||||||
|
|
||||||
|
lastId uint16
|
||||||
|
lastIdLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPeer(id string, conn net.Conn) *Peer {
|
||||||
|
return &Peer{
|
||||||
|
Log: log.WithField("peer_id", id),
|
||||||
|
id: id,
|
||||||
|
conn: conn,
|
||||||
|
pendingParticipantByChannelID: make(map[uint16]*Participant),
|
||||||
|
participantByID: make(map[uint16]*Participant),
|
||||||
|
participantByPeerID: make(map[string]*Participant),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (p *Peer) BindChannel(remotePeerId string) uint16 {
|
||||||
|
ch, ok := p.participantByPeerID[remotePeerId]
|
||||||
|
if ok {
|
||||||
|
return ch.ChannelID
|
||||||
|
}
|
||||||
|
|
||||||
|
channelID := p.newChannelID()
|
||||||
|
channel := &Participant{
|
||||||
|
ChannelID: channelID,
|
||||||
|
}
|
||||||
|
p.pendingParticipantByChannelID[channelID] = channel
|
||||||
|
p.participantByPeerID[remotePeerId] = channel
|
||||||
|
return channelID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) UnBindChannel(remotePeerId string) {
|
||||||
|
pa, ok := p.participantByPeerID[remotePeerId]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.Log.Debugf("unbind channel with '%s': %d", remotePeerId, pa.ChannelID)
|
||||||
|
p.pendingParticipantByChannelID[pa.ChannelID] = pa
|
||||||
|
delete(p.participantByID, pa.ChannelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) AddParticipant(peer *Peer, remoteChannelID uint16) (uint16, bool) {
|
||||||
|
participant, ok := p.participantByPeerID[peer.ID()]
|
||||||
|
if !ok {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
participant.ChannelIDForeign = remoteChannelID
|
||||||
|
participant.ConnForeign = peer.conn
|
||||||
|
participant.Peer = peer
|
||||||
|
|
||||||
|
delete(p.pendingParticipantByChannelID, participant.ChannelID)
|
||||||
|
p.participantByID[participant.ChannelID] = participant
|
||||||
|
return participant.ChannelID, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) DeleteParticipants() {
|
||||||
|
for _, participant := range p.participantByID {
|
||||||
|
participant.Peer.UnBindChannel(p.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) ConnByChannelID(dstID uint16) (uint16, net.Conn, error) {
|
||||||
|
ch, ok := p.participantByID[dstID]
|
||||||
|
if !ok {
|
||||||
|
return 0, nil, fmt.Errorf("destination channel not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return ch.ChannelIDForeign, ch.ConnForeign, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) ID() string {
|
||||||
|
return p.id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) newChannelID() uint16 {
|
||||||
|
p.lastIdLock.Lock()
|
||||||
|
defer p.lastIdLock.Unlock()
|
||||||
|
for {
|
||||||
|
p.lastId++
|
||||||
|
if _, ok := p.pendingParticipantByChannelID[p.lastId]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := p.participantByID[p.lastId]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return p.lastId
|
||||||
|
}
|
||||||
|
}
|
149
relay/server/server.go
Normal file
149
relay/server/server.go
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
|
"github.com/netbirdio/netbird/relay/server/listener"
|
||||||
|
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server
|
||||||
|
// todo:
|
||||||
|
// authentication: provide JWT token via RPC call. The MGM server can forward the token to the agents.
|
||||||
|
// connection timeout handling
|
||||||
|
// implement HA (High Availability) mode
|
||||||
|
type Server struct {
|
||||||
|
store *Store
|
||||||
|
|
||||||
|
listener listener.Listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServer() *Server {
|
||||||
|
return &Server{
|
||||||
|
store: NewStore(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Server) Listen(address string) error {
|
||||||
|
r.listener = ws.NewListener(address)
|
||||||
|
return r.listener.Listen(r.accept)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Server) Close() error {
|
||||||
|
if r.listener == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return r.listener.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Server) accept(conn net.Conn) {
|
||||||
|
peer, err := handShake(conn)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to handshake wiht %s: %s", conn.RemoteAddr(), err)
|
||||||
|
cErr := conn.Close()
|
||||||
|
if cErr != nil {
|
||||||
|
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
peer.Log.Debugf("on new connection: %s", conn.RemoteAddr())
|
||||||
|
|
||||||
|
r.store.AddPeer(peer)
|
||||||
|
defer func() {
|
||||||
|
peer.Log.Debugf("teardown connection")
|
||||||
|
r.store.DeletePeer(peer)
|
||||||
|
}()
|
||||||
|
|
||||||
|
buf := make([]byte, 65535) // todo: optimize buffer size
|
||||||
|
for {
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
peer.Log.Errorf("failed to read message: %s", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msgType, err := messages.DetermineClientMsgType(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to determine message type: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch msgType {
|
||||||
|
case messages.MsgTypeBindNewChannel:
|
||||||
|
dstPeerId, err := messages.UnmarshalBindNewChannel(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to unmarshal bind new channel message: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
channelID := r.store.Link(peer, dstPeerId)
|
||||||
|
|
||||||
|
msg := messages.MarshalBindResponseMsg(channelID, dstPeerId)
|
||||||
|
_, err = conn.Write(msg)
|
||||||
|
if err != nil {
|
||||||
|
peer.Log.Errorf("failed to response to bind request: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
peer.Log.Debugf("bind new channel with '%s', channelID: %d", dstPeerId, channelID)
|
||||||
|
case messages.MsgTypeTransport:
|
||||||
|
msg := buf[:n]
|
||||||
|
channelId, err := messages.UnmarshalTransportID(msg)
|
||||||
|
if err != nil {
|
||||||
|
peer.Log.Errorf("failed to unmarshal transport message: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
foreignChannelID, remoteConn, err := peer.ConnByChannelID(channelId)
|
||||||
|
if err != nil {
|
||||||
|
peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transportTo(remoteConn, foreignChannelID, msg)
|
||||||
|
if err != nil {
|
||||||
|
peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func transportTo(conn net.Conn, channelID uint16, msg []byte) error {
|
||||||
|
err := messages.UpdateTransportMsg(msg, channelID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = conn.Write(msg)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func handShake(conn net.Conn) (*Peer, error) {
|
||||||
|
buf := make([]byte, 65535) // todo: reduce the buffer size
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to read message: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msgType, err := messages.DetermineClientMsgType(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if msgType != messages.MsgTypeHello {
|
||||||
|
tErr := fmt.Errorf("invalid message type")
|
||||||
|
log.Errorf("failed to handshake: %s", tErr)
|
||||||
|
return nil, tErr
|
||||||
|
}
|
||||||
|
peerId, err := messages.UnmarshalHelloMsg(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to handshake: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
p := NewPeer(peerId, conn)
|
||||||
|
return p, nil
|
||||||
|
}
|
48
relay/server/store.go
Normal file
48
relay/server/store.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Store struct {
|
||||||
|
peers map[string]*Peer // Key is the id (public key or sha-256) of the peer
|
||||||
|
peersLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStore() *Store {
|
||||||
|
return &Store{
|
||||||
|
peers: make(map[string]*Peer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) AddPeer(peer *Peer) {
|
||||||
|
s.peersLock.Lock()
|
||||||
|
defer s.peersLock.Unlock()
|
||||||
|
s.peers[peer.ID()] = peer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Link(peer *Peer, peerForeignID string) uint16 {
|
||||||
|
s.peersLock.Lock()
|
||||||
|
defer s.peersLock.Unlock()
|
||||||
|
|
||||||
|
channelId := peer.BindChannel(peerForeignID)
|
||||||
|
dstPeer, ok := s.peers[peerForeignID]
|
||||||
|
if !ok {
|
||||||
|
return channelId
|
||||||
|
}
|
||||||
|
|
||||||
|
foreignChannelID, ok := dstPeer.AddParticipant(peer, channelId)
|
||||||
|
if !ok {
|
||||||
|
return channelId
|
||||||
|
}
|
||||||
|
peer.AddParticipant(dstPeer, foreignChannelID)
|
||||||
|
return channelId
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeletePeer(peer *Peer) {
|
||||||
|
s.peersLock.Lock()
|
||||||
|
defer s.peersLock.Unlock()
|
||||||
|
|
||||||
|
delete(s.peers, peer.ID())
|
||||||
|
peer.DeleteParticipants()
|
||||||
|
}
|
285
relay/test/client_test.go
Normal file
285
relay/test/client_test.go
Normal file
@ -0,0 +1,285 @@
|
|||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/client"
|
||||||
|
"github.com/netbirdio/netbird/relay/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
_ = util.InitLog("trace", "console")
|
||||||
|
code := m.Run()
|
||||||
|
os.Exit(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient(t *testing.T) {
|
||||||
|
addr := "localhost:1234"
|
||||||
|
srv := server.NewServer()
|
||||||
|
go func() {
|
||||||
|
err := srv.Listen(addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := srv.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientAlice := client.NewClient(addr, "alice")
|
||||||
|
err := clientAlice.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
defer clientAlice.Close()
|
||||||
|
|
||||||
|
clientPlaceHolder := client.NewClient(addr, "clientPlaceHolder")
|
||||||
|
err = clientPlaceHolder.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
defer clientPlaceHolder.Close()
|
||||||
|
|
||||||
|
_, err = clientAlice.BindChannel("clientPlaceHolder")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientBob := client.NewClient(addr, "bob")
|
||||||
|
err = clientBob.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
defer clientBob.Close()
|
||||||
|
|
||||||
|
connAliceToBob, err := clientAlice.BindChannel("bob")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
connBobToAlice, err := clientBob.BindChannel("alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := "hello bob, I am alice"
|
||||||
|
_, err = connAliceToBob.Write([]byte(payload))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to write to channel: %s", err)
|
||||||
|
}
|
||||||
|
log.Debugf("alice sent message to bob")
|
||||||
|
|
||||||
|
buf := make([]byte, 65535)
|
||||||
|
n, err := connBobToAlice.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read from channel: %s", err)
|
||||||
|
}
|
||||||
|
log.Debugf("on new message from alice to bob")
|
||||||
|
|
||||||
|
if payload != string(buf[:n]) {
|
||||||
|
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEcho(t *testing.T) {
|
||||||
|
addr := "localhost:1234"
|
||||||
|
srv := server.NewServer()
|
||||||
|
go func() {
|
||||||
|
err := srv.Listen(addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := srv.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientAlice := client.NewClient(addr, "alice")
|
||||||
|
err := clientAlice.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := clientAlice.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close Alice client: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientBob := client.NewClient(addr, "bob")
|
||||||
|
err = clientBob.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := clientBob.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close Bob client: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
connAliceToBob, err := clientAlice.BindChannel("bob")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
connBobToAlice, err := clientBob.BindChannel("alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := "hello bob, I am alice"
|
||||||
|
_, err = connAliceToBob.Write([]byte(payload))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to write to channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 65535)
|
||||||
|
n, err := connBobToAlice.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read from channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = connBobToAlice.Write(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to write to channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = connAliceToBob.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read from channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if payload != string(buf[:n]) {
|
||||||
|
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBindToUnavailabePeer(t *testing.T) {
|
||||||
|
addr := "localhost:1234"
|
||||||
|
srv := server.NewServer()
|
||||||
|
go func() {
|
||||||
|
err := srv.Listen(addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to bind server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
log.Infof("closing server")
|
||||||
|
err := srv.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientAlice := client.NewClient(addr, "alice")
|
||||||
|
err := clientAlice.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
log.Infof("closing client")
|
||||||
|
err := clientAlice.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close client: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = clientAlice.BindChannel("bob")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBindReconnect(t *testing.T) {
|
||||||
|
addr := "localhost:1234"
|
||||||
|
srv := server.NewServer()
|
||||||
|
go func() {
|
||||||
|
err := srv.Listen(addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to bind server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
log.Infof("closing server")
|
||||||
|
err := srv.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close server: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientAlice := client.NewClient(addr, "alice")
|
||||||
|
err := clientAlice.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = clientAlice.BindChannel("bob")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientBob := client.NewClient(addr, "bob")
|
||||||
|
err = clientBob.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chBob, err := clientBob.BindChannel("alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("closing client")
|
||||||
|
err = clientAlice.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close client: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientAlice = client.NewClient(addr, "alice")
|
||||||
|
err = clientAlice.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chAlice, err := clientAlice.BindChannel("bob")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testString := "hello alice, I am bob"
|
||||||
|
_, err = chBob.Write([]byte(testString))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to write to channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 65535)
|
||||||
|
n, err := chAlice.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to read from channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if testString != string(buf[:n]) {
|
||||||
|
t.Errorf("expected %s, got %s", testString, string(buf[:n]))
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("closing client")
|
||||||
|
err = clientAlice.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to close client: %s", err)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user