diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index baf1a2db4..85f94b53f 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -103,6 +103,7 @@ type Conn struct { workerICE *WorkerICE workerRelay *WorkerRelay + wgWatcherWg sync.WaitGroup connIDRelay nbnet.ConnectionID connIDICE nbnet.ConnectionID @@ -211,6 +212,7 @@ func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { // Close closes this peer Conn issuing a close event to the Conn closeCh func (conn *Conn) Close() { conn.mu.Lock() + defer conn.wgWatcherWg.Wait() defer conn.mu.Unlock() conn.log.Infof("close peer connection") @@ -252,6 +254,7 @@ func (conn *Conn) Close() { } conn.setStatusToDisconnected() + conn.log.Infof("peer connection has been closed") } // OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise @@ -362,6 +365,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC } conn.workerRelay.DisableWgWatcher() + // todo consider to run conn.wgWatcherWg.Wait() here if conn.wgProxyRelay != nil { conn.wgProxyRelay.Pause() @@ -407,7 +411,12 @@ func (conn *Conn) onICEStateDisconnected() { if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { conn.log.Errorf("failed to switch to relay conn: %v", err) } - conn.workerRelay.EnableWgWatcher(conn.ctx) + + conn.wgWatcherWg.Add(1) + go func() { + defer conn.wgWatcherWg.Done() + conn.workerRelay.EnableWgWatcher(conn.ctx) + }() conn.currentConnPriority = connPriorityRelay } else { conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String()) @@ -476,7 +485,12 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err) return } - conn.workerRelay.EnableWgWatcher(conn.ctx) + + conn.wgWatcherWg.Add(1) + go func() { + defer conn.wgWatcherWg.Done() + conn.workerRelay.EnableWgWatcher(conn.ctx) + }() wgConfigWorkaround() conn.currentConnPriority = connPriorityRelay diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go index 49049b3d0..589f405bc 100644 --- a/client/internal/peer/wg_watcher.go +++ b/client/internal/peer/wg_watcher.go @@ -32,7 +32,6 @@ type WGWatcher struct { ctx context.Context ctxCancel context.CancelFunc ctxLock sync.Mutex - waitGroup sync.WaitGroup } func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { @@ -48,24 +47,24 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { w.log.Debugf("enable WireGuard watcher") w.ctxLock.Lock() - defer w.ctxLock.Unlock() if w.ctx != nil && w.ctx.Err() == nil { w.log.Errorf("WireGuard watcher already enabled") + w.ctxLock.Unlock() return } ctx, ctxCancel := context.WithCancel(parentCtx) w.ctx = ctx w.ctxCancel = ctxCancel + w.ctxLock.Unlock() initialHandshake, err := w.wgState() if err != nil { w.log.Warnf("failed to read initial wg stats: %v", err) } - w.waitGroup.Add(1) - go w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake) + w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake) } // DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit @@ -81,13 +80,11 @@ func (w *WGWatcher) DisableWgWatcher() { w.ctxCancel() w.ctxCancel = nil - w.waitGroup.Wait() } // wgStateCheck help to check the state of the WireGuard handshake and relay connection func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) { w.log.Infof("WireGuard watcher started") - defer w.waitGroup.Done() timer := time.NewTimer(wgHandshakeOvertime) defer timer.Stop() diff --git a/client/internal/peer/wg_watcher_test.go b/client/internal/peer/wg_watcher_test.go index dfd11e74f..8bfb1af4c 100644 --- a/client/internal/peer/wg_watcher_test.go +++ b/client/internal/peer/wg_watcher_test.go @@ -49,7 +49,7 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) { defer cancel() onDisconnected := make(chan struct{}, 1) - watcher.EnableWgWatcher(ctx, func() { + go watcher.EnableWgWatcher(ctx, func() { mlog.Infof("onDisconnectedFn") onDisconnected <- struct{}{} }) @@ -79,10 +79,11 @@ func TestWGWatcher_ReEnable(t *testing.T) { onDisconnected := make(chan struct{}, 1) - watcher.EnableWgWatcher(ctx, func() {}) + go watcher.EnableWgWatcher(ctx, func() {}) + time.Sleep(1 * time.Second) watcher.DisableWgWatcher() - watcher.EnableWgWatcher(ctx, func() { + go watcher.EnableWgWatcher(ctx, func() { onDisconnected <- struct{}{} })