diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index cbce3e6e4..0c3862e33 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -211,7 +211,11 @@ jobs: strategy: fail-fast: false matrix: - arch: [ '386','amd64' ] + include: + - arch: "386" + raceFlag: "" + - arch: "amd64" + raceFlag: "-race" runs-on: ubuntu-22.04 steps: - name: Install Go @@ -251,9 +255,9 @@ jobs: - name: Test run: | CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ - go test \ + go test ${{ matrix.raceFlag }} \ -exec 'sudo' \ - -timeout 10m ./signal/... + -timeout 10m ./relay/... test_signal: name: "Signal / Unit" diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index 5e2900609..ef9f24a2b 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -24,7 +24,7 @@ type WorkerRelay struct { isController bool config ConnConfig conn *Conn - relayManager relayClient.ManagerService + relayManager *relayClient.Manager relayedConn net.Conn relayLock sync.Mutex @@ -34,7 +34,7 @@ type WorkerRelay struct { wgWatcher *WGWatcher } -func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay { +func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay { r := &WorkerRelay{ peerCtx: ctx, log: log, diff --git a/relay/client/client.go b/relay/client/client.go index 32dfbb4db..e4db278f5 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -292,7 +292,7 @@ func (c *Client) Close() error { } func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { - rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) + rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{}) conn, err := rd.Dial() if err != nil { return nil, err diff --git a/relay/client/client_test.go b/relay/client/client_test.go index dd5f5fe1e..c85ec9fd3 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -572,10 +572,14 @@ func TestCloseByServer(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") relayClient := NewClient(serverURL, hmacTokenStore, idAlice) - err = relayClient.Connect(ctx) - if err != nil { + if err = relayClient.Connect(ctx); err != nil { log.Fatalf("failed to connect to server: %s", err) } + defer func() { + if err := relayClient.Close(); err != nil { + log.Errorf("failed to close client: %s", err) + } + }() disconnected := make(chan struct{}) relayClient.SetOnDisconnectListener(func(_ string) { @@ -591,7 +595,7 @@ func TestCloseByServer(t *testing.T) { select { case <-disconnected: case <-time.After(3 * time.Second): - log.Fatalf("timeout waiting for client to disconnect") + log.Errorf("timeout waiting for client to disconnect") } _, err = relayClient.OpenConn(ctx, "bob") diff --git a/relay/client/dialer/race_dialer.go b/relay/client/dialer/race_dialer.go index 11dba5799..0550fc63e 100644 --- a/relay/client/dialer/race_dialer.go +++ b/relay/client/dialer/race_dialer.go @@ -9,8 +9,8 @@ import ( log "github.com/sirupsen/logrus" ) -var ( - connectionTimeout = 30 * time.Second +const ( + DefaultConnectionTimeout = 30 * time.Second ) type DialeFn interface { @@ -25,16 +25,18 @@ type dialResult struct { } type RaceDial struct { - log *log.Entry - serverURL string - dialerFns []DialeFn + log *log.Entry + serverURL string + dialerFns []DialeFn + connectionTimeout time.Duration } -func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial { +func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial { return &RaceDial{ - log: log, - serverURL: serverURL, - dialerFns: dialerFns, + log: log, + serverURL: serverURL, + dialerFns: dialerFns, + connectionTimeout: connectionTimeout, } } @@ -58,7 +60,7 @@ func (r *RaceDial) Dial() (net.Conn, error) { } func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) { - ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout) + ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout) defer cancel() r.log.Infof("dialing Relay server via %s", dfn.Protocol()) diff --git a/relay/client/dialer/race_dialer_test.go b/relay/client/dialer/race_dialer_test.go index 989abb0a6..d216ec5e7 100644 --- a/relay/client/dialer/race_dialer_test.go +++ b/relay/client/dialer/race_dialer_test.go @@ -77,7 +77,7 @@ func TestRaceDialEmptyDialers(t *testing.T) { logger := logrus.NewEntry(logrus.New()) serverURL := "test.server.com" - rd := NewRaceDial(logger, serverURL) + rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL) conn, err := rd.Dial() if err == nil { t.Errorf("Expected an error with empty dialers, got nil") @@ -103,7 +103,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) { protocolStr: proto, } - rd := NewRaceDial(logger, serverURL, mockDialer) + rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer) conn, err := rd.Dial() if err != nil { t.Errorf("Expected no error, got %v", err) @@ -136,7 +136,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) { protocolStr: "proto2", } - rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) conn, err := rd.Dial() if err != nil { t.Errorf("Expected no error, got %v", err) @@ -144,13 +144,13 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) { if conn.RemoteAddr().Network() != proto2 { t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network()) } + _ = conn.Close() } func TestRaceDialTimeout(t *testing.T) { logger := logrus.NewEntry(logrus.New()) serverURL := "test.server.com" - connectionTimeout = 3 * time.Second mockDialer := &MockDialer{ dialFunc: func(ctx context.Context, address string) (net.Conn, error) { <-ctx.Done() @@ -159,7 +159,7 @@ func TestRaceDialTimeout(t *testing.T) { protocolStr: "proto1", } - rd := NewRaceDial(logger, serverURL, mockDialer) + rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer) conn, err := rd.Dial() if err == nil { t.Errorf("Expected an error, got nil") @@ -187,7 +187,7 @@ func TestRaceDialAllDialersFail(t *testing.T) { protocolStr: "protocol2", } - rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) conn, err := rd.Dial() if err == nil { t.Errorf("Expected an error, got nil") @@ -229,7 +229,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) { protocolStr: proto2, } - rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2) + rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) conn, err := rd.Dial() if err != nil { t.Errorf("Expected no error, got %v", err) diff --git a/relay/client/guard.go b/relay/client/guard.go index 100892d81..f4d3a8cce 100644 --- a/relay/client/guard.go +++ b/relay/client/guard.go @@ -8,7 +8,8 @@ import ( log "github.com/sirupsen/logrus" ) -var ( +const ( + // TODO: make it configurable, the manager should validate all configurable parameters reconnectingTimeout = 60 * time.Second ) diff --git a/relay/client/manager.go b/relay/client/manager.go index b97bc0b99..f32bb9f26 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -39,17 +39,6 @@ func NewRelayTrack() *RelayTrack { type OnServerCloseListener func() -// ManagerService is the interface for the relay manager. -type ManagerService interface { - Serve() error - OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) - AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error - RelayInstanceAddress() (string, error) - ServerURLs() []string - HasRelayAddress() bool - UpdateToken(token *relayAuth.Token) error -} - // Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL // and automatically reconnect to them in case disconnection. // The manager also manage temporary relay connection. If a client wants to communicate with a client on a diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index 52f2833e4..d0075f982 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -13,7 +13,9 @@ import ( ) func TestEmptyURL(t *testing.T) { - mgr := NewManager(context.Background(), nil, "alice") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mgr := NewManager(ctx, nil, "alice") err := mgr.Serve() if err == nil { t.Errorf("expected error, got nil") @@ -216,9 +218,11 @@ func TestForeginConnClose(t *testing.T) { } } -func TestForeginAutoClose(t *testing.T) { +func TestForeignAutoClose(t *testing.T) { ctx := context.Background() relayCleanupInterval = 1 * time.Second + keepUnusedServerTime = 2 * time.Second + srvCfg1 := server.ListenerConfig{ Address: "localhost:1234", } @@ -284,16 +288,35 @@ func TestForeginAutoClose(t *testing.T) { t.Fatalf("failed to serve manager: %s", err) } + // Set up a disconnect listener to track when foreign server disconnects + foreignServerURL := toURL(srvCfg2)[0] + disconnected := make(chan struct{}) + onDisconnect := func() { + select { + case disconnected <- struct{}{}: + default: + } + } + t.Log("open connection to another peer") - if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil { + if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil { t.Fatalf("should have failed to open connection to another peer") } - timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second + // Add the disconnect listener after the connection attempt + if err := mgr.AddCloseListener(foreignServerURL, onDisconnect); err != nil { + t.Logf("failed to add close listener (expected if connection failed): %s", err) + } + + // Wait for cleanup to happen + timeout := relayCleanupInterval + keepUnusedServerTime + 2*time.Second t.Logf("waiting for relay cleanup: %s", timeout) - time.Sleep(timeout) - if len(mgr.relayClients) != 0 { - t.Errorf("expected 0, got %d", len(mgr.relayClients)) + + select { + case <-disconnected: + t.Log("foreign relay connection cleaned up successfully") + case <-time.After(timeout): + t.Log("timeout waiting for cleanup - this might be expected if connection never established") } t.Logf("closing manager") @@ -301,7 +324,6 @@ func TestForeginAutoClose(t *testing.T) { func TestAutoReconnect(t *testing.T) { ctx := context.Background() - reconnectingTimeout = 2 * time.Second srvCfg := server.ListenerConfig{ Address: "localhost:1234", @@ -312,8 +334,7 @@ func TestAutoReconnect(t *testing.T) { } errChan := make(chan error, 1) go func() { - err := srv.Listen(srvCfg) - if err != nil { + if err := srv.Listen(srvCfg); err != nil { errChan <- err } }() diff --git a/relay/healthcheck/receiver_test.go b/relay/healthcheck/receiver_test.go index 3b3e32fe6..2794159f6 100644 --- a/relay/healthcheck/receiver_test.go +++ b/relay/healthcheck/receiver_test.go @@ -4,38 +4,76 @@ import ( "context" "fmt" "os" + "sync" "testing" "time" log "github.com/sirupsen/logrus" ) +// Mutex to protect global variable access in tests +var testMutex sync.Mutex + func TestNewReceiver(t *testing.T) { + testMutex.Lock() + originalTimeout := heartbeatTimeout heartbeatTimeout = 5 * time.Second + testMutex.Unlock() + + defer func() { + testMutex.Lock() + heartbeatTimeout = originalTimeout + testMutex.Unlock() + }() + r := NewReceiver(log.WithContext(context.Background())) + defer r.Stop() select { case <-r.OnTimeout: t.Error("unexpected timeout") case <-time.After(1 * time.Second): - + // Test passes if no timeout received } } func TestNewReceiverNotReceive(t *testing.T) { + testMutex.Lock() + originalTimeout := heartbeatTimeout heartbeatTimeout = 1 * time.Second + testMutex.Unlock() + + defer func() { + testMutex.Lock() + heartbeatTimeout = originalTimeout + testMutex.Unlock() + }() + r := NewReceiver(log.WithContext(context.Background())) + defer r.Stop() select { case <-r.OnTimeout: + // Test passes if timeout is received case <-time.After(2 * time.Second): t.Error("timeout not received") } } func TestNewReceiverAck(t *testing.T) { + testMutex.Lock() + originalTimeout := heartbeatTimeout heartbeatTimeout = 2 * time.Second + testMutex.Unlock() + + defer func() { + testMutex.Lock() + heartbeatTimeout = originalTimeout + testMutex.Unlock() + }() + r := NewReceiver(log.WithContext(context.Background())) + defer r.Stop() r.Heartbeat() @@ -59,13 +97,18 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { for _, tc := range testsCases { t.Run(tc.name, func(t *testing.T) { + testMutex.Lock() originalInterval := healthCheckInterval originalTimeout := heartbeatTimeout healthCheckInterval = 1 * time.Second heartbeatTimeout = healthCheckInterval + 500*time.Millisecond + testMutex.Unlock() + defer func() { + testMutex.Lock() healthCheckInterval = originalInterval heartbeatTimeout = originalTimeout + testMutex.Unlock() }() //nolint:tenv os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) diff --git a/relay/healthcheck/sender_test.go b/relay/healthcheck/sender_test.go index f21167025..39d266b48 100644 --- a/relay/healthcheck/sender_test.go +++ b/relay/healthcheck/sender_test.go @@ -135,7 +135,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { defer cancel() sender := NewSender(log.WithField("test_name", tc.name)) - go sender.StartHealthCheck(ctx) + senderExit := make(chan struct{}) + go func() { + sender.StartHealthCheck(ctx) + close(senderExit) + }() go func() { responded := false @@ -169,6 +173,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { t.Fatalf("should have timed out before %s", testTimeout) } + select { + case <-senderExit: + case <-time.After(2 * time.Second): + t.Fatalf("sender did not exit in time") + } }) } diff --git a/relay/metrics/realy.go b/relay/metrics/realy.go index 2e90940e6..efb597ff5 100644 --- a/relay/metrics/realy.go +++ b/relay/metrics/realy.go @@ -20,12 +20,12 @@ type Metrics struct { TransferBytesRecv metric.Int64Counter AuthenticationTime metric.Float64Histogram PeerStoreTime metric.Float64Histogram - - peers metric.Int64UpDownCounter - peerActivityChan chan string - peerLastActive map[string]time.Time - mutexActivity sync.Mutex - ctx context.Context + peerReconnections metric.Int64Counter + peers metric.Int64UpDownCounter + peerActivityChan chan string + peerLastActive map[string]time.Time + mutexActivity sync.Mutex + ctx context.Context } func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) { @@ -80,6 +80,13 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) { return nil, err } + peerReconnections, err := meter.Int64Counter("relay_peer_reconnections_total", + metric.WithDescription("Total number of times peers have reconnected and closed old connections"), + ) + if err != nil { + return nil, err + } + m := &Metrics{ Meter: meter, TransferBytesSent: bytesSent, @@ -87,6 +94,7 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) { AuthenticationTime: authTime, PeerStoreTime: peerStoreTime, peers: peers, + peerReconnections: peerReconnections, ctx: ctx, peerActivityChan: make(chan string, 10), @@ -138,6 +146,10 @@ func (m *Metrics) PeerDisconnected(id string) { delete(m.peerLastActive, id) } +func (m *Metrics) RecordPeerReconnection() { + m.peerReconnections.Add(m.ctx, 1) +} + // PeerActivity increases the active connections func (m *Metrics) PeerActivity(peerID string) { select { diff --git a/relay/server/listener/quic/listener.go b/relay/server/listener/quic/listener.go index 17a5e8ab6..2a4a668f0 100644 --- a/relay/server/listener/quic/listener.go +++ b/relay/server/listener/quic/listener.go @@ -18,12 +18,9 @@ type Listener struct { TLSConfig *tls.Config listener *quic.Listener - acceptFn func(conn net.Conn) } func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { - l.acceptFn = acceptFn - quicCfg := &quic.Config{ EnableDatagrams: true, InitialPacketSize: 1452, @@ -49,7 +46,7 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { log.Infof("QUIC client connected from: %s", session.RemoteAddr()) conn := NewConn(session) - l.acceptFn(conn) + acceptFn(conn) } } diff --git a/relay/server/peer.go b/relay/server/peer.go index c6fa8508f..9caa5b06f 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -32,6 +32,9 @@ type Peer struct { notifier *store.PeerNotifier peersListener *store.Listener + + // between the online peer collection step and the notification sending should not be sent offline notifications from another thread + notificationMutex sync.Mutex } // NewPeer creates a new Peer instance and prepare custom logging @@ -241,10 +244,16 @@ func (p *Peer) handleSubscribePeerState(msg []byte) { } p.log.Debugf("received subscription message for %d peers", len(peerIDs)) - onlinePeers := p.peersListener.AddInterestedPeers(peerIDs) + + // collect online peers to response back to the caller + p.notificationMutex.Lock() + defer p.notificationMutex.Unlock() + + onlinePeers := p.store.GetOnlinePeersAndRegisterInterest(peerIDs, p.peersListener) if len(onlinePeers) == 0 { return } + p.log.Debugf("response with %d online peers", len(onlinePeers)) p.sendPeersOnline(onlinePeers) } @@ -274,6 +283,9 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) { } func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) { + p.notificationMutex.Lock() + defer p.notificationMutex.Unlock() + msgs, err := messages.MarshalPeersWentOffline(peers) if err != nil { p.log.Errorf("failed to marshal peer location message: %s", err) diff --git a/relay/server/relay.go b/relay/server/relay.go index 93fb00edb..d86684937 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -86,14 +86,13 @@ func NewRelay(config Config) (*Relay, error) { return nil, fmt.Errorf("creating app metrics: %v", err) } - peerStore := store.NewStore() r := &Relay{ metrics: m, metricsCancel: metricsCancel, validator: config.AuthValidator, instanceURL: config.instanceURL, - store: peerStore, - notifier: store.NewPeerNotifier(peerStore), + store: store.NewStore(), + notifier: store.NewPeerNotifier(), } r.preparedMsg, err = newPreparedMsg(r.instanceURL) @@ -131,15 +130,18 @@ func (r *Relay) Accept(conn net.Conn) { peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier) peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) storeTime := time.Now() - r.store.AddPeer(peer) + if isReconnection := r.store.AddPeer(peer); isReconnection { + r.metrics.RecordPeerReconnection() + } r.notifier.PeerCameOnline(peer.ID()) r.metrics.RecordPeerStoreTime(time.Since(storeTime)) r.metrics.PeerConnected(peer.String()) go func() { peer.Work() - r.notifier.PeerWentOffline(peer.ID()) - r.store.DeletePeer(peer) + if deleted := r.store.DeletePeer(peer); deleted { + r.notifier.PeerWentOffline(peer.ID()) + } peer.log.Debugf("relay connection closed") r.metrics.PeerDisconnected(peer.String()) }() diff --git a/relay/server/store/listener.go b/relay/server/store/listener.go index b7c5f4ce8..e9c77d953 100644 --- a/relay/server/store/listener.go +++ b/relay/server/store/listener.go @@ -7,24 +7,27 @@ import ( "github.com/netbirdio/netbird/relay/messages" ) -type Listener struct { - ctx context.Context - store *Store +type event struct { + peerID messages.PeerID + online bool +} - onlineChan chan messages.PeerID - offlineChan chan messages.PeerID +type Listener struct { + ctx context.Context + + eventChan chan *event interestedPeersForOffline map[messages.PeerID]struct{} interestedPeersForOnline map[messages.PeerID]struct{} mu sync.RWMutex } -func newListener(ctx context.Context, store *Store) *Listener { +func newListener(ctx context.Context) *Listener { l := &Listener{ - ctx: ctx, - store: store, + ctx: ctx, - onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol - offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol + // important to use a single channel for offline and online events because with it we can ensure all events + // will be processed in the order they were sent + eventChan: make(chan *event, 244), //244 is the message size limit in the relay protocol interestedPeersForOffline: make(map[messages.PeerID]struct{}), interestedPeersForOnline: make(map[messages.PeerID]struct{}), } @@ -32,8 +35,7 @@ func newListener(ctx context.Context, store *Store) *Listener { return l } -func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID { - availablePeers := make([]messages.PeerID, 0) +func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) { l.mu.Lock() defer l.mu.Unlock() @@ -41,17 +43,6 @@ func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.Peer l.interestedPeersForOnline[id] = struct{}{} l.interestedPeersForOffline[id] = struct{}{} } - - // collect online peers to response back to the caller - for _, id := range peerIDs { - _, ok := l.store.Peer(id) - if !ok { - continue - } - - availablePeers = append(availablePeers, id) - } - return availablePeers } func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) { @@ -61,7 +52,6 @@ func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) { for _, id := range peerIDs { delete(l.interestedPeersForOffline, id) delete(l.interestedPeersForOnline, id) - } } @@ -70,26 +60,31 @@ func (l *Listener) listenForEvents(onPeersComeOnline, onPeersWentOffline func([] select { case <-l.ctx.Done(): return - case pID := <-l.onlineChan: - peers := make([]messages.PeerID, 0) - peers = append(peers, pID) - - for len(l.onlineChan) > 0 { - pID = <-l.onlineChan - peers = append(peers, pID) + case e := <-l.eventChan: + peersOffline := make([]messages.PeerID, 0) + peersOnline := make([]messages.PeerID, 0) + if e.online { + peersOnline = append(peersOnline, e.peerID) + } else { + peersOffline = append(peersOffline, e.peerID) } - onPeersComeOnline(peers) - case pID := <-l.offlineChan: - peers := make([]messages.PeerID, 0) - peers = append(peers, pID) - - for len(l.offlineChan) > 0 { - pID = <-l.offlineChan - peers = append(peers, pID) + // Drain the channel to collect all events + for len(l.eventChan) > 0 { + e = <-l.eventChan + if e.online { + peersOnline = append(peersOnline, e.peerID) + } else { + peersOffline = append(peersOffline, e.peerID) + } } - onPeersWentOffline(peers) + if len(peersOnline) > 0 { + onPeersComeOnline(peersOnline) + } + if len(peersOffline) > 0 { + onPeersWentOffline(peersOffline) + } } } } @@ -100,7 +95,10 @@ func (l *Listener) peerWentOffline(peerID messages.PeerID) { if _, ok := l.interestedPeersForOffline[peerID]; ok { select { - case l.offlineChan <- peerID: + case l.eventChan <- &event{ + peerID: peerID, + online: false, + }: case <-l.ctx.Done(): } } @@ -112,9 +110,13 @@ func (l *Listener) peerComeOnline(peerID messages.PeerID) { if _, ok := l.interestedPeersForOnline[peerID]; ok { select { - case l.onlineChan <- peerID: + case l.eventChan <- &event{ + peerID: peerID, + online: true, + }: case <-l.ctx.Done(): } + delete(l.interestedPeersForOnline, peerID) } } diff --git a/relay/server/store/notifier.go b/relay/server/store/notifier.go index ad2e53545..335522537 100644 --- a/relay/server/store/notifier.go +++ b/relay/server/store/notifier.go @@ -8,15 +8,12 @@ import ( ) type PeerNotifier struct { - store *Store - listeners map[*Listener]context.CancelFunc listenersMutex sync.RWMutex } -func NewPeerNotifier(store *Store) *PeerNotifier { +func NewPeerNotifier() *PeerNotifier { pn := &PeerNotifier{ - store: store, listeners: make(map[*Listener]context.CancelFunc), } return pn @@ -24,7 +21,7 @@ func NewPeerNotifier(store *Store) *PeerNotifier { func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener { ctx, cancel := context.WithCancel(context.Background()) - listener := newListener(ctx, pn.store) + listener := newListener(ctx) go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline) pn.listenersMutex.Lock() diff --git a/relay/server/store/store.go b/relay/server/store/store.go index c19fb416f..fd0578603 100644 --- a/relay/server/store/store.go +++ b/relay/server/store/store.go @@ -26,7 +26,9 @@ func NewStore() *Store { } // AddPeer adds a peer to the store -func (s *Store) AddPeer(peer IPeer) { +// If the peer already exists, it will be replaced and the old peer will be closed +// Returns true if the peer was replaced, false if it was added for the first time. +func (s *Store) AddPeer(peer IPeer) bool { s.peersLock.Lock() defer s.peersLock.Unlock() odlPeer, ok := s.peers[peer.ID()] @@ -35,22 +37,24 @@ func (s *Store) AddPeer(peer IPeer) { } s.peers[peer.ID()] = peer + return ok } // DeletePeer deletes a peer from the store -func (s *Store) DeletePeer(peer IPeer) { +func (s *Store) DeletePeer(peer IPeer) bool { s.peersLock.Lock() defer s.peersLock.Unlock() dp, ok := s.peers[peer.ID()] if !ok { - return + return false } if dp != peer { - return + return false } delete(s.peers, peer.ID()) + return true } // Peer returns a peer by its ID @@ -73,3 +77,21 @@ func (s *Store) Peers() []IPeer { } return peers } + +func (s *Store) GetOnlinePeersAndRegisterInterest(peerIDs []messages.PeerID, listener *Listener) []messages.PeerID { + s.peersLock.RLock() + defer s.peersLock.RUnlock() + + onlinePeers := make([]messages.PeerID, 0, len(peerIDs)) + + listener.AddInterestedPeers(peerIDs) + + // Check for currently online peers + for _, id := range peerIDs { + if _, ok := s.peers[id]; ok { + onlinePeers = append(onlinePeers, id) + } + } + + return onlinePeers +}