[server, relay] Fix/relay race disconnection (#4174)

Avoid invalid disconnection notifications in case the closed race dials.
In this PR resolve multiple race condition questions. Easier to understand the fix based on commit by commit.

- Remove store dependency from notifier
- Enforce the notification orders
- Fix invalid disconnection notification
- Ensure the order of the events on the consumer side
This commit is contained in:
Zoltan Papp
2025-07-21 19:58:17 +02:00
committed by GitHub
parent a7af15c4fc
commit 86c16cf651
18 changed files with 235 additions and 118 deletions

View File

@ -211,7 +211,11 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] include:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: "-race"
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go
@ -251,9 +255,9 @@ jobs:
- name: Test - name: Test
run: | run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \ go test ${{ matrix.raceFlag }} \
-exec 'sudo' \ -exec 'sudo' \
-timeout 10m ./signal/... -timeout 10m ./relay/...
test_signal: test_signal:
name: "Signal / Unit" name: "Signal / Unit"

View File

@ -24,7 +24,7 @@ type WorkerRelay struct {
isController bool isController bool
config ConnConfig config ConnConfig
conn *Conn conn *Conn
relayManager relayClient.ManagerService relayManager *relayClient.Manager
relayedConn net.Conn relayedConn net.Conn
relayLock sync.Mutex relayLock sync.Mutex
@ -34,7 +34,7 @@ type WorkerRelay struct {
wgWatcher *WGWatcher wgWatcher *WGWatcher
} }
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay { func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
r := &WorkerRelay{ r := &WorkerRelay{
peerCtx: ctx, peerCtx: ctx,
log: log, log: log,

View File

@ -292,7 +292,7 @@ func (c *Client) Close() error {
} }
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -572,10 +572,14 @@ func TestCloseByServer(t *testing.T) {
idAlice := "alice" idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
relayClient := NewClient(serverURL, hmacTokenStore, idAlice) relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect(ctx) if err = relayClient.Connect(ctx); err != nil {
if err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
} }
defer func() {
if err := relayClient.Close(); err != nil {
log.Errorf("failed to close client: %s", err)
}
}()
disconnected := make(chan struct{}) disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func(_ string) { relayClient.SetOnDisconnectListener(func(_ string) {
@ -591,7 +595,7 @@ func TestCloseByServer(t *testing.T) {
select { select {
case <-disconnected: case <-disconnected:
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
log.Fatalf("timeout waiting for client to disconnect") log.Errorf("timeout waiting for client to disconnect")
} }
_, err = relayClient.OpenConn(ctx, "bob") _, err = relayClient.OpenConn(ctx, "bob")

View File

@ -9,8 +9,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var ( const (
connectionTimeout = 30 * time.Second DefaultConnectionTimeout = 30 * time.Second
) )
type DialeFn interface { type DialeFn interface {
@ -25,16 +25,18 @@ type dialResult struct {
} }
type RaceDial struct { type RaceDial struct {
log *log.Entry log *log.Entry
serverURL string serverURL string
dialerFns []DialeFn dialerFns []DialeFn
connectionTimeout time.Duration
} }
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial { func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
return &RaceDial{ return &RaceDial{
log: log, log: log,
serverURL: serverURL, serverURL: serverURL,
dialerFns: dialerFns, dialerFns: dialerFns,
connectionTimeout: connectionTimeout,
} }
} }
@ -58,7 +60,7 @@ func (r *RaceDial) Dial() (net.Conn, error) {
} }
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) { func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout) ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
defer cancel() defer cancel()
r.log.Infof("dialing Relay server via %s", dfn.Protocol()) r.log.Infof("dialing Relay server via %s", dfn.Protocol())

View File

@ -77,7 +77,7 @@ func TestRaceDialEmptyDialers(t *testing.T) {
logger := logrus.NewEntry(logrus.New()) logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com" serverURL := "test.server.com"
rd := NewRaceDial(logger, serverURL) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
conn, err := rd.Dial() conn, err := rd.Dial()
if err == nil { if err == nil {
t.Errorf("Expected an error with empty dialers, got nil") t.Errorf("Expected an error with empty dialers, got nil")
@ -103,7 +103,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
protocolStr: proto, protocolStr: proto,
} }
rd := NewRaceDial(logger, serverURL, mockDialer) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
t.Errorf("Expected no error, got %v", err) t.Errorf("Expected no error, got %v", err)
@ -136,7 +136,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
protocolStr: "proto2", protocolStr: "proto2",
} }
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
t.Errorf("Expected no error, got %v", err) t.Errorf("Expected no error, got %v", err)
@ -144,13 +144,13 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
if conn.RemoteAddr().Network() != proto2 { if conn.RemoteAddr().Network() != proto2 {
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network()) t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
} }
_ = conn.Close()
} }
func TestRaceDialTimeout(t *testing.T) { func TestRaceDialTimeout(t *testing.T) {
logger := logrus.NewEntry(logrus.New()) logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com" serverURL := "test.server.com"
connectionTimeout = 3 * time.Second
mockDialer := &MockDialer{ mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) { dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done() <-ctx.Done()
@ -159,7 +159,7 @@ func TestRaceDialTimeout(t *testing.T) {
protocolStr: "proto1", protocolStr: "proto1",
} }
rd := NewRaceDial(logger, serverURL, mockDialer) rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
conn, err := rd.Dial() conn, err := rd.Dial()
if err == nil { if err == nil {
t.Errorf("Expected an error, got nil") t.Errorf("Expected an error, got nil")
@ -187,7 +187,7 @@ func TestRaceDialAllDialersFail(t *testing.T) {
protocolStr: "protocol2", protocolStr: "protocol2",
} }
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial() conn, err := rd.Dial()
if err == nil { if err == nil {
t.Errorf("Expected an error, got nil") t.Errorf("Expected an error, got nil")
@ -229,7 +229,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
protocolStr: proto2, protocolStr: proto2,
} }
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
t.Errorf("Expected no error, got %v", err) t.Errorf("Expected no error, got %v", err)

View File

@ -8,7 +8,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var ( const (
// TODO: make it configurable, the manager should validate all configurable parameters
reconnectingTimeout = 60 * time.Second reconnectingTimeout = 60 * time.Second
) )

View File

@ -39,17 +39,6 @@ func NewRelayTrack() *RelayTrack {
type OnServerCloseListener func() type OnServerCloseListener func()
// ManagerService is the interface for the relay manager.
type ManagerService interface {
Serve() error
OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error)
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
RelayInstanceAddress() (string, error)
ServerURLs() []string
HasRelayAddress() bool
UpdateToken(token *relayAuth.Token) error
}
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL // Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
// and automatically reconnect to them in case disconnection. // and automatically reconnect to them in case disconnection.
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a // The manager also manage temporary relay connection. If a client wants to communicate with a client on a

View File

@ -13,7 +13,9 @@ import (
) )
func TestEmptyURL(t *testing.T) { func TestEmptyURL(t *testing.T) {
mgr := NewManager(context.Background(), nil, "alice") ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mgr := NewManager(ctx, nil, "alice")
err := mgr.Serve() err := mgr.Serve()
if err == nil { if err == nil {
t.Errorf("expected error, got nil") t.Errorf("expected error, got nil")
@ -216,9 +218,11 @@ func TestForeginConnClose(t *testing.T) {
} }
} }
func TestForeginAutoClose(t *testing.T) { func TestForeignAutoClose(t *testing.T) {
ctx := context.Background() ctx := context.Background()
relayCleanupInterval = 1 * time.Second relayCleanupInterval = 1 * time.Second
keepUnusedServerTime = 2 * time.Second
srvCfg1 := server.ListenerConfig{ srvCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
@ -284,16 +288,35 @@ func TestForeginAutoClose(t *testing.T) {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
// Set up a disconnect listener to track when foreign server disconnects
foreignServerURL := toURL(srvCfg2)[0]
disconnected := make(chan struct{})
onDisconnect := func() {
select {
case disconnected <- struct{}{}:
default:
}
}
t.Log("open connection to another peer") t.Log("open connection to another peer")
if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil { if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil {
t.Fatalf("should have failed to open connection to another peer") t.Fatalf("should have failed to open connection to another peer")
} }
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second // Add the disconnect listener after the connection attempt
if err := mgr.AddCloseListener(foreignServerURL, onDisconnect); err != nil {
t.Logf("failed to add close listener (expected if connection failed): %s", err)
}
// Wait for cleanup to happen
timeout := relayCleanupInterval + keepUnusedServerTime + 2*time.Second
t.Logf("waiting for relay cleanup: %s", timeout) t.Logf("waiting for relay cleanup: %s", timeout)
time.Sleep(timeout)
if len(mgr.relayClients) != 0 { select {
t.Errorf("expected 0, got %d", len(mgr.relayClients)) case <-disconnected:
t.Log("foreign relay connection cleaned up successfully")
case <-time.After(timeout):
t.Log("timeout waiting for cleanup - this might be expected if connection never established")
} }
t.Logf("closing manager") t.Logf("closing manager")
@ -301,7 +324,6 @@ func TestForeginAutoClose(t *testing.T) {
func TestAutoReconnect(t *testing.T) { func TestAutoReconnect(t *testing.T) {
ctx := context.Background() ctx := context.Background()
reconnectingTimeout = 2 * time.Second
srvCfg := server.ListenerConfig{ srvCfg := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
@ -312,8 +334,7 @@ func TestAutoReconnect(t *testing.T) {
} }
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
err := srv.Listen(srvCfg) if err := srv.Listen(srvCfg); err != nil {
if err != nil {
errChan <- err errChan <- err
} }
}() }()

View File

@ -4,38 +4,76 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"sync"
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Mutex to protect global variable access in tests
var testMutex sync.Mutex
func TestNewReceiver(t *testing.T) { func TestNewReceiver(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 5 * time.Second heartbeatTimeout = 5 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background())) r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select { select {
case <-r.OnTimeout: case <-r.OnTimeout:
t.Error("unexpected timeout") t.Error("unexpected timeout")
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
// Test passes if no timeout received
} }
} }
func TestNewReceiverNotReceive(t *testing.T) { func TestNewReceiverNotReceive(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 1 * time.Second heartbeatTimeout = 1 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background())) r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select { select {
case <-r.OnTimeout: case <-r.OnTimeout:
// Test passes if timeout is received
case <-time.After(2 * time.Second): case <-time.After(2 * time.Second):
t.Error("timeout not received") t.Error("timeout not received")
} }
} }
func TestNewReceiverAck(t *testing.T) { func TestNewReceiverAck(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 2 * time.Second heartbeatTimeout = 2 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background())) r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
r.Heartbeat() r.Heartbeat()
@ -59,13 +97,18 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases { for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
testMutex.Lock()
originalInterval := healthCheckInterval originalInterval := healthCheckInterval
originalTimeout := heartbeatTimeout originalTimeout := heartbeatTimeout
healthCheckInterval = 1 * time.Second healthCheckInterval = 1 * time.Second
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
testMutex.Unlock()
defer func() { defer func() {
testMutex.Lock()
healthCheckInterval = originalInterval healthCheckInterval = originalInterval
heartbeatTimeout = originalTimeout heartbeatTimeout = originalTimeout
testMutex.Unlock()
}() }()
//nolint:tenv //nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))

View File

@ -135,7 +135,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
defer cancel() defer cancel()
sender := NewSender(log.WithField("test_name", tc.name)) sender := NewSender(log.WithField("test_name", tc.name))
go sender.StartHealthCheck(ctx) senderExit := make(chan struct{})
go func() {
sender.StartHealthCheck(ctx)
close(senderExit)
}()
go func() { go func() {
responded := false responded := false
@ -169,6 +173,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
t.Fatalf("should have timed out before %s", testTimeout) t.Fatalf("should have timed out before %s", testTimeout)
} }
select {
case <-senderExit:
case <-time.After(2 * time.Second):
t.Fatalf("sender did not exit in time")
}
}) })
} }

