mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-07 06:29:06 +01:00
Use buf pool
- eliminate reader function generation - fix write to closed channel panic
This commit is contained in:
parent
8c70b7d7ff
commit
5e93d117cf
@ -3,7 +3,6 @@ package client
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
ws "github.com/netbirdio/netbird/relay/client/dialer/wsnhooyr"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
@ -11,6 +10,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
ws "github.com/netbirdio/netbird/relay/client/dialer/wsnhooyr"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
@ -19,13 +19,49 @@ const (
|
||||
serverResponseTimeout = 8 * time.Second
|
||||
)
|
||||
|
||||
// Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer.
|
||||
type Msg struct {
|
||||
buf []byte
|
||||
Payload []byte
|
||||
|
||||
bufPool *sync.Pool
|
||||
bufPtr *[]byte
|
||||
}
|
||||
|
||||
func (m *Msg) Free() {
|
||||
m.bufPool.Put(m.bufPtr)
|
||||
}
|
||||
|
||||
type connContainer struct {
|
||||
conn *Conn
|
||||
messages chan Msg
|
||||
conn *Conn
|
||||
messages chan Msg
|
||||
msgChanLock sync.Mutex
|
||||
closed bool // flag to check if channel is closed
|
||||
}
|
||||
|
||||
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
|
||||
return &connContainer{
|
||||
conn: conn,
|
||||
messages: messages,
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *connContainer) writeMsg(msg Msg) {
|
||||
cc.msgChanLock.Lock()
|
||||
defer cc.msgChanLock.Unlock()
|
||||
if cc.closed {
|
||||
return
|
||||
}
|
||||
cc.messages <- msg
|
||||
}
|
||||
|
||||
func (cc *connContainer) close() {
|
||||
cc.msgChanLock.Lock()
|
||||
defer cc.msgChanLock.Unlock()
|
||||
if cc.closed {
|
||||
return
|
||||
}
|
||||
close(cc.messages)
|
||||
cc.closed = true
|
||||
}
|
||||
|
||||
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
|
||||
@ -39,6 +75,8 @@ type Client struct {
|
||||
serverAddress string
|
||||
hashedID []byte
|
||||
|
||||
bufPool *sync.Pool
|
||||
|
||||
relayConn net.Conn
|
||||
conns map[string]*connContainer
|
||||
serviceIsRunning bool
|
||||
@ -61,7 +99,13 @@ func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
|
||||
ctxCancel: func() {},
|
||||
serverAddress: serverAddress,
|
||||
hashedID: hashedID,
|
||||
conns: make(map[string]*connContainer),
|
||||
bufPool: &sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, bufferSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
conns: make(map[string]*connContainer),
|
||||
}
|
||||
}
|
||||
|
||||
@ -109,13 +153,10 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
||||
|
||||
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
||||
log.Infof("open connection to peer: %s", hashedStringID)
|
||||
messageBuffer := make(chan Msg, 2)
|
||||
conn := NewConn(c, hashedID, hashedStringID, c.generateConnReaderFN(messageBuffer))
|
||||
msgChannel := make(chan Msg, 2)
|
||||
conn := NewConn(c, hashedID, hashedStringID, msgChannel)
|
||||
|
||||
c.conns[hashedStringID] = &connContainer{
|
||||
conn,
|
||||
messageBuffer,
|
||||
}
|
||||
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@ -246,7 +287,8 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
||||
closedByServer bool
|
||||
)
|
||||
for {
|
||||
buf := make([]byte, bufferSize)
|
||||
bufPtr := c.bufPool.Get().(*[]byte)
|
||||
buf := *bufPtr
|
||||
n, errExit = relayConn.Read(buf)
|
||||
if errExit != nil {
|
||||
c.mu.Lock()
|
||||
@ -265,7 +307,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
||||
|
||||
switch msgType {
|
||||
case messages.MsgTypeTransport:
|
||||
peerID, err := messages.UnmarshalTransportID(buf[:n])
|
||||
peerID, payload, err := messages.UnmarshalTransportMsg(buf[:n])
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to parse transport message: %v", err)
|
||||
continue
|
||||
@ -284,8 +326,10 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
||||
continue
|
||||
}
|
||||
|
||||
// todo review is this can cause panic
|
||||
container.messages <- Msg{buf[:n]}
|
||||
container.writeMsg(Msg{
|
||||
bufPool: c.bufPool,
|
||||
bufPtr: bufPtr,
|
||||
Payload: payload})
|
||||
case messages.MsgClose:
|
||||
closedByServer = true
|
||||
log.Debugf("relay connection close by server")
|
||||
@ -321,26 +365,9 @@ func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *Client) generateConnReaderFN(msgChannel chan Msg) func(b []byte) (n int, err error) {
|
||||
return func(b []byte) (n int, err error) {
|
||||
msg, ok := <-msgChannel
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
payload, err := messages.UnmarshalTransportPayload(msg.buf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
n = copy(b, payload)
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) closeAllConns() {
|
||||
for _, container := range c.conns {
|
||||
close(container.messages)
|
||||
container.close()
|
||||
}
|
||||
c.conns = make(map[string]*connContainer)
|
||||
}
|
||||
@ -350,11 +377,11 @@ func (c *Client) closeConn(id string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
conn, ok := c.conns[id]
|
||||
container, ok := c.conns[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("connection already closed")
|
||||
}
|
||||
close(conn.messages)
|
||||
container.close()
|
||||
delete(c.conns, id)
|
||||
|
||||
return nil
|
||||
|
@ -1,6 +1,7 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
@ -9,15 +10,15 @@ type Conn struct {
|
||||
client *Client
|
||||
dstID []byte
|
||||
dstStringID string
|
||||
readerFn func(b []byte) (n int, err error)
|
||||
messageChan chan Msg
|
||||
}
|
||||
|
||||
func NewConn(client *Client, dstID []byte, dstStringID string, readerFn func(b []byte) (n int, err error)) *Conn {
|
||||
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg) *Conn {
|
||||
c := &Conn{
|
||||
client: client,
|
||||
dstID: dstID,
|
||||
dstStringID: dstStringID,
|
||||
readerFn: readerFn,
|
||||
messageChan: messageChan,
|
||||
}
|
||||
|
||||
return c
|
||||
@ -28,7 +29,14 @@ func (c *Conn) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
return c.readerFn(b)
|
||||
msg, ok := <-c.messageChan
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n = copy(b, msg.Payload)
|
||||
msg.Free()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
|
@ -110,12 +110,13 @@ func MarshalTransportMsg(peerID []byte, payload []byte) []byte {
|
||||
return msg
|
||||
}
|
||||
|
||||
func UnmarshalTransportPayload(buf []byte) ([]byte, error) {
|
||||
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
||||
headerSize := 1 + IDSize
|
||||
if len(buf) < headerSize {
|
||||
return nil, ErrInvalidMessageLength
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
return buf[headerSize:], nil
|
||||
|
||||
return buf[1:headerSize], buf[headerSize:], nil
|
||||
}
|
||||
|
||||
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
||||
|
Loading…
Reference in New Issue
Block a user