mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-23 06:18:46 +01:00
Use buf pool
- eliminate reader function generation - fix write to closed channel panic
This commit is contained in:
parent
8c70b7d7ff
commit
5e93d117cf
@ -3,7 +3,6 @@ package client
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
ws "github.com/netbirdio/netbird/relay/client/dialer/wsnhooyr"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@ -11,6 +10,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
ws "github.com/netbirdio/netbird/relay/client/dialer/wsnhooyr"
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -19,13 +19,49 @@ const (
|
|||||||
serverResponseTimeout = 8 * time.Second
|
serverResponseTimeout = 8 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer.
|
||||||
type Msg struct {
|
type Msg struct {
|
||||||
buf []byte
|
Payload []byte
|
||||||
|
|
||||||
|
bufPool *sync.Pool
|
||||||
|
bufPtr *[]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Msg) Free() {
|
||||||
|
m.bufPool.Put(m.bufPtr)
|
||||||
}
|
}
|
||||||
|
|
||||||
type connContainer struct {
|
type connContainer struct {
|
||||||
conn *Conn
|
conn *Conn
|
||||||
messages chan Msg
|
messages chan Msg
|
||||||
|
msgChanLock sync.Mutex
|
||||||
|
closed bool // flag to check if channel is closed
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
|
||||||
|
return &connContainer{
|
||||||
|
conn: conn,
|
||||||
|
messages: messages,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *connContainer) writeMsg(msg Msg) {
|
||||||
|
cc.msgChanLock.Lock()
|
||||||
|
defer cc.msgChanLock.Unlock()
|
||||||
|
if cc.closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cc.messages <- msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *connContainer) close() {
|
||||||
|
cc.msgChanLock.Lock()
|
||||||
|
defer cc.msgChanLock.Unlock()
|
||||||
|
if cc.closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
close(cc.messages)
|
||||||
|
cc.closed = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
|
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
|
||||||
@ -39,6 +75,8 @@ type Client struct {
|
|||||||
serverAddress string
|
serverAddress string
|
||||||
hashedID []byte
|
hashedID []byte
|
||||||
|
|
||||||
|
bufPool *sync.Pool
|
||||||
|
|
||||||
relayConn net.Conn
|
relayConn net.Conn
|
||||||
conns map[string]*connContainer
|
conns map[string]*connContainer
|
||||||
serviceIsRunning bool
|
serviceIsRunning bool
|
||||||
@ -61,7 +99,13 @@ func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
|
|||||||
ctxCancel: func() {},
|
ctxCancel: func() {},
|
||||||
serverAddress: serverAddress,
|
serverAddress: serverAddress,
|
||||||
hashedID: hashedID,
|
hashedID: hashedID,
|
||||||
conns: make(map[string]*connContainer),
|
bufPool: &sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
buf := make([]byte, bufferSize)
|
||||||
|
return &buf
|
||||||
|
},
|
||||||
|
},
|
||||||
|
conns: make(map[string]*connContainer),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,13 +153,10 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
|||||||
|
|
||||||
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
||||||
log.Infof("open connection to peer: %s", hashedStringID)
|
log.Infof("open connection to peer: %s", hashedStringID)
|
||||||
messageBuffer := make(chan Msg, 2)
|
msgChannel := make(chan Msg, 2)
|
||||||
conn := NewConn(c, hashedID, hashedStringID, c.generateConnReaderFN(messageBuffer))
|
conn := NewConn(c, hashedID, hashedStringID, msgChannel)
|
||||||
|
|
||||||
c.conns[hashedStringID] = &connContainer{
|
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
|
||||||
conn,
|
|
||||||
messageBuffer,
|
|
||||||
}
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -246,7 +287,8 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
|||||||
closedByServer bool
|
closedByServer bool
|
||||||
)
|
)
|
||||||
for {
|
for {
|
||||||
buf := make([]byte, bufferSize)
|
bufPtr := c.bufPool.Get().(*[]byte)
|
||||||
|
buf := *bufPtr
|
||||||
n, errExit = relayConn.Read(buf)
|
n, errExit = relayConn.Read(buf)
|
||||||
if errExit != nil {
|
if errExit != nil {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
@ -265,7 +307,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
|||||||
|
|
||||||
switch msgType {
|
switch msgType {
|
||||||
case messages.MsgTypeTransport:
|
case messages.MsgTypeTransport:
|
||||||
peerID, err := messages.UnmarshalTransportID(buf[:n])
|
peerID, payload, err := messages.UnmarshalTransportMsg(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("failed to parse transport message: %v", err)
|
c.log.Errorf("failed to parse transport message: %v", err)
|
||||||
continue
|
continue
|
||||||
@ -284,8 +326,10 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo review is this can cause panic
|
container.writeMsg(Msg{
|
||||||
container.messages <- Msg{buf[:n]}
|
bufPool: c.bufPool,
|
||||||
|
bufPtr: bufPtr,
|
||||||
|
Payload: payload})
|
||||||
case messages.MsgClose:
|
case messages.MsgClose:
|
||||||
closedByServer = true
|
closedByServer = true
|
||||||
log.Debugf("relay connection close by server")
|
log.Debugf("relay connection close by server")
|
||||||
@ -321,26 +365,9 @@ func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
|
|||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) generateConnReaderFN(msgChannel chan Msg) func(b []byte) (n int, err error) {
|
|
||||||
return func(b []byte) (n int, err error) {
|
|
||||||
msg, ok := <-msgChannel
|
|
||||||
if !ok {
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
payload, err := messages.UnmarshalTransportPayload(msg.buf)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
n = copy(b, payload)
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) closeAllConns() {
|
func (c *Client) closeAllConns() {
|
||||||
for _, container := range c.conns {
|
for _, container := range c.conns {
|
||||||
close(container.messages)
|
container.close()
|
||||||
}
|
}
|
||||||
c.conns = make(map[string]*connContainer)
|
c.conns = make(map[string]*connContainer)
|
||||||
}
|
}
|
||||||
@ -350,11 +377,11 @@ func (c *Client) closeConn(id string) error {
|
|||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
conn, ok := c.conns[id]
|
container, ok := c.conns[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("connection already closed")
|
return fmt.Errorf("connection already closed")
|
||||||
}
|
}
|
||||||
close(conn.messages)
|
container.close()
|
||||||
delete(c.conns, id)
|
delete(c.conns, id)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -9,15 +10,15 @@ type Conn struct {
|
|||||||
client *Client
|
client *Client
|
||||||
dstID []byte
|
dstID []byte
|
||||||
dstStringID string
|
dstStringID string
|
||||||
readerFn func(b []byte) (n int, err error)
|
messageChan chan Msg
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConn(client *Client, dstID []byte, dstStringID string, readerFn func(b []byte) (n int, err error)) *Conn {
|
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg) *Conn {
|
||||||
c := &Conn{
|
c := &Conn{
|
||||||
client: client,
|
client: client,
|
||||||
dstID: dstID,
|
dstID: dstID,
|
||||||
dstStringID: dstStringID,
|
dstStringID: dstStringID,
|
||||||
readerFn: readerFn,
|
messageChan: messageChan,
|
||||||
}
|
}
|
||||||
|
|
||||||
return c
|
return c
|
||||||
@ -28,7 +29,14 @@ func (c *Conn) Write(p []byte) (n int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
return c.readerFn(b)
|
msg, ok := <-c.messageChan
|
||||||
|
if !ok {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
n = copy(b, msg.Payload)
|
||||||
|
msg.Free()
|
||||||
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
|
@ -110,12 +110,13 @@ func MarshalTransportMsg(peerID []byte, payload []byte) []byte {
|
|||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalTransportPayload(buf []byte) ([]byte, error) {
|
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
||||||
headerSize := 1 + IDSize
|
headerSize := 1 + IDSize
|
||||||
if len(buf) < headerSize {
|
if len(buf) < headerSize {
|
||||||
return nil, ErrInvalidMessageLength
|
return nil, nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
return buf[headerSize:], nil
|
|
||||||
|
return buf[1:headerSize], buf[headerSize:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
||||||
|
Loading…
Reference in New Issue
Block a user