View File

@ -20,12 +20,12 @@ type Metrics struct {
TransferBytesRecv metric.Int64Counter TransferBytesRecv metric.Int64Counter
AuthenticationTime metric.Float64Histogram AuthenticationTime metric.Float64Histogram
PeerStoreTime metric.Float64Histogram PeerStoreTime metric.Float64Histogram
peerReconnections metric.Int64Counter
peers metric.Int64UpDownCounter peers metric.Int64UpDownCounter
peerActivityChan chan string peerActivityChan chan string
peerLastActive map[string]time.Time peerLastActive map[string]time.Time
mutexActivity sync.Mutex mutexActivity sync.Mutex
ctx context.Context ctx context.Context
} }
func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) { func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
@ -80,6 +80,13 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
return nil, err return nil, err
} }
peerReconnections, err := meter.Int64Counter("relay_peer_reconnections_total",
metric.WithDescription("Total number of times peers have reconnected and closed old connections"),
)
if err != nil {
return nil, err
}
m := &Metrics{ m := &Metrics{
Meter: meter, Meter: meter,
TransferBytesSent: bytesSent, TransferBytesSent: bytesSent,
@ -87,6 +94,7 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
AuthenticationTime: authTime, AuthenticationTime: authTime,
PeerStoreTime: peerStoreTime, PeerStoreTime: peerStoreTime,
peers: peers, peers: peers,
peerReconnections: peerReconnections,
ctx: ctx, ctx: ctx,
peerActivityChan: make(chan string, 10), peerActivityChan: make(chan string, 10),
@ -138,6 +146,10 @@ func (m *Metrics) PeerDisconnected(id string) {
delete(m.peerLastActive, id) delete(m.peerLastActive, id)
} }
func (m *Metrics) RecordPeerReconnection() {
m.peerReconnections.Add(m.ctx, 1)
}
// PeerActivity increases the active connections // PeerActivity increases the active connections
func (m *Metrics) PeerActivity(peerID string) { func (m *Metrics) PeerActivity(peerID string) {
select { select {

View File

@ -18,12 +18,9 @@ type Listener struct {
TLSConfig *tls.Config TLSConfig *tls.Config
listener *quic.Listener listener *quic.Listener
acceptFn func(conn net.Conn)
} }
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
l.acceptFn = acceptFn
quicCfg := &quic.Config{ quicCfg := &quic.Config{
EnableDatagrams: true, EnableDatagrams: true,
InitialPacketSize: 1452, InitialPacketSize: 1452,
@ -49,7 +46,7 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
log.Infof("QUIC client connected from: %s", session.RemoteAddr()) log.Infof("QUIC client connected from: %s", session.RemoteAddr())
conn := NewConn(session) conn := NewConn(session)
l.acceptFn(conn) acceptFn(conn)
} }
} }

View File

@ -32,6 +32,9 @@ type Peer struct {
notifier *store.PeerNotifier notifier *store.PeerNotifier
peersListener *store.Listener peersListener *store.Listener
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread
notificationMutex sync.Mutex
} }
// NewPeer creates a new Peer instance and prepare custom logging // NewPeer creates a new Peer instance and prepare custom logging
@ -241,10 +244,16 @@ func (p *Peer) handleSubscribePeerState(msg []byte) {
} }
p.log.Debugf("received subscription message for %d peers", len(peerIDs)) p.log.Debugf("received subscription message for %d peers", len(peerIDs))
onlinePeers := p.peersListener.AddInterestedPeers(peerIDs)
// collect online peers to response back to the caller
p.notificationMutex.Lock()
defer p.notificationMutex.Unlock()
onlinePeers := p.store.GetOnlinePeersAndRegisterInterest(peerIDs, p.peersListener)
if len(onlinePeers) == 0 { if len(onlinePeers) == 0 {
return return
} }
p.log.Debugf("response with %d online peers", len(onlinePeers)) p.log.Debugf("response with %d online peers", len(onlinePeers))
p.sendPeersOnline(onlinePeers) p.sendPeersOnline(onlinePeers)
} }
@ -274,6 +283,9 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
} }
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) { func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
p.notificationMutex.Lock()
defer p.notificationMutex.Unlock()
msgs, err := messages.MarshalPeersWentOffline(peers) msgs, err := messages.MarshalPeersWentOffline(peers)
if err != nil { if err != nil {
p.log.Errorf("failed to marshal peer location message: %s", err) p.log.Errorf("failed to marshal peer location message: %s", err)

View File

@ -86,14 +86,13 @@ func NewRelay(config Config) (*Relay, error) {
return nil, fmt.Errorf("creating app metrics: %v", err) return nil, fmt.Errorf("creating app metrics: %v", err)
} }
peerStore := store.NewStore()
r := &Relay{ r := &Relay{
metrics: m, metrics: m,
metricsCancel: metricsCancel, metricsCancel: metricsCancel,
validator: config.AuthValidator, validator: config.AuthValidator,
instanceURL: config.instanceURL, instanceURL: config.instanceURL,
store: peerStore, store: store.NewStore(),
notifier: store.NewPeerNotifier(peerStore), notifier: store.NewPeerNotifier(),
} }
r.preparedMsg, err = newPreparedMsg(r.instanceURL) r.preparedMsg, err = newPreparedMsg(r.instanceURL)
@ -131,15 +130,18 @@ func (r *Relay) Accept(conn net.Conn) {
peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier) peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
storeTime := time.Now() storeTime := time.Now()
r.store.AddPeer(peer) if isReconnection := r.store.AddPeer(peer); isReconnection {
r.metrics.RecordPeerReconnection()
}
r.notifier.PeerCameOnline(peer.ID()) r.notifier.PeerCameOnline(peer.ID())
r.metrics.RecordPeerStoreTime(time.Since(storeTime)) r.metrics.RecordPeerStoreTime(time.Since(storeTime))
r.metrics.PeerConnected(peer.String()) r.metrics.PeerConnected(peer.String())
go func() { go func() {
peer.Work() peer.Work()
r.notifier.PeerWentOffline(peer.ID()) if deleted := r.store.DeletePeer(peer); deleted {
r.store.DeletePeer(peer) r.notifier.PeerWentOffline(peer.ID())
}
peer.log.Debugf("relay connection closed") peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String()) r.metrics.PeerDisconnected(peer.String())
}() }()

View File

@ -7,24 +7,27 @@ import (
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
) )
type Listener struct { type event struct {
ctx context.Context peerID messages.PeerID
store *Store online bool
}
onlineChan chan messages.PeerID type Listener struct {
offlineChan chan messages.PeerID ctx context.Context
eventChan chan *event
interestedPeersForOffline map[messages.PeerID]struct{} interestedPeersForOffline map[messages.PeerID]struct{}
interestedPeersForOnline map[messages.PeerID]struct{} interestedPeersForOnline map[messages.PeerID]struct{}
mu sync.RWMutex mu sync.RWMutex
} }
func newListener(ctx context.Context, store *Store) *Listener { func newListener(ctx context.Context) *Listener {
l := &Listener{ l := &Listener{
ctx: ctx, ctx: ctx,
store: store,
onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol // important to use a single channel for offline and online events because with it we can ensure all events
offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol // will be processed in the order they were sent
eventChan: make(chan *event, 244), //244 is the message size limit in the relay protocol
interestedPeersForOffline: make(map[messages.PeerID]struct{}), interestedPeersForOffline: make(map[messages.PeerID]struct{}),
interestedPeersForOnline: make(map[messages.PeerID]struct{}), interestedPeersForOnline: make(map[messages.PeerID]struct{}),
} }
@ -32,8 +35,7 @@ func newListener(ctx context.Context, store *Store) *Listener {
return l return l
} }
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID { func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) {
availablePeers := make([]messages.PeerID, 0)
l.mu.Lock() l.mu.Lock()
defer l.mu.Unlock() defer l.mu.Unlock()
@ -41,17 +43,6 @@ func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.Peer
l.interestedPeersForOnline[id] = struct{}{} l.interestedPeersForOnline[id] = struct{}{}
l.interestedPeersForOffline[id] = struct{}{} l.interestedPeersForOffline[id] = struct{}{}
} }
// collect online peers to response back to the caller
for _, id := range peerIDs {
_, ok := l.store.Peer(id)
if !ok {
continue
}
availablePeers = append(availablePeers, id)
}
return availablePeers
} }
func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) { func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
@ -61,7 +52,6 @@ func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
for _, id := range peerIDs { for _, id := range peerIDs {
delete(l.interestedPeersForOffline, id) delete(l.interestedPeersForOffline, id)
delete(l.interestedPeersForOnline, id) delete(l.interestedPeersForOnline, id)
} }
} }
@ -70,26 +60,31 @@ func (l *Listener) listenForEvents(onPeersComeOnline, onPeersWentOffline func([]
select { select {
case <-l.ctx.Done(): case <-l.ctx.Done():
return return
case pID := <-l.onlineChan: case e := <-l.eventChan:
peers := make([]messages.PeerID, 0) peersOffline := make([]messages.PeerID, 0)
peers = append(peers, pID) peersOnline := make([]messages.PeerID, 0)
if e.online {
for len(l.onlineChan) > 0 { peersOnline = append(peersOnline, e.peerID)
pID = <-l.onlineChan } else {
peers = append(peers, pID) peersOffline = append(peersOffline, e.peerID)
} }
onPeersComeOnline(peers) // Drain the channel to collect all events
case pID := <-l.offlineChan: for len(l.eventChan) > 0 {
peers := make([]messages.PeerID, 0) e = <-l.eventChan
peers = append(peers, pID) if e.online {
peersOnline = append(peersOnline, e.peerID)
for len(l.offlineChan) > 0 { } else {
pID = <-l.offlineChan peersOffline = append(peersOffline, e.peerID)
peers = append(peers, pID) }
} }
onPeersWentOffline(peers) if len(peersOnline) > 0 {
onPeersComeOnline(peersOnline)
}
if len(peersOffline) > 0 {
onPeersWentOffline(peersOffline)
}
} }
} }
} }
@ -100,7 +95,10 @@ func (l *Listener) peerWentOffline(peerID messages.PeerID) {
if _, ok := l.interestedPeersForOffline[peerID]; ok { if _, ok := l.interestedPeersForOffline[peerID]; ok {
select { select {
case l.offlineChan <- peerID: case l.eventChan <- &event{
peerID: peerID,
online: false,
}:
case <-l.ctx.Done(): case <-l.ctx.Done():
} }
} }
@ -112,9 +110,13 @@ func (l *Listener) peerComeOnline(peerID messages.PeerID) {
if _, ok := l.interestedPeersForOnline[peerID]; ok { if _, ok := l.interestedPeersForOnline[peerID]; ok {
select { select {
case l.onlineChan <- peerID: case l.eventChan <- &event{
peerID: peerID,
online: true,
}:
case <-l.ctx.Done(): case <-l.ctx.Done():
} }
delete(l.interestedPeersForOnline, peerID) delete(l.interestedPeersForOnline, peerID)
} }
} }

