diff --git a/relay/client/client.go b/relay/client/client.go index 2bf679ecb..32dfbb4db 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -181,13 +181,17 @@ func (c *Client) Connect(ctx context.Context) error { return nil } - if err := c.connect(ctx); err != nil { + instanceURL, err := c.connect(ctx) + if err != nil { return err } + c.muInstanceURL.Lock() + c.instanceURL = instanceURL + c.muInstanceURL.Unlock() c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID) - c.log = c.log.WithField("relay", c.instanceURL.String()) + c.log = c.log.WithField("relay", instanceURL.String()) c.log.Infof("relay connection established") c.serviceIsRunning = true @@ -229,9 +233,18 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID) msgChannel := make(chan Msg, 100) - conn := NewConn(c, peerID, msgChannel, c.instanceURL) c.mu.Lock() + if !c.serviceIsRunning { + c.mu.Unlock() + return nil, fmt.Errorf("relay connection is not established") + } + + c.muInstanceURL.Lock() + instanceURL := c.instanceURL + c.muInstanceURL.Unlock() + conn := NewConn(c, peerID, msgChannel, instanceURL) + _, ok = c.conns[peerID] if ok { c.mu.Unlock() @@ -278,69 +291,67 @@ func (c *Client) Close() error { return c.close(true) } -func (c *Client) connect(ctx context.Context) error { +func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) conn, err := rd.Dial() if err != nil { - return err + return nil, err } c.relayConn = conn - if err = c.handShake(ctx); err != nil { + instanceURL, err := c.handShake(ctx) + if err != nil { cErr := conn.Close() if cErr != nil { c.log.Errorf("failed to close connection: %s", cErr) } - return err + return nil, err } - return nil + return instanceURL, nil } -func (c *Client) handShake(ctx context.Context) error { +func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) { msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) if err != nil { c.log.Errorf("failed to marshal auth message: %s", err) - return err + return nil, err } _, err = c.relayConn.Write(msg) if err != nil { c.log.Errorf("failed to send auth message: %s", err) - return err + return nil, err } buf := make([]byte, messages.MaxHandshakeRespSize) n, err := c.readWithTimeout(ctx, buf) if err != nil { c.log.Errorf("failed to read auth response: %s", err) - return err + return nil, err } _, err = messages.ValidateVersion(buf[:n]) if err != nil { - return fmt.Errorf("validate version: %w", err) + return nil, fmt.Errorf("validate version: %w", err) } msgType, err := messages.DetermineServerMessageType(buf[:n]) if err != nil { c.log.Errorf("failed to determine message type: %s", err) - return err + return nil, err } if msgType != messages.MsgTypeAuthResponse { c.log.Errorf("unexpected message type: %s", msgType) - return fmt.Errorf("unexpected message type") + return nil, fmt.Errorf("unexpected message type") } addr, err := messages.UnmarshalAuthResponse(buf[:n]) if err != nil { - return err + return nil, err } - c.muInstanceURL.Lock() - c.instanceURL = &RelayAddr{addr: addr} - c.muInstanceURL.Unlock() - return nil + return &RelayAddr{addr: addr}, nil } func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) { @@ -386,10 +397,6 @@ func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internal hc.Stop() - c.muInstanceURL.Lock() - c.instanceURL = nil - c.muInstanceURL.Unlock() - c.stateSubscription.Cleanup() c.wgReadLoop.Done() _ = c.close(false) @@ -578,8 +585,12 @@ func (c *Client) close(gracefullyExit bool) error { c.log.Warn("relay connection was already marked as not running") return nil } - c.serviceIsRunning = false + + c.muInstanceURL.Lock() + c.instanceURL = nil + c.muInstanceURL.Unlock() + c.log.Infof("closing all peer connections") c.closeAllConns() if gracefullyExit { diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index d20cdaac0..52f2833e4 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -229,16 +229,14 @@ func TestForeginAutoClose(t *testing.T) { errChan := make(chan error, 1) go func() { t.Log("binding server 1.") - err := srv1.Listen(srvCfg1) - if err != nil { + if err := srv1.Listen(srvCfg1); err != nil { errChan <- err } }() defer func() { t.Logf("closing server 1.") - err := srv1.Shutdown(ctx) - if err != nil { + if err := srv1.Shutdown(ctx); err != nil { t.Errorf("failed to close server: %s", err) } t.Logf("server 1. closed") @@ -287,15 +285,8 @@ func TestForeginAutoClose(t *testing.T) { } t.Log("open connection to another peer") - conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer") - if err != nil { - t.Fatalf("failed to bind channel: %s", err) - } - - t.Log("close conn") - err = conn.Close() - if err != nil { - t.Fatalf("failed to close connection: %s", err) + if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil { + t.Fatalf("should have failed to open connection to another peer") } timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second diff --git a/relay/client/peer_subscription.go b/relay/client/peer_subscription.go index 03e7127b3..85bd41cbd 100644 --- a/relay/client/peer_subscription.go +++ b/relay/client/peer_subscription.go @@ -3,6 +3,8 @@ package client import ( "context" "errors" + "fmt" + "sync" "time" log "github.com/sirupsen/logrus" @@ -28,6 +30,7 @@ type PeersStateSubscription struct { listenForOfflinePeers map[messages.PeerID]struct{} waitingPeers map[messages.PeerID]chan struct{} + mu sync.Mutex // Mutex to protect access to waitingPeers and listenForOfflinePeers } func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription { @@ -43,24 +46,31 @@ func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offl // OnPeersOnline should be called when a notification is received that certain peers have come online. // It checks if any of the peers are being waited on and signals their availability. func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) { + s.mu.Lock() + defer s.mu.Unlock() + for _, peerID := range peersID { waitCh, ok := s.waitingPeers[peerID] if !ok { + // If meanwhile the peer was unsubscribed, we don't need to signal it continue } - close(waitCh) + waitCh <- struct{}{} delete(s.waitingPeers, peerID) + close(waitCh) } } func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) { + s.mu.Lock() relevantPeers := make([]messages.PeerID, 0, len(peersID)) for _, peerID := range peersID { if _, ok := s.listenForOfflinePeers[peerID]; ok { relevantPeers = append(relevantPeers, peerID) } } + s.mu.Unlock() if len(relevantPeers) > 0 { s.offlineCallback(relevantPeers) @@ -68,36 +78,41 @@ func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) { } // WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes. -// todo: when we unsubscribe while this is running, this will not return with error func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error { // Check if already waiting for this peer + s.mu.Lock() if _, exists := s.waitingPeers[peerID]; exists { + s.mu.Unlock() return errors.New("already waiting for peer to come online") } // Create a channel to wait for the peer to come online - waitCh := make(chan struct{}) + waitCh := make(chan struct{}, 1) s.waitingPeers[peerID] = waitCh + s.listenForOfflinePeers[peerID] = struct{}{} + s.mu.Unlock() - if err := s.subscribeStateChange([]messages.PeerID{peerID}); err != nil { + if err := s.subscribeStateChange(peerID); err != nil { s.log.Errorf("failed to subscribe to peer state: %s", err) - close(waitCh) - delete(s.waitingPeers, peerID) - return err - } - - defer func() { + s.mu.Lock() if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh { close(waitCh) delete(s.waitingPeers, peerID) + delete(s.listenForOfflinePeers, peerID) } - }() + s.mu.Unlock() + return err + } // Wait for peer to come online or context to be cancelled timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout) defer cancel() select { - case <-waitCh: + case _, ok := <-waitCh: + if !ok { + return fmt.Errorf("wait for peer to come online has been cancelled") + } + s.log.Debugf("peer %s is now online", peerID) return nil case <-timeoutCtx.Done(): @@ -105,6 +120,13 @@ func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil { s.log.Errorf("failed to unsubscribe from peer state: %s", err) } + s.mu.Lock() + if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh { + close(waitCh) + delete(s.waitingPeers, peerID) + delete(s.listenForOfflinePeers, peerID) + } + s.mu.Unlock() return timeoutCtx.Err() } } @@ -112,6 +134,7 @@ func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error { msgErr := s.unsubscribeStateChange(peerIDs) + s.mu.Lock() for _, peerID := range peerIDs { if wch, ok := s.waitingPeers[peerID]; ok { close(wch) @@ -120,11 +143,15 @@ func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerI delete(s.listenForOfflinePeers, peerID) } + s.mu.Unlock() return msgErr } func (s *PeersStateSubscription) Cleanup() { + s.mu.Lock() + defer s.mu.Unlock() + for _, waitCh := range s.waitingPeers { close(waitCh) } @@ -133,16 +160,12 @@ func (s *PeersStateSubscription) Cleanup() { s.listenForOfflinePeers = make(map[messages.PeerID]struct{}) } -func (s *PeersStateSubscription) subscribeStateChange(peerIDs []messages.PeerID) error { - msgs, err := messages.MarshalSubPeerStateMsg(peerIDs) +func (s *PeersStateSubscription) subscribeStateChange(peerID messages.PeerID) error { + msgs, err := messages.MarshalSubPeerStateMsg([]messages.PeerID{peerID}) if err != nil { return err } - for _, peer := range peerIDs { - s.listenForOfflinePeers[peer] = struct{}{} - } - for _, msg := range msgs { if _, err := s.relayConn.Write(msg); err != nil { return err diff --git a/relay/server/store/listener.go b/relay/server/store/listener.go index e5f455795..b7c5f4ce8 100644 --- a/relay/server/store/listener.go +++ b/relay/server/store/listener.go @@ -8,6 +8,7 @@ import ( ) type Listener struct { + ctx context.Context store *Store onlineChan chan messages.PeerID @@ -15,12 +16,11 @@ type Listener struct { interestedPeersForOffline map[messages.PeerID]struct{} interestedPeersForOnline map[messages.PeerID]struct{} mu sync.RWMutex - - listenerCtx context.Context } -func newListener(store *Store) *Listener { +func newListener(ctx context.Context, store *Store) *Listener { l := &Listener{ + ctx: ctx, store: store, onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol @@ -65,11 +65,10 @@ func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) { } } -func (l *Listener) listenForEvents(ctx context.Context, onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) { - l.listenerCtx = ctx +func (l *Listener) listenForEvents(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) { for { select { - case <-ctx.Done(): + case <-l.ctx.Done(): return case pID := <-l.onlineChan: peers := make([]messages.PeerID, 0) @@ -102,7 +101,7 @@ func (l *Listener) peerWentOffline(peerID messages.PeerID) { if _, ok := l.interestedPeersForOffline[peerID]; ok { select { case l.offlineChan <- peerID: - case <-l.listenerCtx.Done(): + case <-l.ctx.Done(): } } } @@ -114,7 +113,7 @@ func (l *Listener) peerComeOnline(peerID messages.PeerID) { if _, ok := l.interestedPeersForOnline[peerID]; ok { select { case l.onlineChan <- peerID: - case <-l.listenerCtx.Done(): + case <-l.ctx.Done(): } delete(l.interestedPeersForOnline, peerID) } diff --git a/relay/server/store/notifier.go b/relay/server/store/notifier.go index d04db478b..ad2e53545 100644 --- a/relay/server/store/notifier.go +++ b/relay/server/store/notifier.go @@ -24,8 +24,8 @@ func NewPeerNotifier(store *Store) *PeerNotifier { func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener { ctx, cancel := context.WithCancel(context.Background()) - listener := newListener(pn.store) - go listener.listenForEvents(ctx, onPeersComeOnline, onPeersWentOffline) + listener := newListener(ctx, pn.store) + go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline) pn.listenersMutex.Lock() pn.listeners[listener] = cancel