diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index b8cb2582f..7caafa53d 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -32,8 +32,8 @@ const ( defaultWgKeepAlive = 25 * time.Second connPriorityRelay ConnPriority = 1 - connPriorityICETurn ConnPriority = 1 - connPriorityICEP2P ConnPriority = 2 + connPriorityICETurn ConnPriority = 2 + connPriorityICEP2P ConnPriority = 3 ) type WgConfig struct { @@ -66,14 +66,6 @@ type ConnConfig struct { ICEConfig icemaker.Config } -type WorkerCallbacks struct { - OnRelayReadyCallback func(info RelayConnInfo) - OnRelayStatusChanged func(ConnStatus) - - OnICEConnReadyCallback func(ConnPriority, ICEConnInfo) - OnICEStatusChanged func(ConnStatus) -} - type Conn struct { log *log.Entry mu sync.Mutex @@ -135,21 +127,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu semaphore: semaphore, } - rFns := WorkerRelayCallbacks{ - OnConnReady: conn.relayConnectionIsReady, - OnDisconnected: conn.onWorkerRelayStateDisconnected, - } - - wFns := WorkerICECallbacks{ - OnConnReady: conn.iCEConnectionIsReady, - OnStatusChanged: conn.onWorkerICEStateDisconnected, - } - ctrl := isController(config) - conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns) + conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager) relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() - conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns) + conn.workerICE, err = NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally) if err != nil { return nil, err } @@ -304,7 +286,7 @@ func (conn *Conn) GetKey() string { } // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected -func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { +func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { conn.mu.Lock() defer conn.mu.Unlock() @@ -376,7 +358,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon } // todo review to make sense to handle connecting and disconnected status also? -func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { +func (conn *Conn) onICEStateDisconnected() { conn.mu.Lock() defer conn.mu.Unlock() @@ -384,7 +366,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { return } - conn.log.Tracef("ICE connection state changed to %s", newState) + conn.log.Tracef("ICE connection state changed to disconnected") if conn.wgProxyICE != nil { if err := conn.wgProxyICE.CloseConn(); err != nil { @@ -404,10 +386,11 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { conn.currentConnPriority = connPriorityRelay } - changed := conn.statusICE.Get() != newState && newState != StatusConnecting - conn.statusICE.Set(newState) - - conn.guard.SetICEConnDisconnected(changed) + changed := conn.statusICE.Get() != StatusDisconnected + if changed { + conn.guard.SetICEConnDisconnected() + } + conn.statusICE.Set(StatusDisconnected) peerState := State{ PubKey: conn.config.Key, @@ -422,7 +405,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { } } -func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { +func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { conn.mu.Lock() defer conn.mu.Unlock() @@ -474,7 +457,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) } -func (conn *Conn) onWorkerRelayStateDisconnected() { +func (conn *Conn) onRelayDisconnected() { conn.mu.Lock() defer conn.mu.Unlock() @@ -497,8 +480,10 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { } changed := conn.statusRelay.Get() != StatusDisconnected + if changed { + conn.guard.SetRelayedConnDisconnected() + } conn.statusRelay.Set(StatusDisconnected) - conn.guard.SetRelayedConnDisconnected(changed) peerState := State{ PubKey: conn.config.Key, diff --git a/client/internal/peer/guard/guard.go b/client/internal/peer/guard/guard.go index bf3527a62..1fc2b4a4a 100644 --- a/client/internal/peer/guard/guard.go +++ b/client/internal/peer/guard/guard.go @@ -29,8 +29,8 @@ type Guard struct { isConnectedOnAllWay isConnectedFunc timeout time.Duration srWatcher *SRWatcher - relayedConnDisconnected chan bool - iCEConnDisconnected chan bool + relayedConnDisconnected chan struct{} + iCEConnDisconnected chan struct{} } func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { @@ -41,8 +41,8 @@ func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, isConnectedOnAllWay: isConnectedFn, timeout: timeout, srWatcher: srWatcher, - relayedConnDisconnected: make(chan bool, 1), - iCEConnDisconnected: make(chan bool, 1), + relayedConnDisconnected: make(chan struct{}, 1), + iCEConnDisconnected: make(chan struct{}, 1), } } @@ -54,16 +54,16 @@ func (g *Guard) Start(ctx context.Context) { } } -func (g *Guard) SetRelayedConnDisconnected(changed bool) { +func (g *Guard) SetRelayedConnDisconnected() { select { - case g.relayedConnDisconnected <- changed: + case g.relayedConnDisconnected <- struct{}{}: default: } } -func (g *Guard) SetICEConnDisconnected(changed bool) { +func (g *Guard) SetICEConnDisconnected() { select { - case g.iCEConnDisconnected <- changed: + case g.iCEConnDisconnected <- struct{}{}: default: } } @@ -96,19 +96,13 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) { g.triggerOfferSending() } - case changed := <-g.relayedConnDisconnected: - if !changed { - continue - } + case <-g.relayedConnDisconnected: g.log.Debugf("Relay connection changed, reset reconnection ticker") ticker.Stop() ticker = g.prepareExponentTicker(ctx) tickerChannel = ticker.C - case changed := <-g.iCEConnDisconnected: - if !changed { - continue - } + case <-g.iCEConnDisconnected: g.log.Debugf("ICE connection changed, reset reconnection ticker") ticker.Stop() ticker = g.prepareExponentTicker(ctx) @@ -138,16 +132,10 @@ func (g *Guard) listenForDisconnectEvents(ctx context.Context) { g.log.Infof("start listen for reconnect events...") for { select { - case changed := <-g.relayedConnDisconnected: - if !changed { - continue - } + case <-g.relayedConnDisconnected: g.log.Debugf("Relay connection changed, triggering reconnect") g.triggerOfferSending() - case changed := <-g.iCEConnDisconnected: - if !changed { - continue - } + case <-g.iCEConnDisconnected: g.log.Debugf("ICE state changed, try to send new offer") g.triggerOfferSending() case <-srReconnectedChan: diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go new file mode 100644 index 000000000..6670c6517 --- /dev/null +++ b/client/internal/peer/wg_watcher.go @@ -0,0 +1,154 @@ +package peer + +import ( + "context" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/configurer" +) + +const ( + wgHandshakePeriod = 3 * time.Minute +) + +var ( + wgHandshakeOvertime = 30 * time.Second // allowed delay in network + checkPeriod = wgHandshakePeriod + wgHandshakeOvertime +) + +type WGInterfaceStater interface { + GetStats(key string) (configurer.WGStats, error) +} + +type WGWatcher struct { + log *log.Entry + wgIfaceStater WGInterfaceStater + peerKey string + + ctx context.Context + ctxCancel context.CancelFunc + ctxLock sync.Mutex + waitGroup sync.WaitGroup +} + +func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string) *WGWatcher { + return &WGWatcher{ + log: log, + wgIfaceStater: wgIfaceStater, + peerKey: peerKey, + } +} + +// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing. +func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { + w.log.Debugf("enable WireGuard watcher") + w.ctxLock.Lock() + defer w.ctxLock.Unlock() + + if w.ctx != nil && w.ctx.Err() == nil { + w.log.Errorf("WireGuard watcher already enabled") + return + } + + ctx, ctxCancel := context.WithCancel(parentCtx) + w.ctx = ctx + w.ctxCancel = ctxCancel + + initialHandshake, err := w.wgState() + if err != nil { + w.log.Warnf("failed to read initial wg stats: %v", err) + } + + w.waitGroup.Add(1) + go w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake) +} + +// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit +func (w *WGWatcher) DisableWgWatcher() { + w.ctxLock.Lock() + defer w.ctxLock.Unlock() + + if w.ctxCancel == nil { + return + } + + w.log.Debugf("disable WireGuard watcher") + + w.ctxCancel() + w.ctxCancel = nil + w.waitGroup.Wait() +} + +// wgStateCheck help to check the state of the WireGuard handshake and relay connection +func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) { + w.log.Infof("WireGuard watcher started") + defer w.waitGroup.Done() + + timer := time.NewTimer(wgHandshakeOvertime) + defer timer.Stop() + defer ctxCancel() + + lastHandshake := initialHandshake + + for { + select { + case <-timer.C: + handshake, ok := w.handshakeCheck(lastHandshake) + if !ok { + onDisconnectedFn() + return + } + lastHandshake = *handshake + + resetTime := time.Until(handshake.Add(checkPeriod)) + timer.Reset(resetTime) + + w.log.Debugf("WireGuard watcher reset timer: %v", resetTime) + case <-ctx.Done(): + w.log.Infof("WireGuard watcher stopped") + return + } + } +} + +// handshakeCheck checks the WireGuard handshake and return the new handshake time if it is different from the previous one +func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) { + handshake, err := w.wgState() + if err != nil { + w.log.Errorf("failed to read wg stats: %v", err) + return nil, false + } + + w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake) + + // the current know handshake did not change + if handshake.Equal(lastHandshake) { + w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake) + return nil, false + } + + // in case if the machine is suspended, the handshake time will be in the past + if handshake.Add(checkPeriod).Before(time.Now()) { + w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake) + return nil, false + } + + // error handling for handshake time in the future + if handshake.After(time.Now()) { + w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake) + return nil, false + } + + return &handshake, true +} + +func (w *WGWatcher) wgState() (time.Time, error) { + wgState, err := w.wgIfaceStater.GetStats(w.peerKey) + if err != nil { + return time.Time{}, err + } + return wgState.LastHandshake, nil +} diff --git a/client/internal/peer/wg_watcher_test.go b/client/internal/peer/wg_watcher_test.go new file mode 100644 index 000000000..a5b9026ad --- /dev/null +++ b/client/internal/peer/wg_watcher_test.go @@ -0,0 +1,98 @@ +package peer + +import ( + "context" + "testing" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/configurer" +) + +type MocWgIface struct { + initial bool + lastHandshake time.Time + stop bool +} + +func (m *MocWgIface) GetStats(key string) (configurer.WGStats, error) { + if !m.initial { + m.initial = true + return configurer.WGStats{}, nil + } + + if !m.stop { + m.lastHandshake = time.Now() + } + + stats := configurer.WGStats{ + LastHandshake: m.lastHandshake, + } + + return stats, nil +} + +func (m *MocWgIface) disconnect() { + m.stop = true +} + +func TestWGWatcher_EnableWgWatcher(t *testing.T) { + checkPeriod = 5 * time.Second + wgHandshakeOvertime = 1 * time.Second + + mlog := log.WithField("peer", "tet") + mocWgIface := &MocWgIface{} + watcher := NewWGWatcher(mlog, mocWgIface, "") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + onDisconnected := make(chan struct{}, 1) + watcher.EnableWgWatcher(ctx, func() { + mlog.Infof("onDisconnectedFn") + onDisconnected <- struct{}{} + }) + + // wait for initial reading + time.Sleep(2 * time.Second) + mocWgIface.disconnect() + + select { + case <-onDisconnected: + case <-time.After(10 * time.Second): + t.Errorf("timeout") + } + watcher.DisableWgWatcher() +} + +func TestWGWatcher_ReEnable(t *testing.T) { + checkPeriod = 5 * time.Second + wgHandshakeOvertime = 1 * time.Second + + mlog := log.WithField("peer", "tet") + mocWgIface := &MocWgIface{} + watcher := NewWGWatcher(mlog, mocWgIface, "") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + onDisconnected := make(chan struct{}, 1) + + watcher.EnableWgWatcher(ctx, func() {}) + watcher.DisableWgWatcher() + + watcher.EnableWgWatcher(ctx, func() { + onDisconnected <- struct{}{} + }) + + time.Sleep(2 * time.Second) + mocWgIface.disconnect() + + select { + case <-onDisconnected: + case <-time.After(10 * time.Second): + t.Errorf("timeout") + } + watcher.DisableWgWatcher() +} diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 008318492..7dd84a98e 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -31,20 +31,15 @@ type ICEConnInfo struct { RelayedOnLocal bool } -type WorkerICECallbacks struct { - OnConnReady func(ConnPriority, ICEConnInfo) - OnStatusChanged func(ConnStatus) -} - type WorkerICE struct { ctx context.Context log *log.Entry config ConnConfig + conn *Conn signaler *Signaler iFaceDiscover stdnet.ExternalIFaceDiscover statusRecorder *Status hasRelayOnLocally bool - conn WorkerICECallbacks agent *ice.Agent muxAgent sync.Mutex @@ -60,16 +55,16 @@ type WorkerICE struct { lastKnownState ice.ConnectionState } -func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) { +func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) { w := &WorkerICE{ ctx: ctx, log: log, config: config, + conn: conn, signaler: signaler, iFaceDiscover: ifaceDiscover, statusRecorder: statusRecorder, hasRelayOnLocally: hasRelayOnLocally, - conn: callBacks, } localUfrag, localPwd, err := icemaker.GenerateICECredentials() @@ -154,8 +149,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { Relayed: isRelayed(pair), RelayedOnLocal: isRelayCandidate(pair.Local), } - w.log.Debugf("on ICE conn read to use ready") - go w.conn.OnConnReady(selectedPriority(pair), ci) + w.log.Debugf("on ICE conn is ready to use") + go w.conn.onICEConnectionIsReady(selectedPriority(pair), ci) } // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. @@ -220,7 +215,7 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected: if w.lastKnownState != ice.ConnectionStateDisconnected { w.lastKnownState = ice.ConnectionStateDisconnected - w.conn.OnStatusChanged(StatusDisconnected) + w.conn.onICEStateDisconnected() } w.closeAgent(agentCancel) default: diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index c22dcdeda..56c19cd1e 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -6,52 +6,41 @@ import ( "net" "sync" "sync/atomic" - "time" log "github.com/sirupsen/logrus" relayClient "github.com/netbirdio/netbird/relay/client" ) -var ( - wgHandshakePeriod = 3 * time.Minute - wgHandshakeOvertime = 30 * time.Second -) - type RelayConnInfo struct { relayedConn net.Conn rosenpassPubKey []byte rosenpassAddr string } -type WorkerRelayCallbacks struct { - OnConnReady func(RelayConnInfo) - OnDisconnected func() -} - type WorkerRelay struct { log *log.Entry isController bool config ConnConfig + conn *Conn relayManager relayClient.ManagerService - callBacks WorkerRelayCallbacks - relayedConn net.Conn - relayLock sync.Mutex - ctxWgWatch context.Context - ctxCancelWgWatch context.CancelFunc - ctxLock sync.Mutex + relayedConn net.Conn + relayLock sync.Mutex relaySupportedOnRemotePeer atomic.Bool + + wgWatcher *WGWatcher } -func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { +func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService) *WorkerRelay { r := &WorkerRelay{ log: log, isController: ctrl, config: config, + conn: conn, relayManager: relayManager, - callBacks: callbacks, + wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key), } return r } @@ -87,7 +76,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.relayedConn = relayedConn w.relayLock.Unlock() - err = w.relayManager.AddCloseListener(srv, w.onRelayMGDisconnected) + err = w.relayManager.AddCloseListener(srv, w.onRelayClientDisconnected) if err != nil { log.Errorf("failed to add close listener: %s", err) _ = relayedConn.Close() @@ -95,7 +84,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { } w.log.Debugf("peer conn opened via Relay: %s", srv) - go w.callBacks.OnConnReady(RelayConnInfo{ + go w.conn.onRelayConnectionIsReady(RelayConnInfo{ relayedConn: relayedConn, rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey, rosenpassAddr: remoteOfferAnswer.RosenpassAddr, @@ -103,32 +92,11 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { } func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) { - w.log.Debugf("enable WireGuard watcher") - w.ctxLock.Lock() - defer w.ctxLock.Unlock() - - if w.ctxWgWatch != nil && w.ctxWgWatch.Err() == nil { - return - } - - ctx, ctxCancel := context.WithCancel(ctx) - w.ctxWgWatch = ctx - w.ctxCancelWgWatch = ctxCancel - - w.wgStateCheck(ctx, ctxCancel) + w.wgWatcher.EnableWgWatcher(ctx, w.onWGDisconnected) } func (w *WorkerRelay) DisableWgWatcher() { - w.ctxLock.Lock() - defer w.ctxLock.Unlock() - - if w.ctxCancelWgWatch == nil { - return - } - - w.log.Debugf("disable WireGuard watcher") - - w.ctxCancelWgWatch() + w.wgWatcher.DisableWgWatcher() } func (w *WorkerRelay) RelayInstanceAddress() (string, error) { @@ -150,57 +118,17 @@ func (w *WorkerRelay) CloseConn() { return } - err := w.relayedConn.Close() - if err != nil { + if err := w.relayedConn.Close(); err != nil { w.log.Warnf("failed to close relay connection: %v", err) } } -// wgStateCheck help to check the state of the WireGuard handshake and relay connection -func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) { - w.log.Debugf("WireGuard watcher started") - lastHandshake, err := w.wgState() - if err != nil { - w.log.Warnf("failed to read wg stats: %v", err) - lastHandshake = time.Time{} - } - - go func(lastHandshake time.Time) { - timer := time.NewTimer(wgHandshakeOvertime) - defer timer.Stop() - defer ctxCancel() - - for { - select { - case <-timer.C: - handshake, err := w.wgState() - if err != nil { - w.log.Errorf("failed to read wg stats: %v", err) - timer.Reset(wgHandshakeOvertime) - continue - } - - w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake) - - if handshake.Equal(lastHandshake) { - w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake) - w.relayLock.Lock() - _ = w.relayedConn.Close() - w.relayLock.Unlock() - w.callBacks.OnDisconnected() - return - } - - resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime)) - lastHandshake = handshake - timer.Reset(resetTime) - case <-ctx.Done(): - w.log.Debugf("WireGuard watcher stopped") - return - } - } - }(lastHandshake) +func (w *WorkerRelay) onWGDisconnected() { + w.relayLock.Lock() + _ = w.relayedConn.Close() + w.relayLock.Unlock() + w.conn.onRelayDisconnected() } func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool { @@ -217,20 +145,7 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st return remoteRelayAddress } -func (w *WorkerRelay) wgState() (time.Time, error) { - wgState, err := w.config.WgConfig.WgInterface.GetStats(w.config.Key) - if err != nil { - return time.Time{}, err - } - return wgState.LastHandshake, nil -} - -func (w *WorkerRelay) onRelayMGDisconnected() { - w.ctxLock.Lock() - defer w.ctxLock.Unlock() - - if w.ctxCancelWgWatch != nil { - w.ctxCancelWgWatch() - } - go w.callBacks.OnDisconnected() +func (w *WorkerRelay) onRelayClientDisconnected() { + w.wgWatcher.DisableWgWatcher() + go w.conn.onRelayDisconnected() } diff --git a/relay/client/client.go b/relay/client/client.go index 3c23b70d2..9e7e54393 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -141,7 +141,6 @@ type Client struct { muInstanceURL sync.Mutex onDisconnectListener func(string) - onConnectedListener func() listenerMutex sync.Mutex } @@ -190,7 +189,6 @@ func (c *Client) Connect() error { c.wgReadLoop.Add(1) go c.readLoop(c.relayConn) - go c.notifyConnected() return nil } @@ -238,12 +236,6 @@ func (c *Client) SetOnDisconnectListener(fn func(string)) { c.onDisconnectListener = fn } -func (c *Client) SetOnConnectedListener(fn func()) { - c.listenerMutex.Lock() - defer c.listenerMutex.Unlock() - c.onConnectedListener = fn -} - // HasConns returns true if there are connections. func (c *Client) HasConns() bool { c.mu.Lock() @@ -559,16 +551,6 @@ func (c *Client) notifyDisconnected() { go c.onDisconnectListener(c.connectionURL) } -func (c *Client) notifyConnected() { - c.listenerMutex.Lock() - defer c.listenerMutex.Unlock() - - if c.onConnectedListener == nil { - return - } - go c.onConnectedListener() -} - func (c *Client) writeCloseMsg() { msg := messages.MarshalCloseMsg() _, err := c.relayConn.Write(msg) diff --git a/relay/client/guard.go b/relay/client/guard.go index b971363a8..554330ea3 100644 --- a/relay/client/guard.go +++ b/relay/client/guard.go @@ -14,8 +14,9 @@ var ( // Guard manage the reconnection tries to the Relay server in case of disconnection event. type Guard struct { - // OnNewRelayClient is a channel that is used to notify the relay client about a new relay client instance. + // OnNewRelayClient is a channel that is used to notify the relay manager about a new relay client instance. OnNewRelayClient chan *Client + OnReconnected chan struct{} serverPicker *ServerPicker } @@ -23,6 +24,7 @@ type Guard struct { func NewGuard(sp *ServerPicker) *Guard { g := &Guard{ OnNewRelayClient: make(chan *Client, 1), + OnReconnected: make(chan struct{}, 1), serverPicker: sp, } return g @@ -39,14 +41,13 @@ func NewGuard(sp *ServerPicker) *Guard { // - relayClient: The relay client instance that was disconnected. // todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) { - if relayClient == nil { - goto RETRY - } - if g.isServerURLStillValid(relayClient) && g.quickReconnect(ctx, relayClient) { + // try to reconnect to the same server + if ok := g.tryToQuickReconnect(ctx, relayClient); ok { + g.notifyReconnected() return } -RETRY: + // start a ticker to pick a new server ticker := exponentTicker(ctx) defer ticker.Stop() @@ -64,6 +65,28 @@ RETRY: } } +func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool { + if rc == nil { + return false + } + + if !g.isServerURLStillValid(rc) { + return false + } + + if cancelled := waiteBeforeRetry(parentCtx); !cancelled { + return false + } + + log.Infof("try to reconnect to Relay server: %s", rc.connectionURL) + + if err := rc.Connect(); err != nil { + log.Errorf("failed to reconnect to relay server: %s", err) + return false + } + return true +} + func (g *Guard) retry(ctx context.Context) error { log.Infof("try to pick up a new Relay server") relayClient, err := g.serverPicker.PickServer(ctx) @@ -78,23 +101,6 @@ func (g *Guard) retry(ctx context.Context) error { return nil } -func (g *Guard) quickReconnect(parentCtx context.Context, rc *Client) bool { - ctx, cancel := context.WithTimeout(parentCtx, 1500*time.Millisecond) - defer cancel() - <-ctx.Done() - - if parentCtx.Err() != nil { - return false - } - log.Infof("try to reconnect to Relay server: %s", rc.connectionURL) - - if err := rc.Connect(); err != nil { - log.Errorf("failed to reconnect to relay server: %s", err) - return false - } - return true -} - func (g *Guard) drainRelayClientChan() { select { case <-g.OnNewRelayClient: @@ -111,6 +117,13 @@ func (g *Guard) isServerURLStillValid(rc *Client) bool { return false } +func (g *Guard) notifyReconnected() { + select { + case g.OnReconnected <- struct{}{}: + default: + } +} + func exponentTicker(ctx context.Context) *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 2 * time.Second, @@ -121,3 +134,15 @@ func exponentTicker(ctx context.Context) *backoff.Ticker { return backoff.NewTicker(bo) } + +func waiteBeforeRetry(ctx context.Context) bool { + timer := time.NewTimer(1500 * time.Millisecond) + defer timer.Stop() + + select { + case <-timer.C: + return true + case <-ctx.Done(): + return false + } +} diff --git a/relay/client/manager.go b/relay/client/manager.go index d847bb879..26b113050 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -165,6 +165,9 @@ func (m *Manager) Ready() bool { } func (m *Manager) SetOnReconnectedListener(f func()) { + m.listenerLock.Lock() + defer m.listenerLock.Unlock() + m.onReconnectedListenerFn = f } @@ -284,6 +287,9 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { } func (m *Manager) onServerConnected() { + m.listenerLock.Lock() + defer m.listenerLock.Unlock() + if m.onReconnectedListenerFn == nil { return } @@ -304,8 +310,11 @@ func (m *Manager) onServerDisconnected(serverAddress string) { func (m *Manager) listenGuardEvent(ctx context.Context) { for { select { + case <-m.reconnectGuard.OnReconnected: + m.onServerConnected() case rc := <-m.reconnectGuard.OnNewRelayClient: m.storeClient(rc) + m.onServerConnected() case <-ctx.Done(): return } @@ -317,7 +326,6 @@ func (m *Manager) storeClient(client *Client) { defer m.relayClientMu.Unlock() m.relayClient = client - m.relayClient.SetOnConnectedListener(m.onServerConnected) m.relayClient.SetOnDisconnectListener(m.onServerDisconnected) }