[client] Fix close WireGuard watcher (#3598)

This PR fixes issues with closing the WireGuard watcher by adjusting its asynchronous invocation and synchronization.

Update tests in wg_watcher_test.go to launch the watcher in a goroutine and add a delay for timing.
Modify wg_watcher.go to run the periodic handshake check synchronously by removing the waitGroup and goroutine.
Enhance conn.go to wait on the watcher wait group during connection close and add a note for potential further synchronization
This commit is contained in:
Zoltan Papp 2025-03-28 20:12:31 +01:00 committed by GitHub
parent ed5647028a
commit 21464ac770
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 11 deletions

View File

@ -103,6 +103,7 @@ type Conn struct {
workerICE *WorkerICE workerICE *WorkerICE
workerRelay *WorkerRelay workerRelay *WorkerRelay
wgWatcherWg sync.WaitGroup
connIDRelay nbnet.ConnectionID connIDRelay nbnet.ConnectionID
connIDICE nbnet.ConnectionID connIDICE nbnet.ConnectionID
@ -211,6 +212,7 @@ func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
// Close closes this peer Conn issuing a close event to the Conn closeCh // Close closes this peer Conn issuing a close event to the Conn closeCh
func (conn *Conn) Close() { func (conn *Conn) Close() {
conn.mu.Lock() conn.mu.Lock()
defer conn.wgWatcherWg.Wait()
defer conn.mu.Unlock() defer conn.mu.Unlock()
conn.log.Infof("close peer connection") conn.log.Infof("close peer connection")
@ -252,6 +254,7 @@ func (conn *Conn) Close() {
} }
conn.setStatusToDisconnected() conn.setStatusToDisconnected()
conn.log.Infof("peer connection has been closed")
} }
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
@ -362,6 +365,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
} }
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Pause() conn.wgProxyRelay.Pause()
@ -407,7 +411,12 @@ func (conn *Conn) onICEStateDisconnected() {
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err) conn.log.Errorf("failed to switch to relay conn: %v", err)
} }
conn.workerRelay.EnableWgWatcher(conn.ctx)
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = connPriorityRelay
} else { } else {
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String()) conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String())
@ -476,7 +485,12 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err) conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return return
} }
conn.workerRelay.EnableWgWatcher(conn.ctx)
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
wgConfigWorkaround() wgConfigWorkaround()
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = connPriorityRelay

View File

@ -32,7 +32,6 @@ type WGWatcher struct {
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
ctxLock sync.Mutex ctxLock sync.Mutex
waitGroup sync.WaitGroup
} }
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@ -48,24 +47,24 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher") w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock() w.ctxLock.Lock()
defer w.ctxLock.Unlock()
if w.ctx != nil && w.ctx.Err() == nil { if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled") w.log.Errorf("WireGuard watcher already enabled")
w.ctxLock.Unlock()
return return
} }
ctx, ctxCancel := context.WithCancel(parentCtx) ctx, ctxCancel := context.WithCancel(parentCtx)
w.ctx = ctx w.ctx = ctx
w.ctxCancel = ctxCancel w.ctxCancel = ctxCancel
w.ctxLock.Unlock()
initialHandshake, err := w.wgState() initialHandshake, err := w.wgState()
if err != nil { if err != nil {
w.log.Warnf("failed to read initial wg stats: %v", err) w.log.Warnf("failed to read initial wg stats: %v", err)
} }
w.waitGroup.Add(1) w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
go w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
} }
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit // DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
@ -81,13 +80,11 @@ func (w *WGWatcher) DisableWgWatcher() {
w.ctxCancel() w.ctxCancel()
w.ctxCancel = nil w.ctxCancel = nil
w.waitGroup.Wait()
} }
// wgStateCheck help to check the state of the WireGuard handshake and relay connection // wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) { func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started") w.log.Infof("WireGuard watcher started")
defer w.waitGroup.Done()
timer := time.NewTimer(wgHandshakeOvertime) timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop() defer timer.Stop()

View File

@ -49,7 +49,7 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
defer cancel() defer cancel()
onDisconnected := make(chan struct{}, 1) onDisconnected := make(chan struct{}, 1)
watcher.EnableWgWatcher(ctx, func() { go watcher.EnableWgWatcher(ctx, func() {
mlog.Infof("onDisconnectedFn") mlog.Infof("onDisconnectedFn")
onDisconnected <- struct{}{} onDisconnected <- struct{}{}
}) })
@ -79,10 +79,11 @@ func TestWGWatcher_ReEnable(t *testing.T) {
onDisconnected := make(chan struct{}, 1) onDisconnected := make(chan struct{}, 1)
watcher.EnableWgWatcher(ctx, func() {}) go watcher.EnableWgWatcher(ctx, func() {})
time.Sleep(1 * time.Second)
watcher.DisableWgWatcher() watcher.DisableWgWatcher()
watcher.EnableWgWatcher(ctx, func() { go watcher.EnableWgWatcher(ctx, func() {
onDisconnected <- struct{}{} onDisconnected <- struct{}{}
}) })