Use buf pool

- eliminate reader function generation
- fix write to closed channel panic
This commit is contained in:
Zoltan Papp 2024-06-09 20:27:40 +02:00
parent 8c70b7d7ff
commit 5e93d117cf
3 changed files with 78 additions and 42 deletions

View File

@ -3,7 +3,6 @@ package client
import (
"context"
"fmt"
ws "github.com/netbirdio/netbird/relay/client/dialer/wsnhooyr"
"io"
"net"
"sync"
@ -11,6 +10,7 @@ import (
log "github.com/sirupsen/logrus"
ws "github.com/netbirdio/netbird/relay/client/dialer/wsnhooyr"
"github.com/netbirdio/netbird/relay/messages"
)
@ -19,13 +19,49 @@ const (
serverResponseTimeout = 8 * time.Second
)
// Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer.
type Msg struct {
buf []byte
Payload []byte
bufPool *sync.Pool
bufPtr *[]byte
}
func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr)
}
type connContainer struct {
conn *Conn
messages chan Msg
conn *Conn
messages chan Msg
msgChanLock sync.Mutex
closed bool // flag to check if channel is closed
}
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
return &connContainer{
conn: conn,
messages: messages,
}
}
func (cc *connContainer) writeMsg(msg Msg) {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
cc.messages <- msg
}
func (cc *connContainer) close() {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
close(cc.messages)
cc.closed = true
}
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
@ -39,6 +75,8 @@ type Client struct {
serverAddress string
hashedID []byte
bufPool *sync.Pool
relayConn net.Conn
conns map[string]*connContainer
serviceIsRunning bool
@ -61,7 +99,13 @@ func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
ctxCancel: func() {},
serverAddress: serverAddress,
hashedID: hashedID,
conns: make(map[string]*connContainer),
bufPool: &sync.Pool{
New: func() any {
buf := make([]byte, bufferSize)
return &buf
},
},
conns: make(map[string]*connContainer),
}
}
@ -109,13 +153,10 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
hashedID, hashedStringID := messages.HashID(dstPeerID)
log.Infof("open connection to peer: %s", hashedStringID)
messageBuffer := make(chan Msg, 2)
conn := NewConn(c, hashedID, hashedStringID, c.generateConnReaderFN(messageBuffer))
msgChannel := make(chan Msg, 2)
conn := NewConn(c, hashedID, hashedStringID, msgChannel)
c.conns[hashedStringID] = &connContainer{
conn,
messageBuffer,
}
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
return conn, nil
}
@ -246,7 +287,8 @@ func (c *Client) readLoop(relayConn net.Conn) {
closedByServer bool
)
for {
buf := make([]byte, bufferSize)
bufPtr := c.bufPool.Get().(*[]byte)
buf := *bufPtr
n, errExit = relayConn.Read(buf)
if errExit != nil {
c.mu.Lock()
@ -265,7 +307,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
switch msgType {
case messages.MsgTypeTransport:
peerID, err := messages.UnmarshalTransportID(buf[:n])
peerID, payload, err := messages.UnmarshalTransportMsg(buf[:n])
if err != nil {
c.log.Errorf("failed to parse transport message: %v", err)
continue
@ -284,8 +326,10 @@ func (c *Client) readLoop(relayConn net.Conn) {
continue
}
// todo review is this can cause panic
container.messages <- Msg{buf[:n]}
container.writeMsg(Msg{
bufPool: c.bufPool,
bufPtr: bufPtr,
Payload: payload})
case messages.MsgClose:
closedByServer = true
log.Debugf("relay connection close by server")
@ -321,26 +365,9 @@ func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
return n, err
}
func (c *Client) generateConnReaderFN(msgChannel chan Msg) func(b []byte) (n int, err error) {
return func(b []byte) (n int, err error) {
msg, ok := <-msgChannel
if !ok {
return 0, io.EOF
}
payload, err := messages.UnmarshalTransportPayload(msg.buf)
if err != nil {
return 0, err
}
n = copy(b, payload)
return n, nil
}
}
func (c *Client) closeAllConns() {
for _, container := range c.conns {
close(container.messages)
container.close()
}
c.conns = make(map[string]*connContainer)
}
@ -350,11 +377,11 @@ func (c *Client) closeConn(id string) error {
c.mu.Lock()
defer c.mu.Unlock()
conn, ok := c.conns[id]
container, ok := c.conns[id]
if !ok {
return fmt.Errorf("connection already closed")
}
close(conn.messages)
container.close()
delete(c.conns, id)
return nil

View File

@ -1,6 +1,7 @@
package client
import (
"io"
"net"
"time"
)
@ -9,15 +10,15 @@ type Conn struct {
client *Client
dstID []byte
dstStringID string
readerFn func(b []byte) (n int, err error)
messageChan chan Msg
}
func NewConn(client *Client, dstID []byte, dstStringID string, readerFn func(b []byte) (n int, err error)) *Conn {
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg) *Conn {
c := &Conn{
client: client,
dstID: dstID,
dstStringID: dstStringID,
readerFn: readerFn,
messageChan: messageChan,
}
return c
@ -28,7 +29,14 @@ func (c *Conn) Write(p []byte) (n int, err error) {
}
func (c *Conn) Read(b []byte) (n int, err error) {
return c.readerFn(b)
msg, ok := <-c.messageChan
if !ok {
return 0, io.EOF
}
n = copy(b, msg.Payload)
msg.Free()
return n, nil
}
func (c *Conn) Close() error {

View File

@ -110,12 +110,13 @@ func MarshalTransportMsg(peerID []byte, payload []byte) []byte {
return msg
}
func UnmarshalTransportPayload(buf []byte) ([]byte, error) {
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
headerSize := 1 + IDSize
if len(buf) < headerSize {
return nil, ErrInvalidMessageLength
return nil, nil, ErrInvalidMessageLength
}
return buf[headerSize:], nil
return buf[1:headerSize], buf[headerSize:], nil
}
func UnmarshalTransportID(buf []byte) ([]byte, error) {