netbird/relay/server/peer.go

175 lines
3.7 KiB
Go
Raw Normal View History

2024-05-17 17:43:28 +02:00
package server
import (
"context"
"fmt"
"io"
2024-05-17 17:43:28 +02:00
"net"
"sync"
"time"
2024-05-17 17:43:28 +02:00
log "github.com/sirupsen/logrus"
2024-06-27 18:40:12 +02:00
"github.com/netbirdio/netbird/relay/healthcheck"
2024-05-23 13:24:02 +02:00
"github.com/netbirdio/netbird/relay/messages"
)
2024-05-17 17:43:28 +02:00
const (
bufferSize = 8820
)
2024-07-29 21:53:07 +02:00
// Peer represents a peer connection
2024-05-17 17:43:28 +02:00
type Peer struct {
log *log.Entry
idS string
idB []byte
conn net.Conn
connMu sync.RWMutex
store *Store
2024-05-17 17:43:28 +02:00
}
2024-07-29 21:53:07 +02:00
// NewPeer creates a new Peer instance and prepare custom logging
func NewPeer(id []byte, conn net.Conn, store *Store) *Peer {
2024-05-23 13:24:02 +02:00
stringID := messages.HashIDToString(id)
2024-05-17 17:43:28 +02:00
return &Peer{
log: log.WithField("peer_id", stringID),
idS: stringID,
idB: id,
conn: conn,
store: store,
}
}
2024-07-29 21:53:07 +02:00
// Work reads data from the connection
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
// the message accordingly.
func (p *Peer) Work() {
2024-06-27 18:40:12 +02:00
ctx, cancel := context.WithCancel(context.Background())
hc := healthcheck.NewSender(ctx)
go p.healthcheck(ctx, hc)
defer cancel()
buf := make([]byte, bufferSize)
for {
n, err := p.conn.Read(buf)
if err != nil {
if err != io.EOF {
p.log.Errorf("failed to read message: %s", err)
}
return
}
msg := buf[:n]
msgType, err := messages.DetermineClientMsgType(msg)
if err != nil {
p.log.Errorf("failed to determine message type: %s", err)
return
}
switch msgType {
2024-06-27 18:40:12 +02:00
case messages.MsgTypeHealthCheck:
hc.OnHCResponse()
case messages.MsgTypeTransport:
2024-07-09 16:50:29 +02:00
p.handleTransportMsg(msg)
2024-06-27 18:40:12 +02:00
case messages.MsgTypeClose:
p.log.Infof("peer exited gracefully")
_ = p.conn.Close()
return
}
2024-05-17 17:43:28 +02:00
}
}
// Write writes data to the connection
func (p *Peer) Write(b []byte) (int, error) {
p.connMu.RLock()
defer p.connMu.RUnlock()
return p.conn.Write(b)
}
2024-07-29 21:53:07 +02:00
// CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the
// connection.
func (p *Peer) CloseGracefully(ctx context.Context) {
p.connMu.Lock()
_, err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
if err != nil {
log.Errorf("failed to send close message to peer: %s", p.String())
}
err = p.conn.Close()
if err != nil {
log.Errorf("failed to close connection to peer: %s", err)
}
defer p.connMu.Unlock()
2024-05-17 17:43:28 +02:00
}
2024-07-29 21:53:07 +02:00
// String returns the peer ID
2024-05-23 13:24:02 +02:00
func (p *Peer) String() string {
return p.idS
2024-05-17 17:43:28 +02:00
}
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) (int, error) {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
writeDone := make(chan struct{})
var (
n int
err error
)
go func() {
_, err = p.conn.Write(buf)
close(writeDone)
}()
select {
case <-ctx.Done():
return 0, fmt.Errorf("write operation timed out")
case <-writeDone:
return n, err
}
}
2024-06-27 18:40:12 +02:00
func (p *Peer) healthcheck(ctx context.Context, hc *healthcheck.Sender) {
for {
select {
case <-hc.HealthCheck:
_, err := p.Write(messages.MarshalHealthcheck())
if err != nil {
p.log.Errorf("failed to send healthcheck message: %s", err)
return
}
case <-hc.Timeout:
p.log.Errorf("peer healthcheck timeout")
_ = p.conn.Close()
return
case <-ctx.Done():
return
}
}
}
2024-07-09 16:50:29 +02:00
func (p *Peer) handleTransportMsg(msg []byte) {
peerID, err := messages.UnmarshalTransportID(msg)
if err != nil {
p.log.Errorf("failed to unmarshal transport message: %s", err)
return
}
stringPeerID := messages.HashIDToString(peerID)
dp, ok := p.store.Peer(stringPeerID)
if !ok {
p.log.Errorf("peer not found: %s", stringPeerID)
return
}
err = messages.UpdateTransportMsg(msg, p.idB)
if err != nil {
p.log.Errorf("failed to update transport message: %s", err)
return
}
_, err = dp.Write(msg)
if err != nil {
p.log.Errorf("failed to write transport message to: %s", dp.String())
}
}