mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-17 18:41:41 +02:00
Add client side heartbeat handling
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
"github.com/netbirdio/netbird/relay/client/dialer/ws"
|
||||||
|
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,6 +24,27 @@ var (
|
|||||||
ErrConnAlreadyExists = fmt.Errorf("connection already exists")
|
ErrConnAlreadyExists = fmt.Errorf("connection already exists")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type internalStopFlag struct {
|
||||||
|
sync.Mutex
|
||||||
|
stop bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInternalStopFlag() *internalStopFlag {
|
||||||
|
return &internalStopFlag{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (isf *internalStopFlag) set() {
|
||||||
|
isf.Lock()
|
||||||
|
defer isf.Unlock()
|
||||||
|
isf.stop = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (isf *internalStopFlag) isSet() bool {
|
||||||
|
isf.Lock()
|
||||||
|
defer isf.Unlock()
|
||||||
|
return isf.stop
|
||||||
|
}
|
||||||
|
|
||||||
// Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer.
|
// Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer.
|
||||||
type Msg struct {
|
type Msg struct {
|
||||||
Payload []byte
|
Payload []byte
|
||||||
@@ -75,7 +97,6 @@ func (cc *connContainer) close() {
|
|||||||
type Client struct {
|
type Client struct {
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
parentCtx context.Context
|
parentCtx context.Context
|
||||||
ctxCancel context.CancelFunc
|
|
||||||
serverAddress string
|
serverAddress string
|
||||||
hashedID []byte
|
hashedID []byte
|
||||||
|
|
||||||
@@ -84,7 +105,7 @@ type Client struct {
|
|||||||
relayConn net.Conn
|
relayConn net.Conn
|
||||||
conns map[string]*connContainer
|
conns map[string]*connContainer
|
||||||
serviceIsRunning bool
|
serviceIsRunning bool
|
||||||
mu sync.Mutex
|
mu sync.Mutex // protect serviceIsRunning and conns
|
||||||
readLoopMutex sync.Mutex
|
readLoopMutex sync.Mutex
|
||||||
wgReadLoop sync.WaitGroup
|
wgReadLoop sync.WaitGroup
|
||||||
|
|
||||||
@@ -100,7 +121,6 @@ func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
|
|||||||
return &Client{
|
return &Client{
|
||||||
log: log.WithField("client_id", hashedStringId),
|
log: log.WithField("client_id", hashedStringId),
|
||||||
parentCtx: ctx,
|
parentCtx: ctx,
|
||||||
ctxCancel: func() {},
|
|
||||||
serverAddress: serverAddress,
|
serverAddress: serverAddress,
|
||||||
hashedID: hashedID,
|
hashedID: hashedID,
|
||||||
bufPool: &sync.Pool{
|
bufPool: &sync.Pool{
|
||||||
@@ -133,15 +153,6 @@ func (c *Client) Connect() error {
|
|||||||
|
|
||||||
c.serviceIsRunning = true
|
c.serviceIsRunning = true
|
||||||
|
|
||||||
var ctx context.Context
|
|
||||||
ctx, c.ctxCancel = context.WithCancel(c.parentCtx)
|
|
||||||
context.AfterFunc(ctx, func() {
|
|
||||||
cErr := c.close(false)
|
|
||||||
if cErr != nil {
|
|
||||||
log.Errorf("failed to close relay connection: %s", cErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
c.wgReadLoop.Add(1)
|
c.wgReadLoop.Add(1)
|
||||||
go c.readLoop(c.relayConn)
|
go c.readLoop(c.relayConn)
|
||||||
|
|
||||||
@@ -200,7 +211,7 @@ func (c *Client) HasConns() bool {
|
|||||||
|
|
||||||
// Close closes the connection to the relay server and all connections to other peers.
|
// Close closes the connection to the relay server and all connections to other peers.
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
return c.close(false)
|
return c.close(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) connect() error {
|
func (c *Client) connect() error {
|
||||||
@@ -257,10 +268,13 @@ func (c *Client) handShake() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) readLoop(relayConn net.Conn) {
|
func (c *Client) readLoop(relayConn net.Conn) {
|
||||||
|
internallyStoppedFlag := newInternalStopFlag()
|
||||||
|
hc := healthcheck.NewReceiver()
|
||||||
|
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errExit error
|
errExit error
|
||||||
n int
|
n int
|
||||||
closedByServer bool
|
|
||||||
)
|
)
|
||||||
for {
|
for {
|
||||||
bufPtr := c.bufPool.Get().(*[]byte)
|
bufPtr := c.bufPool.Get().(*[]byte)
|
||||||
@@ -268,7 +282,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
|||||||
n, errExit = relayConn.Read(buf)
|
n, errExit = relayConn.Read(buf)
|
||||||
if errExit != nil {
|
if errExit != nil {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
if c.serviceIsRunning {
|
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||||
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
||||||
}
|
}
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
@@ -283,15 +297,19 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
|||||||
|
|
||||||
switch msgType {
|
switch msgType {
|
||||||
case messages.MsgTypeHealthCheck:
|
case messages.MsgTypeHealthCheck:
|
||||||
|
log.Debugf("on new heartbeat")
|
||||||
msg := messages.MarshalHealthcheck()
|
msg := messages.MarshalHealthcheck()
|
||||||
_, err := c.relayConn.Write(msg)
|
_, wErr := c.relayConn.Write(msg)
|
||||||
if err != nil {
|
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||||
c.log.Errorf("failed to send heartbeat response: %s", err)
|
c.log.Errorf("failed to send heartbeat: %s", wErr)
|
||||||
}
|
}
|
||||||
|
hc.Heartbeat()
|
||||||
case messages.MsgTypeTransport:
|
case messages.MsgTypeTransport:
|
||||||
peerID, payload, err := messages.UnmarshalTransportMsg(buf[:n])
|
peerID, payload, err := messages.UnmarshalTransportMsg(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("failed to parse transport message: %v", err)
|
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||||
|
c.log.Errorf("failed to parse transport message: %v", err)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
stringID := messages.HashIDToString(peerID)
|
stringID := messages.HashIDToString(peerID)
|
||||||
@@ -313,16 +331,16 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
|||||||
bufPtr: bufPtr,
|
bufPtr: bufPtr,
|
||||||
Payload: payload})
|
Payload: payload})
|
||||||
case messages.MsgTypeClose:
|
case messages.MsgTypeClose:
|
||||||
closedByServer = true
|
|
||||||
log.Debugf("relay connection close by server")
|
log.Debugf("relay connection close by server")
|
||||||
goto Exit
|
goto Exit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Exit:
|
Exit:
|
||||||
|
hc.Stop()
|
||||||
c.notifyDisconnected()
|
c.notifyDisconnected()
|
||||||
c.wgReadLoop.Done()
|
c.wgReadLoop.Done()
|
||||||
_ = c.close(closedByServer)
|
_ = c.close(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo check by reference too, the id is not enought because the id come from the outer conn
|
// todo check by reference too, the id is not enought because the id come from the outer conn
|
||||||
@@ -352,6 +370,27 @@ func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
|
|||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case _, ok := <-hc.OnTimeout:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.log.Errorf("health check timeout")
|
||||||
|
internalStopFlag.set()
|
||||||
|
_ = conn.Close() // ignore the err because the readLoop will handle it
|
||||||
|
return
|
||||||
|
case <-c.parentCtx.Done():
|
||||||
|
err := c.close(true)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to teardown connection: %s", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) closeAllConns() {
|
func (c *Client) closeAllConns() {
|
||||||
for _, container := range c.conns {
|
for _, container := range c.conns {
|
||||||
container.close()
|
container.close()
|
||||||
@@ -374,7 +413,7 @@ func (c *Client) closeConn(id string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) close(byServer bool) error {
|
func (c *Client) close(gracefullyExit bool) error {
|
||||||
c.readLoopMutex.Lock()
|
c.readLoopMutex.Lock()
|
||||||
defer c.readLoopMutex.Unlock()
|
defer c.readLoopMutex.Unlock()
|
||||||
|
|
||||||
@@ -387,7 +426,7 @@ func (c *Client) close(byServer bool) error {
|
|||||||
|
|
||||||
c.serviceIsRunning = false
|
c.serviceIsRunning = false
|
||||||
c.closeAllConns()
|
c.closeAllConns()
|
||||||
if !byServer {
|
if gracefullyExit {
|
||||||
c.writeCloseMsg()
|
c.writeCloseMsg()
|
||||||
err = c.relayConn.Close()
|
err = c.relayConn.Close()
|
||||||
}
|
}
|
||||||
@@ -395,7 +434,6 @@ func (c *Client) close(byServer bool) error {
|
|||||||
|
|
||||||
c.wgReadLoop.Wait()
|
c.wgReadLoop.Wait()
|
||||||
c.log.Infof("relay connection closed with: %s", c.serverAddress)
|
c.log.Infof("relay connection closed with: %s", c.serverAddress)
|
||||||
c.ctxCancel()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -20,7 +20,7 @@ type Receiver struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
heartbeat chan struct{}
|
heartbeat chan struct{}
|
||||||
live bool
|
alive bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewReceiver creates a new healthcheck receiver and start the timer in the background
|
// NewReceiver creates a new healthcheck receiver and start the timer in the background
|
||||||
@@ -60,19 +60,24 @@ func (r *Receiver) waitForHealthcheck() {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-r.heartbeat:
|
case <-r.heartbeat:
|
||||||
r.live = true
|
r.alive = true
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if r.live {
|
if r.alive {
|
||||||
r.live = false
|
r.alive = false
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
select {
|
|
||||||
case r.OnTimeout <- struct{}{}:
|
r.notifyTimeout()
|
||||||
default:
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
case <-r.ctx.Done():
|
case <-r.ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Receiver) notifyTimeout() {
|
||||||
|
select {
|
||||||
|
case r.OnTimeout <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user