diff --git a/relay/client/client.go b/relay/client/client.go index a83ea1415..94c000123 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -10,12 +10,12 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/relay/client/dialer/ws" + "github.com/netbirdio/netbird/relay/client/dialer/udp" "github.com/netbirdio/netbird/relay/messages" ) const ( - bufferSize = 65535 // optimise the buffer size + bufferSize = 1500 // optimise the buffer size ) type connContainer struct { @@ -52,7 +52,7 @@ func NewClient(serverAddress, peerID string) *Client { } func (c *Client) Connect() error { - conn, err := ws.Dial(c.serverAddress) + conn, err := udp.Dial(c.serverAddress) if err != nil { return err } @@ -128,6 +128,7 @@ func (c *Client) readLoop() { buf := c.msgPool.Get().([]byte) n, errExit = c.relayConn.Read(buf) if errExit != nil { + log.Debugf("failed to read message from relay server: %s", errExit) break } @@ -155,7 +156,8 @@ func (c *Client) readLoop() { c.msgPool.Put(buf) continue } - c.handleTransport(channelId, buf[:n]) + go c.handleTransport(channelId, buf[:n]) + } } diff --git a/relay/client/conn.go b/relay/client/conn.go index a5635e020..d450c3f30 100644 --- a/relay/client/conn.go +++ b/relay/client/conn.go @@ -34,13 +34,11 @@ func (c *Conn) Close() error { } func (c *Conn) LocalAddr() net.Addr { - //TODO implement me - panic("implement me") + return c.client.relayConn.LocalAddr() } func (c *Conn) RemoteAddr() net.Addr { - //TODO implement me - panic("implement me") + return c.client.relayConn.RemoteAddr() } func (c *Conn) SetDeadline(t time.Time) error { diff --git a/relay/client/dialer/udp/udp.go b/relay/client/dialer/udp/udp.go new file mode 100644 index 000000000..ff0fa9c83 --- /dev/null +++ b/relay/client/dialer/udp/udp.go @@ -0,0 +1,14 @@ +package udp + +import ( + "net" +) + +func Dial(address string) (net.Conn, error) { + udpAddr, err := net.ResolveUDPAddr("udp", address) + if err != nil { + return nil, err + } + + return net.DialUDP("udp", nil, udpAddr) +} diff --git a/relay/messages/message.go b/relay/messages/message.go index a7d5452f0..9d728a498 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -17,6 +17,21 @@ var ( type MsgType byte +func (m MsgType) String() string { + switch m { + case MsgTypeHello: + return "hello" + case MsgTypeBindNewChannel: + return "bind new channel" + case MsgTypeBindResponse: + return "bind response" + case MsgTypeTransport: + return "transport" + default: + return "unknown" + } +} + func DetermineClientMsgType(msg []byte) (MsgType, error) { // todo: validate magic byte msgType := MsgType(msg[0]) @@ -41,7 +56,7 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) { case MsgTypeTransport: return msgType, nil default: - return 0, fmt.Errorf("invalid msg type: %s", msg) + return 0, fmt.Errorf("invalid msg type, len: %d", len(msg)) } } diff --git a/relay/server/listener/udp/listener.go b/relay/server/listener/udp/listener.go new file mode 100644 index 000000000..df7fa4c64 --- /dev/null +++ b/relay/server/listener/udp/listener.go @@ -0,0 +1,96 @@ +package udp + +import ( + "net" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/server/listener" +) + +type Listener struct { + address string + + onAcceptFn func(conn net.Conn) + + conns map[string]*UDPConn + wg sync.WaitGroup + quit chan struct{} + lock sync.Mutex + listener *net.UDPConn +} + +func NewListener(address string) listener.Listener { + return &Listener{ + address: address, + conns: make(map[string]*UDPConn), + } +} + +func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error { + l.lock.Lock() + + l.onAcceptFn = onAcceptFn + l.quit = make(chan struct{}) + + addr := &net.UDPAddr{ + Port: 1234, + IP: net.ParseIP("0.0.0.0"), + } + li, err := net.ListenUDP("udp", addr) + if err != nil { + log.Errorf("%s", err) + l.lock.Unlock() + return err + } + log.Debugf("udp server is listening on address: %s", l.address) + l.listener = li + l.wg.Add(1) + go l.readLoop() + + l.lock.Unlock() + <-l.quit + return nil +} + +// Close todo: prevent multiple call (do not close two times the channel) +func (l *Listener) Close() error { + l.lock.Lock() + defer l.lock.Unlock() + + close(l.quit) + err := l.listener.Close() + l.wg.Wait() + return err +} + +func (l *Listener) readLoop() { + defer l.wg.Done() + + for { + buf := make([]byte, 1500) + n, addr, err := l.listener.ReadFromUDP(buf) + if err != nil { + select { + case <-l.quit: + return + default: + log.Errorf("failed to accept connection: %s", err) + continue + } + } + + pConn, ok := l.conns[addr.String()] + if ok { + pConn.onNewMsg(buf[:n]) + continue + } + + pConn = NewConn(l.listener, addr) + l.conns[addr.String()] = pConn + go l.onAcceptFn(pConn) + pConn.onNewMsg(buf[:n]) + + } +} diff --git a/relay/server/listener/udp/udp_conn.go b/relay/server/listener/udp/udp_conn.go new file mode 100644 index 000000000..24b6d4640 --- /dev/null +++ b/relay/server/listener/udp/udp_conn.go @@ -0,0 +1,68 @@ +package udp + +import ( + "io" + "net" + "time" +) + +type UDPConn struct { + *net.UDPConn + addr *net.UDPAddr + msgChannel chan []byte +} + +func NewConn(conn *net.UDPConn, addr *net.UDPAddr) *UDPConn { + return &UDPConn{ + UDPConn: conn, + addr: addr, + msgChannel: make(chan []byte), + } +} + +func (u *UDPConn) Read(b []byte) (n int, err error) { + msg, ok := <-u.msgChannel + if !ok { + return 0, io.EOF + } + + n = copy(b, msg) + return n, nil +} + +func (u *UDPConn) Write(b []byte) (n int, err error) { + return u.UDPConn.WriteTo(b, u.addr) +} + +func (u *UDPConn) Close() error { + //TODO implement me + //panic("implement me") + return nil +} + +func (u *UDPConn) LocalAddr() net.Addr { + return u.UDPConn.LocalAddr() +} + +func (u *UDPConn) RemoteAddr() net.Addr { + return u.addr +} + +func (u *UDPConn) SetDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + +func (u *UDPConn) SetReadDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + +func (u *UDPConn) SetWriteDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + +func (u *UDPConn) onNewMsg(b []byte) { + u.msgChannel <- b +} diff --git a/relay/server/listener/ws/server_conn.go b/relay/server/listener/ws/server_conn.go index 5c7be69fd..de3a16781 100644 --- a/relay/server/listener/ws/server_conn.go +++ b/relay/server/listener/ws/server_conn.go @@ -2,6 +2,7 @@ package ws import ( "fmt" + "sync" "time" "github.com/gorilla/websocket" @@ -10,11 +11,13 @@ import ( type Conn struct { *websocket.Conn + + mu sync.Mutex } func NewConn(wsConn *websocket.Conn) *Conn { return &Conn{ - wsConn, + Conn: wsConn, } } @@ -33,7 +36,9 @@ func (c *Conn) Read(b []byte) (n int, err error) { } func (c *Conn) Write(b []byte) (int, error) { + c.mu.Lock() err := c.WriteMessage(websocket.BinaryMessage, b) + c.mu.Unlock() return len(b), err } diff --git a/relay/server/server.go b/relay/server/server.go index 7c465132f..bcfda16e7 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -9,7 +9,7 @@ import ( "github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/server/listener" - "github.com/netbirdio/netbird/relay/server/listener/ws" + "github.com/netbirdio/netbird/relay/server/listener/udp" ) // Server @@ -30,7 +30,7 @@ func NewServer() *Server { } func (r *Server) Listen(address string) error { - r.listener = ws.NewListener(address) + r.listener = udp.NewListener(address) return r.listener.Listen(r.accept) } @@ -51,7 +51,7 @@ func (r *Server) accept(conn net.Conn) { } return } - peer.Log.Debugf("on new connection: %s", conn.RemoteAddr()) + peer.Log.Debugf("peer connected from: %s", conn.RemoteAddr()) r.store.AddPeer(peer) defer func() { @@ -59,8 +59,8 @@ func (r *Server) accept(conn net.Conn) { r.store.DeletePeer(peer) }() - buf := make([]byte, 65535) // todo: optimize buffer size for { + buf := make([]byte, 1500) // todo: optimize buffer size n, err := conn.Read(buf) if err != nil { if err != io.EOF { @@ -98,18 +98,18 @@ func (r *Server) accept(conn net.Conn) { peer.Log.Errorf("failed to unmarshal transport message: %s", err) continue } + go func() { + foreignChannelID, remoteConn, err := peer.ConnByChannelID(channelId) + if err != nil { + peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err) + return + } - foreignChannelID, remoteConn, err := peer.ConnByChannelID(channelId) - if err != nil { - peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err) - continue - } - - err = transportTo(remoteConn, foreignChannelID, msg) - if err != nil { - peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err) - continue - } + err = transportTo(remoteConn, foreignChannelID, msg) + if err != nil { + peer.Log.Errorf("failed to transport message from peer '%s' to '%d': %s", peer.ID(), channelId, err) + } + }() } } } diff --git a/relay/test/client_test.go b/relay/test/client_test.go index c6a862ab9..b98d6690e 100644 --- a/relay/test/client_test.go +++ b/relay/test/client_test.go @@ -173,7 +173,7 @@ func TestBindToUnavailabePeer(t *testing.T) { go func() { err := srv.Listen(addr) if err != nil { - t.Errorf("failed to bind server: %s", err) + t.Fatalf("failed to bind server: %s", err) } }()