Use buf pool

- eliminate reader function generation
- fix write to closed channel panic
This commit is contained in:
Zoltan Papp 2024-06-09 20:27:40 +02:00
parent 8c70b7d7ff
commit 5e93d117cf
3 changed files with 78 additions and 42 deletions

View File

@ -3,7 +3,6 @@ package client
import ( import (
"context" "context"
"fmt" "fmt"
ws "github.com/netbirdio/netbird/relay/client/dialer/wsnhooyr"
"io" "io"
"net" "net"
"sync" "sync"
@ -11,6 +10,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
ws "github.com/netbirdio/netbird/relay/client/dialer/wsnhooyr"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
) )
@ -19,13 +19,49 @@ const (
serverResponseTimeout = 8 * time.Second serverResponseTimeout = 8 * time.Second
) )
// Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer.
type Msg struct { type Msg struct {
buf []byte Payload []byte
bufPool *sync.Pool
bufPtr *[]byte
}
func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr)
} }
type connContainer struct { type connContainer struct {
conn *Conn conn *Conn
messages chan Msg messages chan Msg
msgChanLock sync.Mutex
closed bool // flag to check if channel is closed
}
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
return &connContainer{
conn: conn,
messages: messages,
}
}
func (cc *connContainer) writeMsg(msg Msg) {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
cc.messages <- msg
}
func (cc *connContainer) close() {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
close(cc.messages)
cc.closed = true
} }
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and // Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
@ -39,6 +75,8 @@ type Client struct {
serverAddress string serverAddress string
hashedID []byte hashedID []byte
bufPool *sync.Pool
relayConn net.Conn relayConn net.Conn
conns map[string]*connContainer conns map[string]*connContainer
serviceIsRunning bool serviceIsRunning bool
@ -61,7 +99,13 @@ func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
ctxCancel: func() {}, ctxCancel: func() {},
serverAddress: serverAddress, serverAddress: serverAddress,
hashedID: hashedID, hashedID: hashedID,
conns: make(map[string]*connContainer), bufPool: &sync.Pool{
New: func() any {
buf := make([]byte, bufferSize)
return &buf
},
},
conns: make(map[string]*connContainer),
} }
} }
@ -109,13 +153,10 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
hashedID, hashedStringID := messages.HashID(dstPeerID) hashedID, hashedStringID := messages.HashID(dstPeerID)
log.Infof("open connection to peer: %s", hashedStringID) log.Infof("open connection to peer: %s", hashedStringID)
messageBuffer := make(chan Msg, 2) msgChannel := make(chan Msg, 2)
conn := NewConn(c, hashedID, hashedStringID, c.generateConnReaderFN(messageBuffer)) conn := NewConn(c, hashedID, hashedStringID, msgChannel)
c.conns[hashedStringID] = &connContainer{ c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
conn,
messageBuffer,
}
return conn, nil return conn, nil
} }
@ -246,7 +287,8 @@ func (c *Client) readLoop(relayConn net.Conn) {
closedByServer bool closedByServer bool
) )
for { for {
buf := make([]byte, bufferSize) bufPtr := c.bufPool.Get().(*[]byte)
buf := *bufPtr
n, errExit = relayConn.Read(buf) n, errExit = relayConn.Read(buf)
if errExit != nil { if errExit != nil {
c.mu.Lock() c.mu.Lock()
@ -265,7 +307,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
switch msgType { switch msgType {
case messages.MsgTypeTransport: case messages.MsgTypeTransport:
peerID, err := messages.UnmarshalTransportID(buf[:n]) peerID, payload, err := messages.UnmarshalTransportMsg(buf[:n])
if err != nil { if err != nil {
c.log.Errorf("failed to parse transport message: %v", err) c.log.Errorf("failed to parse transport message: %v", err)
continue continue
@ -284,8 +326,10 @@ func (c *Client) readLoop(relayConn net.Conn) {
continue continue
} }
// todo review is this can cause panic container.writeMsg(Msg{
container.messages <- Msg{buf[:n]} bufPool: c.bufPool,
bufPtr: bufPtr,
Payload: payload})
case messages.MsgClose: case messages.MsgClose:
closedByServer = true closedByServer = true
log.Debugf("relay connection close by server") log.Debugf("relay connection close by server")
@ -321,26 +365,9 @@ func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
return n, err return n, err
} }
func (c *Client) generateConnReaderFN(msgChannel chan Msg) func(b []byte) (n int, err error) {
return func(b []byte) (n int, err error) {
msg, ok := <-msgChannel
if !ok {
return 0, io.EOF
}
payload, err := messages.UnmarshalTransportPayload(msg.buf)
if err != nil {
return 0, err
}
n = copy(b, payload)
return n, nil
}
}
func (c *Client) closeAllConns() { func (c *Client) closeAllConns() {
for _, container := range c.conns { for _, container := range c.conns {
close(container.messages) container.close()
} }
c.conns = make(map[string]*connContainer) c.conns = make(map[string]*connContainer)
} }
@ -350,11 +377,11 @@ func (c *Client) closeConn(id string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
conn, ok := c.conns[id] container, ok := c.conns[id]
if !ok { if !ok {
return fmt.Errorf("connection already closed") return fmt.Errorf("connection already closed")
} }
close(conn.messages) container.close()
delete(c.conns, id) delete(c.conns, id)
return nil return nil

View File

@ -1,6 +1,7 @@
package client package client
import ( import (
"io"
"net" "net"
"time" "time"
) )
@ -9,15 +10,15 @@ type Conn struct {
client *Client client *Client
dstID []byte dstID []byte
dstStringID string dstStringID string
readerFn func(b []byte) (n int, err error) messageChan chan Msg
} }
func NewConn(client *Client, dstID []byte, dstStringID string, readerFn func(b []byte) (n int, err error)) *Conn { func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg) *Conn {
c := &Conn{ c := &Conn{
client: client, client: client,
dstID: dstID, dstID: dstID,
dstStringID: dstStringID, dstStringID: dstStringID,
readerFn: readerFn, messageChan: messageChan,
} }
return c return c
@ -28,7 +29,14 @@ func (c *Conn) Write(p []byte) (n int, err error) {
} }
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
return c.readerFn(b) msg, ok := <-c.messageChan
if !ok {
return 0, io.EOF
}
n = copy(b, msg.Payload)
msg.Free()
return n, nil
} }
func (c *Conn) Close() error { func (c *Conn) Close() error {

View File

@ -110,12 +110,13 @@ func MarshalTransportMsg(peerID []byte, payload []byte) []byte {
return msg return msg
} }
func UnmarshalTransportPayload(buf []byte) ([]byte, error) { func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
headerSize := 1 + IDSize headerSize := 1 + IDSize
if len(buf) < headerSize { if len(buf) < headerSize {
return nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
return buf[headerSize:], nil
return buf[1:headerSize], buf[headerSize:], nil
} }
func UnmarshalTransportID(buf []byte) ([]byte, error) { func UnmarshalTransportID(buf []byte) ([]byte, error) {