Fix wg handshake checking (#2590)

* Fix wg handshake checking

* Ensure in the initial handshake reading

* Change the handshake period
This commit is contained in:
Zoltan Papp 2024-09-12 19:18:02 +02:00 committed by GitHub
parent 33c9b2d989
commit ab892b8cf9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 30 deletions

View File

@ -484,11 +484,11 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
// switch back to relay connection // switch back to relay connection
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
conn.log.Debugf("ICE disconnected, set Relay to active connection") conn.log.Debugf("ICE disconnected, set Relay to active connection")
conn.workerRelay.EnableWgWatcher(conn.ctx)
err := conn.configureWGEndpoint(conn.endpointRelay) err := conn.configureWGEndpoint(conn.endpointRelay)
if err != nil { if err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err) conn.log.Errorf("failed to switch to relay conn: %v", err)
} }
conn.workerRelay.EnableWgWatcher(conn.ctx)
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = connPriorityRelay
} }
@ -551,6 +551,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
} }
} }
conn.workerRelay.EnableWgWatcher(conn.ctx)
err = conn.configureWGEndpoint(endpointUdpAddr) err = conn.configureWGEndpoint(endpointUdpAddr)
if err != nil { if err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
@ -560,7 +561,6 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
return return
} }
wgConfigWorkaround() wgConfigWorkaround()
conn.workerRelay.EnableWgWatcher(conn.ctx)
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil { if err := conn.wgProxyRelay.CloseConn(); err != nil {

View File

@ -14,7 +14,7 @@ import (
) )
var ( var (
wgHandshakePeriod = 2 * time.Minute wgHandshakePeriod = 3 * time.Minute
wgHandshakeOvertime = 30 * time.Second wgHandshakeOvertime = 30 * time.Second
) )
@ -109,7 +109,7 @@ func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
} }
ctx, ctxCancel := context.WithCancel(ctx) ctx, ctxCancel := context.WithCancel(ctx)
go w.wgStateCheck(ctx) w.wgStateCheck(ctx)
w.ctxWgWatch = ctx w.ctxWgWatch = ctx
w.ctxCancelWgWatch = ctxCancel w.ctxCancelWgWatch = ctxCancel
@ -157,37 +157,50 @@ func (w *WorkerRelay) CloseConn() {
} }
} }
// wgStateCheck help to check the state of the wireguard handshake and relay connection // wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WorkerRelay) wgStateCheck(ctx context.Context) { func (w *WorkerRelay) wgStateCheck(ctx context.Context) {
timer := time.NewTimer(wgHandshakeOvertime) lastHandshake, err := w.wgState()
defer timer.Stop() if err != nil {
expected := wgHandshakeOvertime w.log.Errorf("failed to read wg stats: %v", err)
for { lastHandshake = time.Time{}
select { }
case <-timer.C:
lastHandshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
continue
}
w.log.Tracef("last handshake: %v", lastHandshake)
if time.Since(lastHandshake) > expected { go func(lastHandshake time.Time) {
w.log.Infof("Wireguard handshake timed out, closing relay connection") timer := time.NewTimer(wgHandshakeOvertime)
w.relayLock.Lock() defer timer.Stop()
_ = w.relayedConn.Close()
w.relayLock.Unlock() for {
w.callBacks.OnDisconnected() select {
case <-timer.C:
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
timer.Reset(wgHandshakeOvertime)
continue
}
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
if handshake.Equal(lastHandshake) {
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.callBacks.OnDisconnected()
return
}
resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
lastHandshake = handshake
timer.Reset(resetTime)
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return return
} }
resetTime := time.Until(lastHandshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
timer.Reset(resetTime)
expected = wgHandshakePeriod
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return
} }
} }(lastHandshake)
} }
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool { func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {