diff --git a/relay/client/client.go b/relay/client/client.go index 0a1066145..f508fb8fc 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -3,7 +3,6 @@ package client import ( "context" "fmt" - ws "github.com/netbirdio/netbird/relay/client/dialer/wsnhooyr" "io" "net" "sync" @@ -11,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" + ws "github.com/netbirdio/netbird/relay/client/dialer/wsnhooyr" "github.com/netbirdio/netbird/relay/messages" ) @@ -19,13 +19,49 @@ const ( serverResponseTimeout = 8 * time.Second ) +// Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer. type Msg struct { - buf []byte + Payload []byte + + bufPool *sync.Pool + bufPtr *[]byte +} + +func (m *Msg) Free() { + m.bufPool.Put(m.bufPtr) } type connContainer struct { - conn *Conn - messages chan Msg + conn *Conn + messages chan Msg + msgChanLock sync.Mutex + closed bool // flag to check if channel is closed +} + +func newConnContainer(conn *Conn, messages chan Msg) *connContainer { + return &connContainer{ + conn: conn, + messages: messages, + } +} + +func (cc *connContainer) writeMsg(msg Msg) { + cc.msgChanLock.Lock() + defer cc.msgChanLock.Unlock() + if cc.closed { + return + } + cc.messages <- msg +} + +func (cc *connContainer) close() { + cc.msgChanLock.Lock() + defer cc.msgChanLock.Unlock() + if cc.closed { + return + } + close(cc.messages) + cc.closed = true } // Client is a client for the relay server. It is responsible for establishing a connection to the relay server and @@ -39,6 +75,8 @@ type Client struct { serverAddress string hashedID []byte + bufPool *sync.Pool + relayConn net.Conn conns map[string]*connContainer serviceIsRunning bool @@ -61,7 +99,13 @@ func NewClient(ctx context.Context, serverAddress, peerID string) *Client { ctxCancel: func() {}, serverAddress: serverAddress, hashedID: hashedID, - conns: make(map[string]*connContainer), + bufPool: &sync.Pool{ + New: func() any { + buf := make([]byte, bufferSize) + return &buf + }, + }, + conns: make(map[string]*connContainer), } } @@ -109,13 +153,10 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { hashedID, hashedStringID := messages.HashID(dstPeerID) log.Infof("open connection to peer: %s", hashedStringID) - messageBuffer := make(chan Msg, 2) - conn := NewConn(c, hashedID, hashedStringID, c.generateConnReaderFN(messageBuffer)) + msgChannel := make(chan Msg, 2) + conn := NewConn(c, hashedID, hashedStringID, msgChannel) - c.conns[hashedStringID] = &connContainer{ - conn, - messageBuffer, - } + c.conns[hashedStringID] = newConnContainer(conn, msgChannel) return conn, nil } @@ -246,7 +287,8 @@ func (c *Client) readLoop(relayConn net.Conn) { closedByServer bool ) for { - buf := make([]byte, bufferSize) + bufPtr := c.bufPool.Get().(*[]byte) + buf := *bufPtr n, errExit = relayConn.Read(buf) if errExit != nil { c.mu.Lock() @@ -265,7 +307,7 @@ func (c *Client) readLoop(relayConn net.Conn) { switch msgType { case messages.MsgTypeTransport: - peerID, err := messages.UnmarshalTransportID(buf[:n]) + peerID, payload, err := messages.UnmarshalTransportMsg(buf[:n]) if err != nil { c.log.Errorf("failed to parse transport message: %v", err) continue @@ -284,8 +326,10 @@ func (c *Client) readLoop(relayConn net.Conn) { continue } - // todo review is this can cause panic - container.messages <- Msg{buf[:n]} + container.writeMsg(Msg{ + bufPool: c.bufPool, + bufPtr: bufPtr, + Payload: payload}) case messages.MsgClose: closedByServer = true log.Debugf("relay connection close by server") @@ -321,26 +365,9 @@ func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) { return n, err } -func (c *Client) generateConnReaderFN(msgChannel chan Msg) func(b []byte) (n int, err error) { - return func(b []byte) (n int, err error) { - msg, ok := <-msgChannel - if !ok { - return 0, io.EOF - } - - payload, err := messages.UnmarshalTransportPayload(msg.buf) - if err != nil { - return 0, err - } - - n = copy(b, payload) - return n, nil - } -} - func (c *Client) closeAllConns() { for _, container := range c.conns { - close(container.messages) + container.close() } c.conns = make(map[string]*connContainer) } @@ -350,11 +377,11 @@ func (c *Client) closeConn(id string) error { c.mu.Lock() defer c.mu.Unlock() - conn, ok := c.conns[id] + container, ok := c.conns[id] if !ok { return fmt.Errorf("connection already closed") } - close(conn.messages) + container.close() delete(c.conns, id) return nil diff --git a/relay/client/conn.go b/relay/client/conn.go index 647b0fae4..19f3c8df4 100644 --- a/relay/client/conn.go +++ b/relay/client/conn.go @@ -1,6 +1,7 @@ package client import ( + "io" "net" "time" ) @@ -9,15 +10,15 @@ type Conn struct { client *Client dstID []byte dstStringID string - readerFn func(b []byte) (n int, err error) + messageChan chan Msg } -func NewConn(client *Client, dstID []byte, dstStringID string, readerFn func(b []byte) (n int, err error)) *Conn { +func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg) *Conn { c := &Conn{ client: client, dstID: dstID, dstStringID: dstStringID, - readerFn: readerFn, + messageChan: messageChan, } return c @@ -28,7 +29,14 @@ func (c *Conn) Write(p []byte) (n int, err error) { } func (c *Conn) Read(b []byte) (n int, err error) { - return c.readerFn(b) + msg, ok := <-c.messageChan + if !ok { + return 0, io.EOF + } + + n = copy(b, msg.Payload) + msg.Free() + return n, nil } func (c *Conn) Close() error { diff --git a/relay/messages/message.go b/relay/messages/message.go index 7f73daa17..b865687ce 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -110,12 +110,13 @@ func MarshalTransportMsg(peerID []byte, payload []byte) []byte { return msg } -func UnmarshalTransportPayload(buf []byte) ([]byte, error) { +func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) { headerSize := 1 + IDSize if len(buf) < headerSize { - return nil, ErrInvalidMessageLength + return nil, nil, ErrInvalidMessageLength } - return buf[headerSize:], nil + + return buf[1:headerSize], buf[headerSize:], nil } func UnmarshalTransportID(buf []byte) ([]byte, error) {