[server, relay] Fix/relay race disconnection (#4174)

Avoid invalid disconnection notifications in case the closed race dials.
In this PR resolve multiple race condition questions. Easier to understand the fix based on commit by commit.

- Remove store dependency from notifier
- Enforce the notification orders
- Fix invalid disconnection notification
- Ensure the order of the events on the consumer side
This commit is contained in:
Zoltan Papp
2025-07-21 19:58:17 +02:00
committed by GitHub
parent a7af15c4fc
commit 86c16cf651
18 changed files with 235 additions and 118 deletions

View File

@ -211,7 +211,11 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
include:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: "-race"
runs-on: ubuntu-22.04
steps:
- name: Install Go
@ -251,9 +255,9 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
go test ${{ matrix.raceFlag }} \
-exec 'sudo' \
-timeout 10m ./signal/...
-timeout 10m ./relay/...
test_signal:
name: "Signal / Unit"

View File

@ -24,7 +24,7 @@ type WorkerRelay struct {
isController bool
config ConnConfig
conn *Conn
relayManager relayClient.ManagerService
relayManager *relayClient.Manager
relayedConn net.Conn
relayLock sync.Mutex
@ -34,7 +34,7 @@ type WorkerRelay struct {
wgWatcher *WGWatcher
}
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
r := &WorkerRelay{
peerCtx: ctx,
log: log,

View File

@ -292,7 +292,7 @@ func (c *Client) Close() error {
}
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial()
if err != nil {
return nil, err

View File

@ -572,10 +572,14 @@ func TestCloseByServer(t *testing.T) {
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect(ctx)
if err != nil {
if err = relayClient.Connect(ctx); err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
defer func() {
if err := relayClient.Close(); err != nil {
log.Errorf("failed to close client: %s", err)
}
}()
disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func(_ string) {
@ -591,7 +595,7 @@ func TestCloseByServer(t *testing.T) {
select {
case <-disconnected:
case <-time.After(3 * time.Second):
log.Fatalf("timeout waiting for client to disconnect")
log.Errorf("timeout waiting for client to disconnect")
}
_, err = relayClient.OpenConn(ctx, "bob")

View File

@ -9,8 +9,8 @@ import (
log "github.com/sirupsen/logrus"
)
var (
connectionTimeout = 30 * time.Second
const (
DefaultConnectionTimeout = 30 * time.Second
)
type DialeFn interface {
@ -25,16 +25,18 @@ type dialResult struct {
}
type RaceDial struct {
log *log.Entry
serverURL string
dialerFns []DialeFn
log *log.Entry
serverURL string
dialerFns []DialeFn
connectionTimeout time.Duration
}
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial {
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
return &RaceDial{
log: log,
serverURL: serverURL,
dialerFns: dialerFns,
log: log,
serverURL: serverURL,
dialerFns: dialerFns,
connectionTimeout: connectionTimeout,
}
}
@ -58,7 +60,7 @@ func (r *RaceDial) Dial() (net.Conn, error) {
}
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout)
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
defer cancel()
r.log.Infof("dialing Relay server via %s", dfn.Protocol())

View File

@ -77,7 +77,7 @@ func TestRaceDialEmptyDialers(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
rd := NewRaceDial(logger, serverURL)
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error with empty dialers, got nil")
@ -103,7 +103,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
protocolStr: proto,
}
rd := NewRaceDial(logger, serverURL, mockDialer)
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
@ -136,7 +136,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
protocolStr: "proto2",
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
@ -144,13 +144,13 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
if conn.RemoteAddr().Network() != proto2 {
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
}
_ = conn.Close()
}
func TestRaceDialTimeout(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
connectionTimeout = 3 * time.Second
mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done()
@ -159,7 +159,7 @@ func TestRaceDialTimeout(t *testing.T) {
protocolStr: "proto1",
}
rd := NewRaceDial(logger, serverURL, mockDialer)
rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
@ -187,7 +187,7 @@ func TestRaceDialAllDialersFail(t *testing.T) {
protocolStr: "protocol2",
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
@ -229,7 +229,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
protocolStr: proto2,
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)

View File

@ -8,7 +8,8 @@ import (
log "github.com/sirupsen/logrus"
)
var (
const (
// TODO: make it configurable, the manager should validate all configurable parameters
reconnectingTimeout = 60 * time.Second
)

View File

@ -39,17 +39,6 @@ func NewRelayTrack() *RelayTrack {
type OnServerCloseListener func()
// ManagerService is the interface for the relay manager.
type ManagerService interface {
Serve() error
OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error)
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
RelayInstanceAddress() (string, error)
ServerURLs() []string
HasRelayAddress() bool
UpdateToken(token *relayAuth.Token) error
}
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
// and automatically reconnect to them in case disconnection.
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a

View File

@ -13,7 +13,9 @@ import (
)
func TestEmptyURL(t *testing.T) {
mgr := NewManager(context.Background(), nil, "alice")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mgr := NewManager(ctx, nil, "alice")
err := mgr.Serve()
if err == nil {
t.Errorf("expected error, got nil")
@ -216,9 +218,11 @@ func TestForeginConnClose(t *testing.T) {
}
}
func TestForeginAutoClose(t *testing.T) {
func TestForeignAutoClose(t *testing.T) {
ctx := context.Background()
relayCleanupInterval = 1 * time.Second
keepUnusedServerTime = 2 * time.Second
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
@ -284,16 +288,35 @@ func TestForeginAutoClose(t *testing.T) {
t.Fatalf("failed to serve manager: %s", err)
}
// Set up a disconnect listener to track when foreign server disconnects
foreignServerURL := toURL(srvCfg2)[0]
disconnected := make(chan struct{})
onDisconnect := func() {
select {
case disconnected <- struct{}{}:
default:
}
}
t.Log("open connection to another peer")
if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil {
if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil {
t.Fatalf("should have failed to open connection to another peer")
}
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
// Add the disconnect listener after the connection attempt
if err := mgr.AddCloseListener(foreignServerURL, onDisconnect); err != nil {
t.Logf("failed to add close listener (expected if connection failed): %s", err)
}
// Wait for cleanup to happen
timeout := relayCleanupInterval + keepUnusedServerTime + 2*time.Second
t.Logf("waiting for relay cleanup: %s", timeout)
time.Sleep(timeout)
if len(mgr.relayClients) != 0 {
t.Errorf("expected 0, got %d", len(mgr.relayClients))
select {
case <-disconnected:
t.Log("foreign relay connection cleaned up successfully")
case <-time.After(timeout):
t.Log("timeout waiting for cleanup - this might be expected if connection never established")
}
t.Logf("closing manager")
@ -301,7 +324,6 @@ func TestForeginAutoClose(t *testing.T) {
func TestAutoReconnect(t *testing.T) {
ctx := context.Background()
reconnectingTimeout = 2 * time.Second
srvCfg := server.ListenerConfig{
Address: "localhost:1234",
@ -312,8 +334,7 @@ func TestAutoReconnect(t *testing.T) {
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
if err := srv.Listen(srvCfg); err != nil {
errChan <- err
}
}()

View File

@ -4,38 +4,76 @@ import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
// Mutex to protect global variable access in tests
var testMutex sync.Mutex
func TestNewReceiver(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 5 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select {
case <-r.OnTimeout:
t.Error("unexpected timeout")
case <-time.After(1 * time.Second):
// Test passes if no timeout received
}
}
func TestNewReceiverNotReceive(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 1 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select {
case <-r.OnTimeout:
// Test passes if timeout is received
case <-time.After(2 * time.Second):
t.Error("timeout not received")
}
}
func TestNewReceiverAck(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 2 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
r.Heartbeat()
@ -59,13 +97,18 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
testMutex.Lock()
originalInterval := healthCheckInterval
originalTimeout := heartbeatTimeout
healthCheckInterval = 1 * time.Second
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
testMutex.Unlock()
defer func() {
testMutex.Lock()
healthCheckInterval = originalInterval
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))

View File

@ -135,7 +135,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
defer cancel()
sender := NewSender(log.WithField("test_name", tc.name))
go sender.StartHealthCheck(ctx)
senderExit := make(chan struct{})
go func() {
sender.StartHealthCheck(ctx)
close(senderExit)
}()
go func() {
responded := false
@ -169,6 +173,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
t.Fatalf("should have timed out before %s", testTimeout)
}
select {
case <-senderExit:
case <-time.After(2 * time.Second):
t.Fatalf("sender did not exit in time")
}
})
}

View File

@ -20,12 +20,12 @@ type Metrics struct {
TransferBytesRecv metric.Int64Counter
AuthenticationTime metric.Float64Histogram
PeerStoreTime metric.Float64Histogram
peers metric.Int64UpDownCounter
peerActivityChan chan string
peerLastActive map[string]time.Time
mutexActivity sync.Mutex
ctx context.Context
peerReconnections metric.Int64Counter
peers metric.Int64UpDownCounter
peerActivityChan chan string
peerLastActive map[string]time.Time
mutexActivity sync.Mutex
ctx context.Context
}
func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
@ -80,6 +80,13 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
return nil, err
}
peerReconnections, err := meter.Int64Counter("relay_peer_reconnections_total",
metric.WithDescription("Total number of times peers have reconnected and closed old connections"),
)
if err != nil {
return nil, err
}
m := &Metrics{
Meter: meter,
TransferBytesSent: bytesSent,
@ -87,6 +94,7 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
AuthenticationTime: authTime,
PeerStoreTime: peerStoreTime,
peers: peers,
peerReconnections: peerReconnections,
ctx: ctx,
peerActivityChan: make(chan string, 10),
@ -138,6 +146,10 @@ func (m *Metrics) PeerDisconnected(id string) {
delete(m.peerLastActive, id)
}
func (m *Metrics) RecordPeerReconnection() {
m.peerReconnections.Add(m.ctx, 1)
}
// PeerActivity increases the active connections
func (m *Metrics) PeerActivity(peerID string) {
select {

View File

@ -18,12 +18,9 @@ type Listener struct {
TLSConfig *tls.Config
listener *quic.Listener
acceptFn func(conn net.Conn)
}
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
l.acceptFn = acceptFn
quicCfg := &quic.Config{
EnableDatagrams: true,
InitialPacketSize: 1452,
@ -49,7 +46,7 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
log.Infof("QUIC client connected from: %s", session.RemoteAddr())
conn := NewConn(session)
l.acceptFn(conn)
acceptFn(conn)
}
}

View File

@ -32,6 +32,9 @@ type Peer struct {
notifier *store.PeerNotifier
peersListener *store.Listener
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread
notificationMutex sync.Mutex
}
// NewPeer creates a new Peer instance and prepare custom logging
@ -241,10 +244,16 @@ func (p *Peer) handleSubscribePeerState(msg []byte) {
}
p.log.Debugf("received subscription message for %d peers", len(peerIDs))
onlinePeers := p.peersListener.AddInterestedPeers(peerIDs)
// collect online peers to response back to the caller
p.notificationMutex.Lock()
defer p.notificationMutex.Unlock()
onlinePeers := p.store.GetOnlinePeersAndRegisterInterest(peerIDs, p.peersListener)
if len(onlinePeers) == 0 {
return
}
p.log.Debugf("response with %d online peers", len(onlinePeers))
p.sendPeersOnline(onlinePeers)
}
@ -274,6 +283,9 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
}
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
p.notificationMutex.Lock()
defer p.notificationMutex.Unlock()
msgs, err := messages.MarshalPeersWentOffline(peers)
if err != nil {
p.log.Errorf("failed to marshal peer location message: %s", err)

View File

@ -86,14 +86,13 @@ func NewRelay(config Config) (*Relay, error) {
return nil, fmt.Errorf("creating app metrics: %v", err)
}
peerStore := store.NewStore()
r := &Relay{
metrics: m,
metricsCancel: metricsCancel,
validator: config.AuthValidator,
instanceURL: config.instanceURL,
store: peerStore,
notifier: store.NewPeerNotifier(peerStore),
store: store.NewStore(),
notifier: store.NewPeerNotifier(),
}
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
@ -131,15 +130,18 @@ func (r *Relay) Accept(conn net.Conn) {
peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
storeTime := time.Now()
r.store.AddPeer(peer)
if isReconnection := r.store.AddPeer(peer); isReconnection {
r.metrics.RecordPeerReconnection()
}
r.notifier.PeerCameOnline(peer.ID())
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
r.metrics.PeerConnected(peer.String())
go func() {
peer.Work()
r.notifier.PeerWentOffline(peer.ID())
r.store.DeletePeer(peer)
if deleted := r.store.DeletePeer(peer); deleted {
r.notifier.PeerWentOffline(peer.ID())
}
peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String())
}()

View File

@ -7,24 +7,27 @@ import (
"github.com/netbirdio/netbird/relay/messages"
)
type Listener struct {
ctx context.Context
store *Store
type event struct {
peerID messages.PeerID
online bool
}
onlineChan chan messages.PeerID
offlineChan chan messages.PeerID
type Listener struct {
ctx context.Context
eventChan chan *event
interestedPeersForOffline map[messages.PeerID]struct{}
interestedPeersForOnline map[messages.PeerID]struct{}
mu sync.RWMutex
}
func newListener(ctx context.Context, store *Store) *Listener {
func newListener(ctx context.Context) *Listener {
l := &Listener{
ctx: ctx,
store: store,
ctx: ctx,
onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
// important to use a single channel for offline and online events because with it we can ensure all events
// will be processed in the order they were sent
eventChan: make(chan *event, 244), //244 is the message size limit in the relay protocol
interestedPeersForOffline: make(map[messages.PeerID]struct{}),
interestedPeersForOnline: make(map[messages.PeerID]struct{}),
}
@ -32,8 +35,7 @@ func newListener(ctx context.Context, store *Store) *Listener {
return l
}
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID {
availablePeers := make([]messages.PeerID, 0)
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) {
l.mu.Lock()
defer l.mu.Unlock()
@ -41,17 +43,6 @@ func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.Peer
l.interestedPeersForOnline[id] = struct{}{}
l.interestedPeersForOffline[id] = struct{}{}
}
// collect online peers to response back to the caller
for _, id := range peerIDs {
_, ok := l.store.Peer(id)
if !ok {
continue
}
availablePeers = append(availablePeers, id)
}
return availablePeers
}
func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
@ -61,7 +52,6 @@ func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
for _, id := range peerIDs {
delete(l.interestedPeersForOffline, id)
delete(l.interestedPeersForOnline, id)
}
}
@ -70,26 +60,31 @@ func (l *Listener) listenForEvents(onPeersComeOnline, onPeersWentOffline func([]
select {
case <-l.ctx.Done():
return
case pID := <-l.onlineChan:
peers := make([]messages.PeerID, 0)
peers = append(peers, pID)
for len(l.onlineChan) > 0 {
pID = <-l.onlineChan
peers = append(peers, pID)
case e := <-l.eventChan:
peersOffline := make([]messages.PeerID, 0)
peersOnline := make([]messages.PeerID, 0)
if e.online {
peersOnline = append(peersOnline, e.peerID)
} else {
peersOffline = append(peersOffline, e.peerID)
}
onPeersComeOnline(peers)
case pID := <-l.offlineChan:
peers := make([]messages.PeerID, 0)
peers = append(peers, pID)
for len(l.offlineChan) > 0 {
pID = <-l.offlineChan
peers = append(peers, pID)
// Drain the channel to collect all events
for len(l.eventChan) > 0 {
e = <-l.eventChan
if e.online {
peersOnline = append(peersOnline, e.peerID)
} else {
peersOffline = append(peersOffline, e.peerID)
}
}
onPeersWentOffline(peers)
if len(peersOnline) > 0 {
onPeersComeOnline(peersOnline)
}
if len(peersOffline) > 0 {
onPeersWentOffline(peersOffline)
}
}
}
}
@ -100,7 +95,10 @@ func (l *Listener) peerWentOffline(peerID messages.PeerID) {
if _, ok := l.interestedPeersForOffline[peerID]; ok {
select {
case l.offlineChan <- peerID:
case l.eventChan <- &event{
peerID: peerID,
online: false,
}:
case <-l.ctx.Done():
}
}
@ -112,9 +110,13 @@ func (l *Listener) peerComeOnline(peerID messages.PeerID) {
if _, ok := l.interestedPeersForOnline[peerID]; ok {
select {
case l.onlineChan <- peerID:
case l.eventChan <- &event{
peerID: peerID,
online: true,
}:
case <-l.ctx.Done():
}
delete(l.interestedPeersForOnline, peerID)
}
}

View File

@ -8,15 +8,12 @@ import (
)
type PeerNotifier struct {
store *Store
listeners map[*Listener]context.CancelFunc
listenersMutex sync.RWMutex
}
func NewPeerNotifier(store *Store) *PeerNotifier {
func NewPeerNotifier() *PeerNotifier {
pn := &PeerNotifier{
store: store,
listeners: make(map[*Listener]context.CancelFunc),
}
return pn
@ -24,7 +21,7 @@ func NewPeerNotifier(store *Store) *PeerNotifier {
func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener {
ctx, cancel := context.WithCancel(context.Background())
listener := newListener(ctx, pn.store)
listener := newListener(ctx)
go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline)
pn.listenersMutex.Lock()

View File

@ -26,7 +26,9 @@ func NewStore() *Store {
}
// AddPeer adds a peer to the store
func (s *Store) AddPeer(peer IPeer) {
// If the peer already exists, it will be replaced and the old peer will be closed
// Returns true if the peer was replaced, false if it was added for the first time.
func (s *Store) AddPeer(peer IPeer) bool {
s.peersLock.Lock()
defer s.peersLock.Unlock()
odlPeer, ok := s.peers[peer.ID()]
@ -35,22 +37,24 @@ func (s *Store) AddPeer(peer IPeer) {
}
s.peers[peer.ID()] = peer
return ok
}
// DeletePeer deletes a peer from the store
func (s *Store) DeletePeer(peer IPeer) {
func (s *Store) DeletePeer(peer IPeer) bool {
s.peersLock.Lock()
defer s.peersLock.Unlock()
dp, ok := s.peers[peer.ID()]
if !ok {
return
return false
}
if dp != peer {
return
return false
}
delete(s.peers, peer.ID())
return true
}
// Peer returns a peer by its ID
@ -73,3 +77,21 @@ func (s *Store) Peers() []IPeer {
}
return peers
}
func (s *Store) GetOnlinePeersAndRegisterInterest(peerIDs []messages.PeerID, listener *Listener) []messages.PeerID {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
onlinePeers := make([]messages.PeerID, 0, len(peerIDs))
listener.AddInterestedPeers(peerIDs)
// Check for currently online peers
for _, id := range peerIDs {
if _, ok := s.peers[id]; ok {
onlinePeers = append(onlinePeers, id)
}
}
return onlinePeers
}