diff --git a/relay/client/client.go b/relay/client/client.go index 3f689e6dd..2913550c7 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -141,6 +141,7 @@ func (c *Client) Connect() error { log.Errorf("failed to close relay connection: %s", cErr) } }) + c.wgReadLoop.Add(1) go c.readLoop(c.relayConn) @@ -306,6 +307,12 @@ func (c *Client) readLoop(relayConn net.Conn) { } switch msgType { + case messages.MsgTypeHealthCheck: + msg := messages.MarshalHealthcheck() + _, err := c.relayConn.Write(msg) + if err != nil { + c.log.Errorf("failed to send heartbeat response: %s", err) + } case messages.MsgTypeTransport: peerID, payload, err := messages.UnmarshalTransportMsg(buf[:n]) if err != nil { @@ -330,7 +337,7 @@ func (c *Client) readLoop(relayConn net.Conn) { bufPool: c.bufPool, bufPtr: bufPtr, Payload: payload}) - case messages.MsgClose: + case messages.MsgTypeClose: closedByServer = true log.Debugf("relay connection close by server") goto Exit diff --git a/relay/healthcheck/receiver.go b/relay/healthcheck/receiver.go new file mode 100644 index 000000000..157fb6684 --- /dev/null +++ b/relay/healthcheck/receiver.go @@ -0,0 +1,75 @@ +package healthcheck + +import ( + "context" + "time" +) + +var ( + heartbeatTimeout = healthCheckInterval + 3*time.Second +) + +// Receiver is a healthcheck receiver +// It will listen for heartbeat and check if the heartbeat is not received in a certain time +// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work +// It will also stop if the context is canceled +// The heartbeat timeout is a bit longer than the sender's healthcheck interval +type Receiver struct { + OnTimeout chan struct{} + + ctx context.Context + ctxCancel context.CancelFunc + heartbeat chan struct{} + live bool +} + +func NewReceiver() *Receiver { + ctx, ctxCancel := context.WithCancel(context.Background()) + + r := &Receiver{ + OnTimeout: make(chan struct{}, 1), + ctx: ctx, + ctxCancel: ctxCancel, + heartbeat: make(chan struct{}, 1), + } + + go r.waitForHealthcheck() + return r +} + +func (r *Receiver) Heartbeat() { + select { + case r.heartbeat <- struct{}{}: + default: + } +} + +func (r *Receiver) Stop() { + r.ctxCancel() +} + +func (r *Receiver) waitForHealthcheck() { + ticker := time.NewTicker(heartbeatTimeout) + defer ticker.Stop() + defer r.ctxCancel() + defer close(r.OnTimeout) + + for { + select { + case <-r.heartbeat: + r.live = true + case <-ticker.C: + if r.live { + r.live = false + continue + } + select { + case r.OnTimeout <- struct{}{}: + default: + } + return + case <-r.ctx.Done(): + return + } + } +} diff --git a/relay/healthcheck/receiver_test.go b/relay/healthcheck/receiver_test.go new file mode 100644 index 000000000..4b4123416 --- /dev/null +++ b/relay/healthcheck/receiver_test.go @@ -0,0 +1,42 @@ +package healthcheck + +import ( + "testing" + "time" +) + +func TestNewReceiver(t *testing.T) { + heartbeatTimeout = 5 * time.Second + r := NewReceiver() + + select { + case <-r.OnTimeout: + t.Error("unexpected timeout") + case <-time.After(1 * time.Second): + + } +} + +func TestNewReceiverNotReceive(t *testing.T) { + heartbeatTimeout = 1 * time.Second + r := NewReceiver() + + select { + case <-r.OnTimeout: + case <-time.After(2 * time.Second): + t.Error("timeout not received") + } +} + +func TestNewReceiverAck(t *testing.T) { + heartbeatTimeout = 2 * time.Second + r := NewReceiver() + + r.Heartbeat() + + select { + case <-r.OnTimeout: + t.Error("unexpected timeout") + case <-time.After(3 * time.Second): + } +} diff --git a/relay/healthcheck/sender.go b/relay/healthcheck/sender.go new file mode 100644 index 000000000..401170ec9 --- /dev/null +++ b/relay/healthcheck/sender.go @@ -0,0 +1,68 @@ +package healthcheck + +import ( + "context" + "time" +) + +var ( + healthCheckInterval = 25 * time.Second + healthCheckTimeout = 5 * time.Second +) + +// Sender is a healthcheck sender +// It will send healthcheck signal to the receiver +// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work +// It will also stop if the context is canceled +type Sender struct { + HealthCheck chan struct{} + Timeout chan struct{} + + ctx context.Context + ack chan struct{} +} + +// NewSender creates a new healthcheck sender +func NewSender(ctx context.Context) *Sender { + hc := &Sender{ + HealthCheck: make(chan struct{}, 1), + Timeout: make(chan struct{}, 1), + ctx: ctx, + ack: make(chan struct{}, 1), + } + + go hc.healthCheck() + return hc +} + +func (hc *Sender) OnHCResponse() { + select { + case hc.ack <- struct{}{}: + default: + } +} + +func (hc *Sender) healthCheck() { + ticker := time.NewTicker(healthCheckInterval) + defer ticker.Stop() + + timeoutTimer := time.NewTimer(healthCheckInterval + healthCheckTimeout) + defer timeoutTimer.Stop() + + defer close(hc.HealthCheck) + defer close(hc.Timeout) + + for { + select { + case <-ticker.C: + hc.HealthCheck <- struct{}{} + case <-timeoutTimer.C: + hc.Timeout <- struct{}{} + return + case <-hc.ack: + timeoutTimer.Stop() + case <-hc.ctx.Done(): + return + } + } +} diff --git a/relay/healthcheck/sender_test.go b/relay/healthcheck/sender_test.go new file mode 100644 index 000000000..a571437ef --- /dev/null +++ b/relay/healthcheck/sender_test.go @@ -0,0 +1,66 @@ +package healthcheck + +import ( + "context" + "testing" + "time" +) + +func TestNewHealthPeriod(t *testing.T) { + // override the health check interval to speed up the test + healthCheckInterval = 1 * time.Second + healthCheckTimeout = 100 * time.Millisecond + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hc := NewSender(ctx) + + iterations := 0 + for i := 0; i < 3; i++ { + select { + case <-hc.HealthCheck: + iterations++ + hc.OnHCResponse() + case <-hc.Timeout: + t.Fatalf("health check is timed out") + case <-time.After(healthCheckInterval + 100*time.Millisecond): + t.Fatalf("health check not received") + } + } +} + +func TestNewHealthFailed(t *testing.T) { + // override the health check interval to speed up the test + healthCheckInterval = 1 * time.Second + healthCheckTimeout = 500 * time.Millisecond + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hc := NewSender(ctx) + + select { + case <-hc.Timeout: + case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond): + t.Fatalf("health check is not timed out") + } +} + +func TestNewHealthcheckStop(t *testing.T) { + + ctx, cancel := context.WithCancel(context.Background()) + hc := NewSender(ctx) + + time.Sleep(1 * time.Second) + cancel() + + select { + case <-hc.HealthCheck: + t.Fatalf("is not closed") + case <-hc.Timeout: + t.Fatalf("is not closed") + case <-ctx.Done(): + // expected + case <-time.After(1 * time.Second): + t.Fatalf("is not exited") + } +} diff --git a/relay/messages/message.go b/relay/messages/message.go index d2a6d46d7..aa62fe867 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -11,7 +11,8 @@ const ( MsgTypeHello MsgType = 0 MsgTypeHelloResponse MsgType = 1 MsgTypeTransport MsgType = 2 - MsgClose MsgType = 3 + MsgTypeClose MsgType = 3 + MsgTypeHealthCheck MsgType = 4 headerSizeTransport = 1 + IDSize // 1 byte for msg type, IDSize for peerID headerSizeHello = 1 + 4 + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID @@ -23,6 +24,8 @@ var ( ErrInvalidMessageLength = fmt.Errorf("invalid message length") magicHeader = []byte{0x21, 0x12, 0xA4, 0x42} + + healthCheckMsg = []byte{byte(MsgTypeHealthCheck)} ) type MsgType byte @@ -35,7 +38,7 @@ func (m MsgType) String() string { return "hello response" case MsgTypeTransport: return "transport" - case MsgClose: + case MsgTypeClose: return "close" default: return "unknown" @@ -49,7 +52,9 @@ func DetermineClientMsgType(msg []byte) (MsgType, error) { return msgType, nil case MsgTypeTransport: return msgType, nil - case MsgClose: + case MsgTypeClose: + return msgType, nil + case MsgTypeHealthCheck: return msgType, nil default: return 0, fmt.Errorf("invalid msg type, len: %d", len(msg)) @@ -63,7 +68,9 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) { return msgType, nil case MsgTypeTransport: return msgType, nil - case MsgClose: + case MsgTypeClose: + return msgType, nil + case MsgTypeHealthCheck: return msgType, nil default: return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg)) @@ -100,8 +107,8 @@ func MarshalHelloResponse() []byte { func MarshalCloseMsg() []byte { msg := make([]byte, 1) - msg[0] = byte(MsgClose) - return msg + msg[0] = byte(MsgTypeClose) + return healthCheckMsg } // Transport message @@ -141,3 +148,9 @@ func UpdateTransportMsg(msg []byte, peerID []byte) error { copy(msg[1:], peerID) return nil } + +// health check message + +func MarshalHealthcheck() []byte { + return healthCheckMsg +} diff --git a/relay/server/peer.go b/relay/server/peer.go index 14a86a8ab..2869340a9 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/messages" ) @@ -38,6 +39,11 @@ func NewPeer(id []byte, conn net.Conn, store *Store) *Peer { } func (p *Peer) Work() { + ctx, cancel := context.WithCancel(context.Background()) + hc := healthcheck.NewSender(ctx) + go p.healthcheck(ctx, hc) + defer cancel() + buf := make([]byte, bufferSize) for { n, err := p.conn.Read(buf) @@ -56,7 +62,8 @@ func (p *Peer) Work() { return } switch msgType { - case messages.MsgHealthCheck: + case messages.MsgTypeHealthCheck: + hc.OnHCResponse() case messages.MsgTypeTransport: peerID, err := messages.UnmarshalTransportID(msg) if err != nil { @@ -78,7 +85,7 @@ func (p *Peer) Work() { if err != nil { p.log.Errorf("failed to write transport message to: %s", dp.String()) } - case messages.MsgClose: + case messages.MsgTypeClose: p.log.Infof("peer exited gracefully") _ = p.conn.Close() return @@ -135,3 +142,23 @@ func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) (int, error) { return n, err } } + +func (p *Peer) healthcheck(ctx context.Context, hc *healthcheck.Sender) { + for { + select { + case <-hc.HealthCheck: + p.log.Debugf("sending healthcheck message") + _, err := p.Write(messages.MarshalHealthcheck()) + if err != nil { + p.log.Errorf("failed to send healthcheck message: %s", err) + return + } + case <-hc.Timeout: + p.log.Errorf("peer healthcheck timeout") + _ = p.conn.Close() + return + case <-ctx.Done(): + return + } + } +}