View File

@ -8,15 +8,12 @@ import (
) )
type PeerNotifier struct { type PeerNotifier struct {
store *Store
listeners map[*Listener]context.CancelFunc listeners map[*Listener]context.CancelFunc
listenersMutex sync.RWMutex listenersMutex sync.RWMutex
} }
func NewPeerNotifier(store *Store) *PeerNotifier { func NewPeerNotifier() *PeerNotifier {
pn := &PeerNotifier{ pn := &PeerNotifier{
store: store,
listeners: make(map[*Listener]context.CancelFunc), listeners: make(map[*Listener]context.CancelFunc),
} }
return pn return pn
@ -24,7 +21,7 @@ func NewPeerNotifier(store *Store) *PeerNotifier {
func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener { func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
listener := newListener(ctx, pn.store) listener := newListener(ctx)
go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline) go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline)
pn.listenersMutex.Lock() pn.listenersMutex.Lock()

View File

@ -26,7 +26,9 @@ func NewStore() *Store {
} }
// AddPeer adds a peer to the store // AddPeer adds a peer to the store
func (s *Store) AddPeer(peer IPeer) { // If the peer already exists, it will be replaced and the old peer will be closed
// Returns true if the peer was replaced, false if it was added for the first time.
func (s *Store) AddPeer(peer IPeer) bool {
s.peersLock.Lock() s.peersLock.Lock()
defer s.peersLock.Unlock() defer s.peersLock.Unlock()
odlPeer, ok := s.peers[peer.ID()] odlPeer, ok := s.peers[peer.ID()]
@ -35,22 +37,24 @@ func (s *Store) AddPeer(peer IPeer) {
} }
s.peers[peer.ID()] = peer s.peers[peer.ID()] = peer
return ok
} }
// DeletePeer deletes a peer from the store // DeletePeer deletes a peer from the store
func (s *Store) DeletePeer(peer IPeer) { func (s *Store) DeletePeer(peer IPeer) bool {
s.peersLock.Lock() s.peersLock.Lock()
defer s.peersLock.Unlock() defer s.peersLock.Unlock()
dp, ok := s.peers[peer.ID()] dp, ok := s.peers[peer.ID()]
if !ok { if !ok {
return return false
} }
if dp != peer { if dp != peer {
return return false
} }
delete(s.peers, peer.ID()) delete(s.peers, peer.ID())
return true
} }
// Peer returns a peer by its ID // Peer returns a peer by its ID
@ -73,3 +77,21 @@ func (s *Store) Peers() []IPeer {
} }
return peers return peers
} }
func (s *Store) GetOnlinePeersAndRegisterInterest(peerIDs []messages.PeerID, listener *Listener) []messages.PeerID {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
onlinePeers := make([]messages.PeerID, 0, len(peerIDs))
listener.AddInterestedPeers(peerIDs)
// Check for currently online peers
for _, id := range peerIDs {
if _, ok := s.peers[id]; ok {
onlinePeers = append(onlinePeers, id)
}
}
return onlinePeers
}