Add client side heartbeat handling

This commit is contained in:
Zoltán Papp
2024-06-29 14:13:05 +02:00
parent faeae52329
commit aa55fba5ee
2 changed files with 77 additions and 34 deletions

View File

@@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/client/dialer/ws" "github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
) )
@@ -23,6 +24,27 @@ var (
ErrConnAlreadyExists = fmt.Errorf("connection already exists") ErrConnAlreadyExists = fmt.Errorf("connection already exists")
) )
type internalStopFlag struct {
sync.Mutex
stop bool
}
func newInternalStopFlag() *internalStopFlag {
return &internalStopFlag{}
}
func (isf *internalStopFlag) set() {
isf.Lock()
defer isf.Unlock()
isf.stop = true
}
func (isf *internalStopFlag) isSet() bool {
isf.Lock()
defer isf.Unlock()
return isf.stop
}
// Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer. // Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer.
type Msg struct { type Msg struct {
Payload []byte Payload []byte
@@ -75,7 +97,6 @@ func (cc *connContainer) close() {
type Client struct { type Client struct {
log *log.Entry log *log.Entry
parentCtx context.Context parentCtx context.Context
ctxCancel context.CancelFunc
serverAddress string serverAddress string
hashedID []byte hashedID []byte
@@ -84,7 +105,7 @@ type Client struct {
relayConn net.Conn relayConn net.Conn
conns map[string]*connContainer conns map[string]*connContainer
serviceIsRunning bool serviceIsRunning bool
mu sync.Mutex mu sync.Mutex // protect serviceIsRunning and conns
readLoopMutex sync.Mutex readLoopMutex sync.Mutex
wgReadLoop sync.WaitGroup wgReadLoop sync.WaitGroup
@@ -100,7 +121,6 @@ func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
return &Client{ return &Client{
log: log.WithField("client_id", hashedStringId), log: log.WithField("client_id", hashedStringId),
parentCtx: ctx, parentCtx: ctx,
ctxCancel: func() {},
serverAddress: serverAddress, serverAddress: serverAddress,
hashedID: hashedID, hashedID: hashedID,
bufPool: &sync.Pool{ bufPool: &sync.Pool{
@@ -133,15 +153,6 @@ func (c *Client) Connect() error {
c.serviceIsRunning = true c.serviceIsRunning = true
var ctx context.Context
ctx, c.ctxCancel = context.WithCancel(c.parentCtx)
context.AfterFunc(ctx, func() {
cErr := c.close(false)
if cErr != nil {
log.Errorf("failed to close relay connection: %s", cErr)
}
})
c.wgReadLoop.Add(1) c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn) go c.readLoop(c.relayConn)
@@ -200,7 +211,7 @@ func (c *Client) HasConns() bool {
// Close closes the connection to the relay server and all connections to other peers. // Close closes the connection to the relay server and all connections to other peers.
func (c *Client) Close() error { func (c *Client) Close() error {
return c.close(false) return c.close(true)
} }
func (c *Client) connect() error { func (c *Client) connect() error {
@@ -257,10 +268,13 @@ func (c *Client) handShake() error {
} }
func (c *Client) readLoop(relayConn net.Conn) { func (c *Client) readLoop(relayConn net.Conn) {
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver()
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
var ( var (
errExit error errExit error
n int n int
closedByServer bool
) )
for { for {
bufPtr := c.bufPool.Get().(*[]byte) bufPtr := c.bufPool.Get().(*[]byte)
@@ -268,7 +282,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
n, errExit = relayConn.Read(buf) n, errExit = relayConn.Read(buf)
if errExit != nil { if errExit != nil {
c.mu.Lock() c.mu.Lock()
if c.serviceIsRunning { if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit) c.log.Debugf("failed to read message from relay server: %s", errExit)
} }
c.mu.Unlock() c.mu.Unlock()
@@ -283,15 +297,19 @@ func (c *Client) readLoop(relayConn net.Conn) {
switch msgType { switch msgType {
case messages.MsgTypeHealthCheck: case messages.MsgTypeHealthCheck:
log.Debugf("on new heartbeat")
msg := messages.MarshalHealthcheck() msg := messages.MarshalHealthcheck()
_, err := c.relayConn.Write(msg) _, wErr := c.relayConn.Write(msg)
if err != nil { if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Errorf("failed to send heartbeat response: %s", err) c.log.Errorf("failed to send heartbeat: %s", wErr)
} }
hc.Heartbeat()
case messages.MsgTypeTransport: case messages.MsgTypeTransport:
peerID, payload, err := messages.UnmarshalTransportMsg(buf[:n]) peerID, payload, err := messages.UnmarshalTransportMsg(buf[:n])
if err != nil { if err != nil {
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Errorf("failed to parse transport message: %v", err) c.log.Errorf("failed to parse transport message: %v", err)
}
continue continue
} }
stringID := messages.HashIDToString(peerID) stringID := messages.HashIDToString(peerID)
@@ -313,16 +331,16 @@ func (c *Client) readLoop(relayConn net.Conn) {
bufPtr: bufPtr, bufPtr: bufPtr,
Payload: payload}) Payload: payload})
case messages.MsgTypeClose: case messages.MsgTypeClose:
closedByServer = true
log.Debugf("relay connection close by server") log.Debugf("relay connection close by server")
goto Exit goto Exit
} }
} }
Exit: Exit:
hc.Stop()
c.notifyDisconnected() c.notifyDisconnected()
c.wgReadLoop.Done() c.wgReadLoop.Done()
_ = c.close(closedByServer) _ = c.close(false)
} }
// todo check by reference too, the id is not enought because the id come from the outer conn // todo check by reference too, the id is not enought because the id come from the outer conn
@@ -352,6 +370,27 @@ func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
return n, err return n, err
} }
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
for {
select {
case _, ok := <-hc.OnTimeout:
if !ok {
return
}
c.log.Errorf("health check timeout")
internalStopFlag.set()
_ = conn.Close() // ignore the err because the readLoop will handle it
return
case <-c.parentCtx.Done():
err := c.close(true)
if err != nil {
log.Errorf("failed to teardown connection: %s", err)
}
return
}
}
}
func (c *Client) closeAllConns() { func (c *Client) closeAllConns() {
for _, container := range c.conns { for _, container := range c.conns {
container.close() container.close()
@@ -374,7 +413,7 @@ func (c *Client) closeConn(id string) error {
return nil return nil
} }
func (c *Client) close(byServer bool) error { func (c *Client) close(gracefullyExit bool) error {
c.readLoopMutex.Lock() c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock() defer c.readLoopMutex.Unlock()
@@ -387,7 +426,7 @@ func (c *Client) close(byServer bool) error {
c.serviceIsRunning = false c.serviceIsRunning = false
c.closeAllConns() c.closeAllConns()
if !byServer { if gracefullyExit {
c.writeCloseMsg() c.writeCloseMsg()
err = c.relayConn.Close() err = c.relayConn.Close()
} }
@@ -395,7 +434,6 @@ func (c *Client) close(byServer bool) error {
c.wgReadLoop.Wait() c.wgReadLoop.Wait()
c.log.Infof("relay connection closed with: %s", c.serverAddress) c.log.Infof("relay connection closed with: %s", c.serverAddress)
c.ctxCancel()
return err return err
} }

View File

@@ -20,7 +20,7 @@ type Receiver struct {
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
heartbeat chan struct{} heartbeat chan struct{}
live bool alive bool
} }
// NewReceiver creates a new healthcheck receiver and start the timer in the background // NewReceiver creates a new healthcheck receiver and start the timer in the background
@@ -60,19 +60,24 @@ func (r *Receiver) waitForHealthcheck() {
for { for {
select { select {
case <-r.heartbeat: case <-r.heartbeat:
r.live = true r.alive = true
case <-ticker.C: case <-ticker.C:
if r.live { if r.alive {
r.live = false r.alive = false
continue continue
} }
select {
case r.OnTimeout <- struct{}{}: r.notifyTimeout()
default:
}
return return
case <-r.ctx.Done(): case <-r.ctx.Done():
return return
} }
} }
} }
func (r *Receiver) notifyTimeout() {
select {
case r.OnTimeout <- struct{}{}:
default:
}
}