mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-01 12:33:53 +01:00
97e10e440c
Fix leaked server connections close unused connections in the client lib close deprecated connection in the server lib The Server Picker is reusable in the guard if we want in the future. So we can support the server address changes. --------- Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> * Add logging --------- Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
562 lines
13 KiB
Go
562 lines
13 KiB
Go
package client
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
|
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
|
"github.com/netbirdio/netbird/relay/healthcheck"
|
|
"github.com/netbirdio/netbird/relay/messages"
|
|
)
|
|
|
|
const (
|
|
bufferSize = 8820
|
|
serverResponseTimeout = 8 * time.Second
|
|
)
|
|
|
|
var (
|
|
ErrConnAlreadyExists = fmt.Errorf("connection already exists")
|
|
)
|
|
|
|
type internalStopFlag struct {
|
|
sync.Mutex
|
|
stop bool
|
|
}
|
|
|
|
func newInternalStopFlag() *internalStopFlag {
|
|
return &internalStopFlag{}
|
|
}
|
|
|
|
func (isf *internalStopFlag) set() {
|
|
isf.Lock()
|
|
defer isf.Unlock()
|
|
isf.stop = true
|
|
}
|
|
|
|
func (isf *internalStopFlag) isSet() bool {
|
|
isf.Lock()
|
|
defer isf.Unlock()
|
|
return isf.stop
|
|
}
|
|
|
|
// Msg carry the payload from the server to the client. With this struct, the net.Conn can free the buffer.
|
|
type Msg struct {
|
|
Payload []byte
|
|
|
|
bufPool *sync.Pool
|
|
bufPtr *[]byte
|
|
}
|
|
|
|
func (m *Msg) Free() {
|
|
m.bufPool.Put(m.bufPtr)
|
|
}
|
|
|
|
type connContainer struct {
|
|
conn *Conn
|
|
messages chan Msg
|
|
msgChanLock sync.Mutex
|
|
closed bool // flag to check if channel is closed
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
return &connContainer{
|
|
conn: conn,
|
|
messages: messages,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
}
|
|
|
|
func (cc *connContainer) writeMsg(msg Msg) {
|
|
cc.msgChanLock.Lock()
|
|
defer cc.msgChanLock.Unlock()
|
|
|
|
if cc.closed {
|
|
msg.Free()
|
|
return
|
|
}
|
|
|
|
select {
|
|
case cc.messages <- msg:
|
|
case <-cc.ctx.Done():
|
|
msg.Free()
|
|
}
|
|
}
|
|
|
|
func (cc *connContainer) close() {
|
|
cc.cancel()
|
|
|
|
cc.msgChanLock.Lock()
|
|
defer cc.msgChanLock.Unlock()
|
|
|
|
if cc.closed {
|
|
return
|
|
}
|
|
|
|
cc.closed = true
|
|
close(cc.messages)
|
|
|
|
for msg := range cc.messages {
|
|
msg.Free()
|
|
}
|
|
}
|
|
|
|
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
|
|
// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
|
|
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
|
|
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
|
|
type Client struct {
|
|
log *log.Entry
|
|
parentCtx context.Context
|
|
connectionURL string
|
|
authTokenStore *auth.TokenStore
|
|
hashedID []byte
|
|
|
|
bufPool *sync.Pool
|
|
|
|
relayConn net.Conn
|
|
conns map[string]*connContainer
|
|
serviceIsRunning bool
|
|
mu sync.Mutex // protect serviceIsRunning and conns
|
|
readLoopMutex sync.Mutex
|
|
wgReadLoop sync.WaitGroup
|
|
instanceURL *RelayAddr
|
|
muInstanceURL sync.Mutex
|
|
|
|
onDisconnectListener func()
|
|
listenerMutex sync.Mutex
|
|
}
|
|
|
|
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
|
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
|
hashedID, hashedStringId := messages.HashID(peerID)
|
|
return &Client{
|
|
log: log.WithFields(log.Fields{"client_id": hashedStringId, "relay": serverURL}),
|
|
parentCtx: ctx,
|
|
connectionURL: serverURL,
|
|
authTokenStore: authTokenStore,
|
|
hashedID: hashedID,
|
|
bufPool: &sync.Pool{
|
|
New: func() any {
|
|
buf := make([]byte, bufferSize)
|
|
return &buf
|
|
},
|
|
},
|
|
conns: make(map[string]*connContainer),
|
|
}
|
|
}
|
|
|
|
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
|
func (c *Client) Connect() error {
|
|
c.log.Infof("connecting to relay server")
|
|
c.readLoopMutex.Lock()
|
|
defer c.readLoopMutex.Unlock()
|
|
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.serviceIsRunning {
|
|
return nil
|
|
}
|
|
|
|
err := c.connect()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
c.serviceIsRunning = true
|
|
|
|
c.wgReadLoop.Add(1)
|
|
go c.readLoop(c.relayConn)
|
|
|
|
c.log.Infof("relay connection established")
|
|
return nil
|
|
}
|
|
|
|
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
|
|
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
|
|
// it will return immediately.
|
|
// todo: what should happen if call with the same peerID with multiple times?
|
|
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if !c.serviceIsRunning {
|
|
return nil, fmt.Errorf("relay connection is not established")
|
|
}
|
|
|
|
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
|
_, ok := c.conns[hashedStringID]
|
|
if ok {
|
|
return nil, ErrConnAlreadyExists
|
|
}
|
|
|
|
c.log.Infof("open connection to peer: %s", hashedStringID)
|
|
msgChannel := make(chan Msg, 2)
|
|
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
|
|
|
|
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
|
|
return conn, nil
|
|
}
|
|
|
|
// ServerInstanceURL returns the address of the relay server. It could change after the close and reopen the connection.
|
|
func (c *Client) ServerInstanceURL() (string, error) {
|
|
c.muInstanceURL.Lock()
|
|
defer c.muInstanceURL.Unlock()
|
|
if c.instanceURL == nil {
|
|
return "", fmt.Errorf("relay connection is not established")
|
|
}
|
|
return c.instanceURL.String(), nil
|
|
}
|
|
|
|
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
|
|
func (c *Client) SetOnDisconnectListener(fn func()) {
|
|
c.listenerMutex.Lock()
|
|
defer c.listenerMutex.Unlock()
|
|
c.onDisconnectListener = fn
|
|
}
|
|
|
|
// HasConns returns true if there are connections.
|
|
func (c *Client) HasConns() bool {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return len(c.conns) > 0
|
|
}
|
|
|
|
// Close closes the connection to the relay server and all connections to other peers.
|
|
func (c *Client) Close() error {
|
|
return c.close(true)
|
|
}
|
|
|
|
func (c *Client) connect() error {
|
|
conn, err := ws.Dial(c.connectionURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.relayConn = conn
|
|
|
|
err = c.handShake()
|
|
if err != nil {
|
|
cErr := conn.Close()
|
|
if cErr != nil {
|
|
c.log.Errorf("failed to close connection: %s", cErr)
|
|
}
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) handShake() error {
|
|
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
|
if err != nil {
|
|
c.log.Errorf("failed to marshal auth message: %s", err)
|
|
return err
|
|
}
|
|
|
|
_, err = c.relayConn.Write(msg)
|
|
if err != nil {
|
|
c.log.Errorf("failed to send auth message: %s", err)
|
|
return err
|
|
}
|
|
buf := make([]byte, messages.MaxHandshakeRespSize)
|
|
n, err := c.readWithTimeout(buf)
|
|
if err != nil {
|
|
c.log.Errorf("failed to read auth response: %s", err)
|
|
return err
|
|
}
|
|
|
|
_, err = messages.ValidateVersion(buf[:n])
|
|
if err != nil {
|
|
return fmt.Errorf("validate version: %w", err)
|
|
}
|
|
|
|
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
|
if err != nil {
|
|
c.log.Errorf("failed to determine message type: %s", err)
|
|
return err
|
|
}
|
|
|
|
if msgType != messages.MsgTypeAuthResponse {
|
|
c.log.Errorf("unexpected message type: %s", msgType)
|
|
return fmt.Errorf("unexpected message type")
|
|
}
|
|
|
|
addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
c.muInstanceURL.Lock()
|
|
c.instanceURL = &RelayAddr{addr: addr}
|
|
c.muInstanceURL.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) readLoop(relayConn net.Conn) {
|
|
internallyStoppedFlag := newInternalStopFlag()
|
|
hc := healthcheck.NewReceiver()
|
|
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
|
|
|
|
var (
|
|
errExit error
|
|
n int
|
|
)
|
|
for {
|
|
bufPtr := c.bufPool.Get().(*[]byte)
|
|
buf := *bufPtr
|
|
n, errExit = relayConn.Read(buf)
|
|
if errExit != nil {
|
|
c.log.Infof("start to Relay read loop exit")
|
|
c.mu.Lock()
|
|
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
|
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
|
}
|
|
c.mu.Unlock()
|
|
break
|
|
}
|
|
|
|
_, err := messages.ValidateVersion(buf[:n])
|
|
if err != nil {
|
|
c.log.Errorf("failed to validate protocol version: %s", err)
|
|
c.bufPool.Put(bufPtr)
|
|
continue
|
|
}
|
|
|
|
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
|
|
if err != nil {
|
|
c.log.Errorf("failed to determine message type: %s", err)
|
|
c.bufPool.Put(bufPtr)
|
|
continue
|
|
}
|
|
|
|
if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) {
|
|
break
|
|
}
|
|
}
|
|
|
|
hc.Stop()
|
|
|
|
c.muInstanceURL.Lock()
|
|
c.instanceURL = nil
|
|
c.muInstanceURL.Unlock()
|
|
|
|
c.notifyDisconnected()
|
|
c.wgReadLoop.Done()
|
|
_ = c.close(false)
|
|
}
|
|
|
|
func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) {
|
|
switch msgType {
|
|
case messages.MsgTypeHealthCheck:
|
|
c.handleHealthCheck(hc, internallyStoppedFlag)
|
|
c.bufPool.Put(bufPtr)
|
|
case messages.MsgTypeTransport:
|
|
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
|
|
case messages.MsgTypeClose:
|
|
c.log.Debugf("relay connection close by server")
|
|
c.bufPool.Put(bufPtr)
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (c *Client) handleHealthCheck(hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) {
|
|
msg := messages.MarshalHealthcheck()
|
|
_, wErr := c.relayConn.Write(msg)
|
|
if wErr != nil {
|
|
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
|
c.log.Errorf("failed to send heartbeat: %s", wErr)
|
|
}
|
|
}
|
|
hc.Heartbeat()
|
|
}
|
|
|
|
func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppedFlag *internalStopFlag) bool {
|
|
peerID, payload, err := messages.UnmarshalTransportMsg(buf)
|
|
if err != nil {
|
|
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
|
c.log.Errorf("failed to parse transport message: %v", err)
|
|
}
|
|
|
|
c.bufPool.Put(bufPtr)
|
|
return true
|
|
}
|
|
|
|
stringID := messages.HashIDToString(peerID)
|
|
|
|
c.mu.Lock()
|
|
if !c.serviceIsRunning {
|
|
c.mu.Unlock()
|
|
c.bufPool.Put(bufPtr)
|
|
return false
|
|
}
|
|
container, ok := c.conns[stringID]
|
|
c.mu.Unlock()
|
|
if !ok {
|
|
c.log.Errorf("peer not found: %s", stringID)
|
|
c.bufPool.Put(bufPtr)
|
|
return true
|
|
}
|
|
msg := Msg{
|
|
bufPool: c.bufPool,
|
|
bufPtr: bufPtr,
|
|
Payload: payload,
|
|
}
|
|
container.writeMsg(msg)
|
|
return true
|
|
}
|
|
|
|
func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) {
|
|
c.mu.Lock()
|
|
conn, ok := c.conns[id]
|
|
c.mu.Unlock()
|
|
if !ok {
|
|
return 0, io.EOF
|
|
}
|
|
|
|
if conn.conn != connReference {
|
|
return 0, io.EOF
|
|
}
|
|
|
|
// todo: use buffer pool instead of create new transport msg.
|
|
msg, err := messages.MarshalTransportMsg(dstID, payload)
|
|
if err != nil {
|
|
c.log.Errorf("failed to marshal transport message: %s", err)
|
|
return 0, err
|
|
}
|
|
|
|
// the write always return with 0 length because the underling does not support the size feedback.
|
|
_, err = c.relayConn.Write(msg)
|
|
if err != nil {
|
|
c.log.Errorf("failed to write transport message: %s", err)
|
|
}
|
|
return len(payload), err
|
|
}
|
|
|
|
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
|
for {
|
|
select {
|
|
case _, ok := <-hc.OnTimeout:
|
|
if !ok {
|
|
return
|
|
}
|
|
c.log.Errorf("health check timeout")
|
|
internalStopFlag.set()
|
|
_ = conn.Close() // ignore the err because the readLoop will handle it
|
|
return
|
|
case <-c.parentCtx.Done():
|
|
err := c.close(true)
|
|
if err != nil {
|
|
c.log.Errorf("failed to teardown connection: %s", err)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) closeAllConns() {
|
|
for _, container := range c.conns {
|
|
container.close()
|
|
}
|
|
c.conns = make(map[string]*connContainer)
|
|
}
|
|
|
|
func (c *Client) closeConn(connReference *Conn, id string) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
container, ok := c.conns[id]
|
|
if !ok {
|
|
return fmt.Errorf("connection already closed")
|
|
}
|
|
|
|
if container.conn != connReference {
|
|
return fmt.Errorf("conn reference mismatch")
|
|
}
|
|
delete(c.conns, id)
|
|
container.close()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) close(gracefullyExit bool) error {
|
|
c.readLoopMutex.Lock()
|
|
defer c.readLoopMutex.Unlock()
|
|
|
|
c.mu.Lock()
|
|
var err error
|
|
if !c.serviceIsRunning {
|
|
c.mu.Unlock()
|
|
c.log.Warn("relay connection was already marked as not running")
|
|
return nil
|
|
}
|
|
|
|
c.serviceIsRunning = false
|
|
c.log.Infof("closing all peer connections")
|
|
c.closeAllConns()
|
|
if gracefullyExit {
|
|
c.writeCloseMsg()
|
|
}
|
|
err = c.relayConn.Close()
|
|
c.mu.Unlock()
|
|
|
|
c.log.Infof("waiting for read loop to close")
|
|
c.wgReadLoop.Wait()
|
|
c.log.Infof("relay connection closed")
|
|
return err
|
|
}
|
|
|
|
func (c *Client) notifyDisconnected() {
|
|
c.listenerMutex.Lock()
|
|
defer c.listenerMutex.Unlock()
|
|
|
|
if c.onDisconnectListener == nil {
|
|
return
|
|
}
|
|
go c.onDisconnectListener()
|
|
}
|
|
|
|
func (c *Client) writeCloseMsg() {
|
|
msg := messages.MarshalCloseMsg()
|
|
_, err := c.relayConn.Write(msg)
|
|
if err != nil {
|
|
c.log.Errorf("failed to send close message: %s", err)
|
|
}
|
|
}
|
|
|
|
func (c *Client) readWithTimeout(buf []byte) (int, error) {
|
|
ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout)
|
|
defer cancel()
|
|
|
|
readDone := make(chan struct{})
|
|
var (
|
|
n int
|
|
err error
|
|
)
|
|
|
|
go func() {
|
|
n, err = c.relayConn.Read(buf)
|
|
close(readDone)
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return 0, fmt.Errorf("read operation timed out")
|
|
case <-readDone:
|
|
return n, err
|
|
}
|
|
}
|