diff --git a/go.mod b/go.mod index f75ddcb6f..cf9308a20 100644 --- a/go.mod +++ b/go.mod @@ -49,6 +49,7 @@ require ( github.com/google/martian/v3 v3.0.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/gopacket/gopacket v1.1.1 + github.com/gorilla/websocket v1.5.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 diff --git a/go.sum b/go.sum index 49314e729..1c5dd13c1 100644 --- a/go.sum +++ b/go.sum @@ -317,6 +317,8 @@ github.com/gopherjs/gopherjs v0.0.0-20220410123724-9e86199038b0 h1:fWY+zXdWhvWnd github.com/gopherjs/gopherjs v0.0.0-20220410123724-9e86199038b0/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 h1:Fkzd8ktnpOR9h47SXHe2AYPwelXLH2GjGsjlAloiWfo= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357/go.mod h1:w9Y7gY31krpLmrVU5ZPG9H7l9fZuRu5/3R3S3FMtVQ4= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= diff --git a/relay/client/client.go b/relay/client/client.go new file mode 100644 index 000000000..110c75ec2 --- /dev/null +++ b/relay/client/client.go @@ -0,0 +1,224 @@ +package client + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/client/dialer/ws" + "github.com/netbirdio/netbird/relay/messages" +) + +const ( + bufferSize = 65535 // optimise the buffer size +) + +type connContainer struct { + conn *Conn + messages chan []byte +} + +// Client Todo: +// - handle automatic reconnection +type Client struct { + serverAddress string + peerID string + + channelsPending map[string]chan net.Conn // todo: protect map with mutex + channels map[uint16]*connContainer + msgPool sync.Pool + + relayConn net.Conn + relayConnState bool +} + +func NewClient(serverAddress, peerID string) *Client { + return &Client{ + serverAddress: serverAddress, + peerID: peerID, + channelsPending: make(map[string]chan net.Conn), + channels: make(map[uint16]*connContainer), + msgPool: sync.Pool{ + New: func() any { + return make([]byte, bufferSize) + }, + }, + } +} + +func (c *Client) Connect() error { + conn, err := ws.Dial(c.serverAddress) + if err != nil { + return err + } + c.relayConn = conn + + err = c.handShake() + if err != nil { + cErr := conn.Close() + if cErr != nil { + log.Errorf("failed to close connection: %s", cErr) + } + c.relayConn = nil + return err + } + + c.relayConnState = true + go c.readLoop() + return nil +} + +func (c *Client) BindChannel(remotePeerID string) (net.Conn, error) { + if c.relayConn == nil { + return nil, fmt.Errorf("client not connected") + } + + bindSuccessChan := make(chan net.Conn, 1) + c.channelsPending[remotePeerID] = bindSuccessChan + msg := messages.MarshalBindNewChannelMsg(remotePeerID) + _, err := c.relayConn.Write(msg) + if err != nil { + log.Errorf("failed to write out bind message: %s", err) + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + select { + case <-ctx.Done(): + return nil, fmt.Errorf("bind timeout") + case c := <-bindSuccessChan: + return c, nil + } +} + +func (c *Client) Close() error { + for _, conn := range c.channels { + close(conn.messages) + } + c.channels = make(map[uint16]*connContainer) + c.relayConnState = false + err := c.relayConn.Close() + return err +} + +func (c *Client) handShake() error { + msg, err := messages.MarshalHelloMsg(c.peerID) + if err != nil { + return err + } + _, err = c.relayConn.Write(msg) + if err != nil { + log.Errorf("failed to send hello message: %s", err) + return err + } + return nil +} + +func (c *Client) readLoop() { + log := log.WithField("client_id", c.peerID) + var errExit error + var n int + for { + buf := c.msgPool.Get().([]byte) + n, errExit = c.relayConn.Read(buf) + if errExit != nil { + break + } + + msgType, err := messages.DetermineServerMsgType(buf[:n]) + if err != nil { + log.Errorf("failed to determine message type: %s", err) + c.msgPool.Put(buf) + continue + } + + switch msgType { + case messages.MsgTypeBindResponse: + channelId, peerId, err := messages.UnmarshalBindResponseMsg(buf[:n]) + if err != nil { + log.Errorf("failed to parse bind response message: %v", err) + } else { + c.handleBindResponse(channelId, peerId) + } + c.msgPool.Put(buf) + continue + case messages.MsgTypeTransport: + channelId, payload, err := messages.UnmarshalTransportMsg(buf[:n]) + if err != nil { + log.Errorf("failed to parse transport message: %v", err) + c.msgPool.Put(buf) + continue + } + c.handleTransport(channelId, payload) + } + } + + if c.relayConnState { + log.Errorf("failed to read message from relay server: %s", errExit) + _ = c.relayConn.Close() + } +} + +func (c *Client) handleBindResponse(channelId uint16, peerId string) { + bindSuccessChan, ok := c.channelsPending[peerId] + if !ok { + log.Errorf("unexpected bind response from: %s", peerId) + return + } + delete(c.channelsPending, peerId) + + messageBuffer := make(chan []byte, 10) + conn := NewConn(c, channelId, c.generateConnReaderFN(messageBuffer)) + + c.channels[channelId] = &connContainer{ + conn, + messageBuffer, + } + log.Debugf("bind success for '%s': %d", peerId, channelId) + + bindSuccessChan <- conn +} + +func (c *Client) handleTransport(channelId uint16, payload []byte) { + container, ok := c.channels[channelId] + if !ok { + log.Errorf("c.channels: %v", c.peerID) + log.Errorf("unexpected transport message for channel: %d", channelId) + return + } + + select { + case container.messages <- payload: + default: + log.Errorf("dropping message for channel: %d", channelId) + } +} + +func (c *Client) writeTo(channelID uint16, payload []byte) (int, error) { + msg := messages.MarshalTransportMsg(channelID, payload) + n, err := c.relayConn.Write(msg) + if err != nil { + log.Errorf("failed to write transport message: %s", err) + } + return n, err +} + +func (c *Client) generateConnReaderFN(messageBufferChan chan []byte) func(b []byte) (n int, err error) { + return func(b []byte) (n int, err error) { + select { + case msg, ok := <-messageBufferChan: + if !ok { + return 0, io.EOF + } + n = copy(b, msg) + c.msgPool.Put(msg) + } + return n, nil + } +} diff --git a/relay/client/conn.go b/relay/client/conn.go new file mode 100644 index 000000000..a5635e020 --- /dev/null +++ b/relay/client/conn.go @@ -0,0 +1,59 @@ +package client + +import ( + "net" + "time" +) + +type Conn struct { + client *Client + channelID uint16 + readerFn func(b []byte) (n int, err error) +} + +func NewConn(client *Client, channelID uint16, readerFn func(b []byte) (n int, err error)) *Conn { + c := &Conn{ + client: client, + channelID: channelID, + readerFn: readerFn, + } + + return c +} + +func (c *Conn) Write(p []byte) (n int, err error) { + return c.client.writeTo(c.channelID, p) +} + +func (c *Conn) Read(b []byte) (n int, err error) { + return c.readerFn(b) +} + +func (c *Conn) Close() error { + return nil +} + +func (c *Conn) LocalAddr() net.Addr { + //TODO implement me + panic("implement me") +} + +func (c *Conn) RemoteAddr() net.Addr { + //TODO implement me + panic("implement me") +} + +func (c *Conn) SetDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} diff --git a/relay/client/dialer/tcp/tcp.go b/relay/client/dialer/tcp/tcp.go new file mode 100644 index 000000000..47b0a31d5 --- /dev/null +++ b/relay/client/dialer/tcp/tcp.go @@ -0,0 +1,7 @@ +package tcp + +import "net" + +func Dial(address string) (net.Conn, error) { + return net.Dial("tcp", address) +} diff --git a/relay/client/dialer/ws/client_conn.go b/relay/client/dialer/ws/client_conn.go new file mode 100644 index 000000000..3298bd228 --- /dev/null +++ b/relay/client/dialer/ws/client_conn.go @@ -0,0 +1,56 @@ +package ws + +import ( + "fmt" + "net" + "time" + + "github.com/gorilla/websocket" +) + +type Conn struct { + *websocket.Conn +} + +func NewConn(wsConn *websocket.Conn) net.Conn { + return &Conn{ + wsConn, + } +} + +func (c *Conn) Read(b []byte) (n int, err error) { + t, r, err := c.NextReader() + if err != nil { + return 0, err + } + + if t != websocket.BinaryMessage { + return 0, fmt.Errorf("unexpected message type") + } + + return r.Read(b) +} + +func (c *Conn) Write(b []byte) (int, error) { + err := c.WriteMessage(websocket.BinaryMessage, b) + return len(b), err +} + +func (c *Conn) SetDeadline(t time.Time) error { + errR := c.SetReadDeadline(t) + errW := c.SetWriteDeadline(t) + + if errR != nil { + return errR + } + + if errW != nil { + return errW + } + return nil +} + +func (c *Conn) Close() error { + _ = c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)) + return c.Conn.Close() +} diff --git a/relay/client/dialer/ws/ws.go b/relay/client/dialer/ws/ws.go new file mode 100644 index 000000000..3ab3d38a8 --- /dev/null +++ b/relay/client/dialer/ws/ws.go @@ -0,0 +1,18 @@ +package ws + +import ( + "fmt" + "net" + + "github.com/gorilla/websocket" +) + +func Dial(address string) (net.Conn, error) { + addr := fmt.Sprintf("ws://" + address) + wsConn, _, err := websocket.DefaultDialer.Dial(addr, nil) + if err != nil { + return nil, err + } + conn := NewConn(wsConn) + return conn, nil +} diff --git a/relay/messages/message.go b/relay/messages/message.go new file mode 100644 index 000000000..ee4947afe --- /dev/null +++ b/relay/messages/message.go @@ -0,0 +1,134 @@ +package messages + +import ( + "fmt" +) + +const ( + MsgTypeHello MsgType = 0 + MsgTypeBindNewChannel MsgType = 1 + MsgTypeBindResponse MsgType = 2 + MsgTypeTransport MsgType = 3 +) + +var ( + ErrInvalidMessageLength = fmt.Errorf("invalid message length") +) + +type MsgType byte + +func DetermineClientMsgType(msg []byte) (MsgType, error) { + // todo: validate magic byte + msgType := MsgType(msg[0]) + switch msgType { + case MsgTypeHello: + return msgType, nil + case MsgTypeBindNewChannel: + return msgType, nil + case MsgTypeTransport: + return msgType, nil + default: + return 0, fmt.Errorf("invalid msg type: %s", msg) + } +} + +func DetermineServerMsgType(msg []byte) (MsgType, error) { + // todo: validate magic byte + msgType := MsgType(msg[0]) + switch msgType { + case MsgTypeBindResponse: + return msgType, nil + case MsgTypeTransport: + return msgType, nil + default: + return 0, fmt.Errorf("invalid msg type: %s", msg) + } +} + +// MarshalHelloMsg initial hello message +func MarshalHelloMsg(peerID string) ([]byte, error) { + if len(peerID) == 0 { + return nil, fmt.Errorf("invalid peer id") + } + msg := make([]byte, 1, 1+len(peerID)) + msg[0] = byte(MsgTypeHello) + msg = append(msg, []byte(peerID)...) + return msg, nil +} + +func UnmarshalHelloMsg(msg []byte) (string, error) { + if len(msg) < 2 { + return "", fmt.Errorf("invalid 'hello' messge") + } + return string(msg[1:]), nil +} + +// Bind new channel + +func MarshalBindNewChannelMsg(destinationPeerId string) []byte { + msg := make([]byte, 1, 1+len(destinationPeerId)) + msg[0] = byte(MsgTypeBindNewChannel) + msg = append(msg, []byte(destinationPeerId)...) + return msg +} + +func UnmarshalBindNewChannel(msg []byte) (string, error) { + if len(msg) < 2 { + return "", fmt.Errorf("invalid 'bind new channel' messge") + } + return string(msg[1:]), nil +} + +// Bind response + +func MarshalBindResponseMsg(channelId uint16, id string) []byte { + data := []byte(id) + msg := make([]byte, 3, 3+len(data)) + msg[0] = byte(MsgTypeBindResponse) + msg[1], msg[2] = uint8(channelId>>8), uint8(channelId&0xff) + msg = append(msg, data...) + return msg +} + +func UnmarshalBindResponseMsg(buf []byte) (uint16, string, error) { + if len(buf) < 3 { + return 0, "", ErrInvalidMessageLength + } + channelId := uint16(buf[1])<<8 | uint16(buf[2]) + peerID := string(buf[3:]) + return channelId, peerID, nil +} + +// Transport message + +func MarshalTransportMsg(channelId uint16, payload []byte) []byte { + msg := make([]byte, 3, 3+len(payload)) + msg[0] = byte(MsgTypeTransport) + msg[1], msg[2] = uint8(channelId>>8), uint8(channelId&0xff) + msg = append(msg, payload...) + return msg +} + +func UnmarshalTransportMsg(buf []byte) (uint16, []byte, error) { + if len(buf) < 3 { + return 0, nil, ErrInvalidMessageLength + } + channelId := uint16(buf[1])<<8 | uint16(buf[2]) + return channelId, buf[3:], nil +} + +func UnmarshalTransportID(buf []byte) (uint16, error) { + if len(buf) < 3 { + return 0, ErrInvalidMessageLength + } + channelId := uint16(buf[1])<<8 | uint16(buf[2]) + return channelId, nil +} + +func UpdateTransportMsg(msg []byte, channelId uint16) error { + if len(msg) < 3 { + return ErrInvalidMessageLength + } + msg[1], msg[2] = uint8(channelId>>8), uint8(channelId&0xff) + return nil +} diff --git a/relay/server/listener/listener.go b/relay/server/listener/listener.go new file mode 100644 index 000000000..66e6d357e --- /dev/null +++ b/relay/server/listener/listener.go @@ -0,0 +1,8 @@ +package listener + +import "net" + +type Listener interface { + Listen(func(conn net.Conn)) error + Close() error +} diff --git a/relay/server/listener/tcp/listener.go b/relay/server/listener/tcp/listener.go new file mode 100644 index 000000000..bd96e3b30 --- /dev/null +++ b/relay/server/listener/tcp/listener.go @@ -0,0 +1,80 @@ +package tcp + +import ( + "net" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/server/listener" +) + +// Listener +// Is it just demo code. It does not work in real life environment because the TCP is a streaming protocol, adn +// it does not handle framing. +type Listener struct { + address string + + onAcceptFn func(conn net.Conn) + wg sync.WaitGroup + quit chan struct{} + listener net.Listener + lock sync.Mutex +} + +func NewListener(address string) listener.Listener { + return &Listener{ + address: address, + } +} + +func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error { + l.lock.Lock() + + l.onAcceptFn = onAcceptFn + l.quit = make(chan struct{}) + + li, err := net.Listen("tcp", l.address) + if err != nil { + log.Errorf("failed to listen on address: %s, %s", l.address, err) + l.lock.Unlock() + return err + } + log.Debugf("TCP server is listening on address: %s", l.address) + l.listener = li + l.wg.Add(1) + go l.acceptLoop() + + l.lock.Unlock() + <-l.quit + return nil +} + +// Close todo: prevent multiple call (do not close two times the channel) +func (l *Listener) Close() error { + l.lock.Lock() + defer l.lock.Unlock() + + close(l.quit) + err := l.listener.Close() + l.wg.Wait() + return err +} + +func (l *Listener) acceptLoop() { + defer l.wg.Done() + + for { + conn, err := l.listener.Accept() + if err != nil { + select { + case <-l.quit: + return + default: + log.Errorf("failed to accept connection: %s", err) + continue + } + } + go l.onAcceptFn(conn) + } +} diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go new file mode 100644 index 000000000..cee57e348 --- /dev/null +++ b/relay/server/listener/ws/listener.go @@ -0,0 +1,82 @@ +package ws + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/server/listener" +) + +var ( + upgrader = websocket.Upgrader{} // use default options +) + +type Listener struct { + address string + + wg sync.WaitGroup + server *http.Server + acceptFn func(conn net.Conn) +} + +func NewListener(address string) listener.Listener { + return &Listener{ + address: address, + } +} + +// Listen todo: prevent multiple call +func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { + l.acceptFn = acceptFn + http.HandleFunc("/", l.onAccept) + + l.server = &http.Server{ + Addr: l.address, + } + + log.Debugf("WS server is listening on address: %s", l.address) + err := l.server.ListenAndServe() + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err +} + +func (l *Listener) Close() error { + if l.server == nil { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + log.Debugf("closing WS server") + if err := l.server.Shutdown(ctx); err != nil { + return fmt.Errorf("server shutdown failed: %v", err) + } + + l.wg.Wait() + return nil +} + +func (l *Listener) onAccept(writer http.ResponseWriter, request *http.Request) { + l.wg.Add(1) + defer l.wg.Done() + + wsConn, err := upgrader.Upgrade(writer, request, nil) + if err != nil { + log.Errorf("failed to upgrade connection: %s", err) + return + } + conn := NewConn(wsConn) + l.acceptFn(conn) + return +} diff --git a/relay/server/listener/ws/server_conn.go b/relay/server/listener/ws/server_conn.go new file mode 100644 index 000000000..5c7be69fd --- /dev/null +++ b/relay/server/listener/ws/server_conn.go @@ -0,0 +1,52 @@ +package ws + +import ( + "fmt" + "time" + + "github.com/gorilla/websocket" + log "github.com/sirupsen/logrus" +) + +type Conn struct { + *websocket.Conn +} + +func NewConn(wsConn *websocket.Conn) *Conn { + return &Conn{ + wsConn, + } +} + +func (c *Conn) Read(b []byte) (n int, err error) { + t, r, err := c.NextReader() + if err != nil { + return 0, err + } + + if t != websocket.BinaryMessage { + log.Errorf("unexpected message type: %d", t) + return 0, fmt.Errorf("unexpected message type") + } + + return r.Read(b) +} + +func (c *Conn) Write(b []byte) (int, error) { + err := c.WriteMessage(websocket.BinaryMessage, b) + return len(b), err +} + +func (c *Conn) SetDeadline(t time.Time) error { + errR := c.SetReadDeadline(t) + errW := c.SetWriteDeadline(t) + + if errR != nil { + return errR + } + + if errW != nil { + return errW + } + return nil +} diff --git a/relay/server/peer.go b/relay/server/peer.go new file mode 100644 index 000000000..2e40cbb12 --- /dev/null +++ b/relay/server/peer.go @@ -0,0 +1,113 @@ +package server + +import ( + "fmt" + "net" + "sync" + + log "github.com/sirupsen/logrus" +) + +type Participant struct { + ChannelID uint16 + ChannelIDForeign uint16 + ConnForeign net.Conn + Peer *Peer +} + +type Peer struct { + Log *log.Entry + id string + conn net.Conn + + pendingParticipantByChannelID map[uint16]*Participant + participantByID map[uint16]*Participant // used for package transfer + participantByPeerID map[string]*Participant // used for channel linking + + lastId uint16 + lastIdLock sync.Mutex +} + +func NewPeer(id string, conn net.Conn) *Peer { + return &Peer{ + Log: log.WithField("peer_id", id), + id: id, + conn: conn, + pendingParticipantByChannelID: make(map[uint16]*Participant), + participantByID: make(map[uint16]*Participant), + participantByPeerID: make(map[string]*Participant), + } +} +func (p *Peer) BindChannel(remotePeerId string) uint16 { + ch, ok := p.participantByPeerID[remotePeerId] + if ok { + return ch.ChannelID + } + + channelID := p.newChannelID() + channel := &Participant{ + ChannelID: channelID, + } + p.pendingParticipantByChannelID[channelID] = channel + p.participantByPeerID[remotePeerId] = channel + return channelID +} + +func (p *Peer) UnBindChannel(remotePeerId string) { + pa, ok := p.participantByPeerID[remotePeerId] + if !ok { + return + } + + p.Log.Debugf("unbind channel with '%s': %d", remotePeerId, pa.ChannelID) + p.pendingParticipantByChannelID[pa.ChannelID] = pa + delete(p.participantByID, pa.ChannelID) +} + +func (p *Peer) AddParticipant(peer *Peer, remoteChannelID uint16) (uint16, bool) { + participant, ok := p.participantByPeerID[peer.ID()] + if !ok { + return 0, false + } + participant.ChannelIDForeign = remoteChannelID + participant.ConnForeign = peer.conn + participant.Peer = peer + + delete(p.pendingParticipantByChannelID, participant.ChannelID) + p.participantByID[participant.ChannelID] = participant + return participant.ChannelID, true +} + +func (p *Peer) DeleteParticipants() { + for _, participant := range p.participantByID { + participant.Peer.UnBindChannel(p.id) + } +} + +func (p *Peer) ConnByChannelID(dstID uint16) (uint16, net.Conn, error) { + ch, ok := p.participantByID[dstID] + if !ok { + return 0, nil, fmt.Errorf("destination channel not found") + } + + return ch.ChannelIDForeign, ch.ConnForeign, nil +} + +func (p *Peer) ID() string { + return p.id +} + +func (p *Peer) newChannelID() uint16 { + p.lastIdLock.Lock() + defer p.lastIdLock.Unlock() + for { + p.lastId++ + if _, ok := p.pendingParticipantByChannelID[p.lastId]; ok { + continue + } + if _, ok := p.participantByID[p.lastId]; ok { + continue + } + return p.lastId + } +} diff --git a/relay/server/server.go b/relay/server/server.go new file mode 100644 index 000000000..7c465132f --- /dev/null +++ b/relay/server/server.go @@ -0,0 +1,149 @@ +package server + +import ( + "fmt" + "io" + "net" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/messages" + "github.com/netbirdio/netbird/relay/server/listener" + "github.com/netbirdio/netbird/relay/server/listener/ws" +) + +// Server +// todo: +// authentication: provide JWT token via RPC call. The MGM server can forward the token to the agents. +// connection timeout handling +// implement HA (High Availability) mode +type Server struct { + store *Store + + listener listener.Listener +} + +func NewServer() *Server { + return &Server{ + store: NewStore(), + } +} + +func (r *Server) Listen(address string) error { + r.listener = ws.NewListener(address) + return r.listener.Listen(r.accept) +} + +func (r *Server) Close() error { + if r.listener == nil { + return nil + } + return r.listener.Close() +} + +func (r *Server) accept(conn net.Conn) { + peer, err := handShake(conn) + if err != nil { + log.Errorf("failed to handshake wiht %s: %s", conn.RemoteAddr(), err) + cErr := conn.Close() + if cErr != nil { + log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) + } + return + } + peer.Log.Debugf("on new connection: %s", conn.RemoteAddr()) + + r.store.AddPeer(peer) + defer func() { + peer.Log.Debugf("teardown connection") + r.store.DeletePeer(peer) + }() + + buf := make([]byte, 65535) // todo: optimize buffer size + for { + n, err := conn.Read(buf) + if err != nil { + if err != io.EOF { + peer.Log.Errorf("failed to read message: %s", err) + } + return + } + + msgType, err := messages.DetermineClientMsgType(buf[:n]) + if err != nil { + log.Errorf("failed to determine message type: %s", err) + return + } + switch msgType { + case messages.MsgTypeBindNewChannel: + dstPeerId, err := messages.UnmarshalBindNewChannel(buf[:n]) + if err != nil { + log.Errorf("failed to unmarshal bind new channel message: %s", err) + continue + } + + channelID := r.store.Link(peer, dstPeerId) + + msg := messages.MarshalBindResponseMsg(channelID, dstPeerId) + _, err = conn.Write(msg) + if err != nil { + peer.Log.Errorf("failed to response to bind request: %s", err) + continue + } + peer.Log.Debugf("bind new channel with '%s', channelID: %d", dstPeerId, channelID) + case messages.MsgTypeTransport: + msg := buf[:n] + channelId, err := messages.UnmarshalTransportID(msg) + if err != nil { + peer.Log.Errorf("failed to unmarshal transport message: %s", err) + continue + } + + foreignChannelID, remoteConn, err := peer.ConnByChannelID(channelId) + if err != nil { + peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err) + continue + } + + err = transportTo(remoteConn, foreignChannelID, msg) + if err != nil { + peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err) + continue + } + } + } +} + +func transportTo(conn net.Conn, channelID uint16, msg []byte) error { + err := messages.UpdateTransportMsg(msg, channelID) + if err != nil { + return err + } + _, err = conn.Write(msg) + return err +} + +func handShake(conn net.Conn) (*Peer, error) { + buf := make([]byte, 65535) // todo: reduce the buffer size + n, err := conn.Read(buf) + if err != nil { + log.Errorf("failed to read message: %s", err) + return nil, err + } + msgType, err := messages.DetermineClientMsgType(buf[:n]) + if err != nil { + return nil, err + } + if msgType != messages.MsgTypeHello { + tErr := fmt.Errorf("invalid message type") + log.Errorf("failed to handshake: %s", tErr) + return nil, tErr + } + peerId, err := messages.UnmarshalHelloMsg(buf[:n]) + if err != nil { + log.Errorf("failed to handshake: %s", err) + return nil, err + } + p := NewPeer(peerId, conn) + return p, nil +} diff --git a/relay/server/store.go b/relay/server/store.go new file mode 100644 index 000000000..fdbc118e4 --- /dev/null +++ b/relay/server/store.go @@ -0,0 +1,48 @@ +package server + +import ( + "sync" +) + +type Store struct { + peers map[string]*Peer // Key is the id (public key or sha-256) of the peer + peersLock sync.Mutex +} + +func NewStore() *Store { + return &Store{ + peers: make(map[string]*Peer), + } +} + +func (s *Store) AddPeer(peer *Peer) { + s.peersLock.Lock() + defer s.peersLock.Unlock() + s.peers[peer.ID()] = peer +} + +func (s *Store) Link(peer *Peer, peerForeignID string) uint16 { + s.peersLock.Lock() + defer s.peersLock.Unlock() + + channelId := peer.BindChannel(peerForeignID) + dstPeer, ok := s.peers[peerForeignID] + if !ok { + return channelId + } + + foreignChannelID, ok := dstPeer.AddParticipant(peer, channelId) + if !ok { + return channelId + } + peer.AddParticipant(dstPeer, foreignChannelID) + return channelId +} + +func (s *Store) DeletePeer(peer *Peer) { + s.peersLock.Lock() + defer s.peersLock.Unlock() + + delete(s.peers, peer.ID()) + peer.DeleteParticipants() +} diff --git a/relay/test/client_test.go b/relay/test/client_test.go new file mode 100644 index 000000000..c6a862ab9 --- /dev/null +++ b/relay/test/client_test.go @@ -0,0 +1,285 @@ +package test + +import ( + "os" + "testing" + + "github.com/netbirdio/netbird/util" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/client" + "github.com/netbirdio/netbird/relay/server" +) + +func TestMain(m *testing.M) { + _ = util.InitLog("trace", "console") + code := m.Run() + os.Exit(code) +} + +func TestClient(t *testing.T) { + addr := "localhost:1234" + srv := server.NewServer() + go func() { + err := srv.Listen(addr) + if err != nil { + t.Fatalf("failed to bind server: %s", err) + } + }() + + defer func() { + err := srv.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + clientAlice := client.NewClient(addr, "alice") + err := clientAlice.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer clientAlice.Close() + + clientPlaceHolder := client.NewClient(addr, "clientPlaceHolder") + err = clientPlaceHolder.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer clientPlaceHolder.Close() + + _, err = clientAlice.BindChannel("clientPlaceHolder") + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + clientBob := client.NewClient(addr, "bob") + err = clientBob.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer clientBob.Close() + + connAliceToBob, err := clientAlice.BindChannel("bob") + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + connBobToAlice, err := clientBob.BindChannel("alice") + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + payload := "hello bob, I am alice" + _, err = connAliceToBob.Write([]byte(payload)) + if err != nil { + t.Fatalf("failed to write to channel: %s", err) + } + log.Debugf("alice sent message to bob") + + buf := make([]byte, 65535) + n, err := connBobToAlice.Read(buf) + if err != nil { + t.Fatalf("failed to read from channel: %s", err) + } + log.Debugf("on new message from alice to bob") + + if payload != string(buf[:n]) { + t.Fatalf("expected %s, got %s", payload, string(buf[:n])) + } +} + +func TestEcho(t *testing.T) { + addr := "localhost:1234" + srv := server.NewServer() + go func() { + err := srv.Listen(addr) + if err != nil { + t.Fatalf("failed to bind server: %s", err) + } + }() + + defer func() { + err := srv.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + clientAlice := client.NewClient(addr, "alice") + err := clientAlice.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer func() { + err := clientAlice.Close() + if err != nil { + t.Errorf("failed to close Alice client: %s", err) + } + }() + + clientBob := client.NewClient(addr, "bob") + err = clientBob.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer func() { + err := clientBob.Close() + if err != nil { + t.Errorf("failed to close Bob client: %s", err) + } + }() + + connAliceToBob, err := clientAlice.BindChannel("bob") + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + connBobToAlice, err := clientBob.BindChannel("alice") + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + payload := "hello bob, I am alice" + _, err = connAliceToBob.Write([]byte(payload)) + if err != nil { + t.Fatalf("failed to write to channel: %s", err) + } + + buf := make([]byte, 65535) + n, err := connBobToAlice.Read(buf) + if err != nil { + t.Fatalf("failed to read from channel: %s", err) + } + + _, err = connBobToAlice.Write(buf[:n]) + if err != nil { + t.Fatalf("failed to write to channel: %s", err) + } + + n, err = connAliceToBob.Read(buf) + if err != nil { + t.Fatalf("failed to read from channel: %s", err) + } + + if payload != string(buf[:n]) { + t.Fatalf("expected %s, got %s", payload, string(buf[:n])) + } +} + +func TestBindToUnavailabePeer(t *testing.T) { + addr := "localhost:1234" + srv := server.NewServer() + go func() { + err := srv.Listen(addr) + if err != nil { + t.Errorf("failed to bind server: %s", err) + } + }() + + defer func() { + log.Infof("closing server") + err := srv.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + clientAlice := client.NewClient(addr, "alice") + err := clientAlice.Connect() + if err != nil { + t.Errorf("failed to connect to server: %s", err) + } + defer func() { + log.Infof("closing client") + err := clientAlice.Close() + if err != nil { + t.Errorf("failed to close client: %s", err) + } + }() + + _, err = clientAlice.BindChannel("bob") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } +} + +func TestBindReconnect(t *testing.T) { + addr := "localhost:1234" + srv := server.NewServer() + go func() { + err := srv.Listen(addr) + if err != nil { + t.Errorf("failed to bind server: %s", err) + } + }() + + defer func() { + log.Infof("closing server") + err := srv.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + clientAlice := client.NewClient(addr, "alice") + err := clientAlice.Connect() + if err != nil { + t.Errorf("failed to connect to server: %s", err) + } + + _, err = clientAlice.BindChannel("bob") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + + clientBob := client.NewClient(addr, "bob") + err = clientBob.Connect() + if err != nil { + t.Errorf("failed to connect to server: %s", err) + } + + chBob, err := clientBob.BindChannel("alice") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + + log.Infof("closing client") + err = clientAlice.Close() + if err != nil { + t.Errorf("failed to close client: %s", err) + } + + clientAlice = client.NewClient(addr, "alice") + err = clientAlice.Connect() + if err != nil { + t.Errorf("failed to connect to server: %s", err) + } + + chAlice, err := clientAlice.BindChannel("bob") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + + testString := "hello alice, I am bob" + _, err = chBob.Write([]byte(testString)) + if err != nil { + t.Errorf("failed to write to channel: %s", err) + } + + buf := make([]byte, 65535) + n, err := chAlice.Read(buf) + if err != nil { + t.Errorf("failed to read from channel: %s", err) + } + + if testString != string(buf[:n]) { + t.Errorf("expected %s, got %s", testString, string(buf[:n])) + } + + log.Infof("closing client") + err = clientAlice.Close() + if err != nil { + t.Errorf("failed to close client: %s", err) + } +}