mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-04 01:41:17 +01:00
Fix WireGuard watcher related issues - Fix race handling between TURN and Relayed reconnection - Move the WgWatcher logic to separate struct - Handle timeouts in a more defensive way - Fix initial Relay client reconnection to the home server
155 lines
3.9 KiB
Go
155 lines
3.9 KiB
Go
package peer
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
|
)
|
|
|
|
const (
|
|
wgHandshakePeriod = 3 * time.Minute
|
|
)
|
|
|
|
var (
|
|
wgHandshakeOvertime = 30 * time.Second // allowed delay in network
|
|
checkPeriod = wgHandshakePeriod + wgHandshakeOvertime
|
|
)
|
|
|
|
type WGInterfaceStater interface {
|
|
GetStats(key string) (configurer.WGStats, error)
|
|
}
|
|
|
|
type WGWatcher struct {
|
|
log *log.Entry
|
|
wgIfaceStater WGInterfaceStater
|
|
peerKey string
|
|
|
|
ctx context.Context
|
|
ctxCancel context.CancelFunc
|
|
ctxLock sync.Mutex
|
|
waitGroup sync.WaitGroup
|
|
}
|
|
|
|
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string) *WGWatcher {
|
|
return &WGWatcher{
|
|
log: log,
|
|
wgIfaceStater: wgIfaceStater,
|
|
peerKey: peerKey,
|
|
}
|
|
}
|
|
|
|
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
|
|
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
|
|
w.log.Debugf("enable WireGuard watcher")
|
|
w.ctxLock.Lock()
|
|
defer w.ctxLock.Unlock()
|
|
|
|
if w.ctx != nil && w.ctx.Err() == nil {
|
|
w.log.Errorf("WireGuard watcher already enabled")
|
|
return
|
|
}
|
|
|
|
ctx, ctxCancel := context.WithCancel(parentCtx)
|
|
w.ctx = ctx
|
|
w.ctxCancel = ctxCancel
|
|
|
|
initialHandshake, err := w.wgState()
|
|
if err != nil {
|
|
w.log.Warnf("failed to read initial wg stats: %v", err)
|
|
}
|
|
|
|
w.waitGroup.Add(1)
|
|
go w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
|
|
}
|
|
|
|
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
|
|
func (w *WGWatcher) DisableWgWatcher() {
|
|
w.ctxLock.Lock()
|
|
defer w.ctxLock.Unlock()
|
|
|
|
if w.ctxCancel == nil {
|
|
return
|
|
}
|
|
|
|
w.log.Debugf("disable WireGuard watcher")
|
|
|
|
w.ctxCancel()
|
|
w.ctxCancel = nil
|
|
w.waitGroup.Wait()
|
|
}
|
|
|
|
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
|
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
|
|
w.log.Infof("WireGuard watcher started")
|
|
defer w.waitGroup.Done()
|
|
|
|
timer := time.NewTimer(wgHandshakeOvertime)
|
|
defer timer.Stop()
|
|
defer ctxCancel()
|
|
|
|
lastHandshake := initialHandshake
|
|
|
|
for {
|
|
select {
|
|
case <-timer.C:
|
|
handshake, ok := w.handshakeCheck(lastHandshake)
|
|
if !ok {
|
|
onDisconnectedFn()
|
|
return
|
|
}
|
|
lastHandshake = *handshake
|
|
|
|
resetTime := time.Until(handshake.Add(checkPeriod))
|
|
timer.Reset(resetTime)
|
|
|
|
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
|
case <-ctx.Done():
|
|
w.log.Infof("WireGuard watcher stopped")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// handshakeCheck checks the WireGuard handshake and return the new handshake time if it is different from the previous one
|
|
func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
|
|
handshake, err := w.wgState()
|
|
if err != nil {
|
|
w.log.Errorf("failed to read wg stats: %v", err)
|
|
return nil, false
|
|
}
|
|
|
|
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
|
|
|
|
// the current know handshake did not change
|
|
if handshake.Equal(lastHandshake) {
|
|
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
|
return nil, false
|
|
}
|
|
|
|
// in case if the machine is suspended, the handshake time will be in the past
|
|
if handshake.Add(checkPeriod).Before(time.Now()) {
|
|
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
|
return nil, false
|
|
}
|
|
|
|
// error handling for handshake time in the future
|
|
if handshake.After(time.Now()) {
|
|
w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake)
|
|
return nil, false
|
|
}
|
|
|
|
return &handshake, true
|
|
}
|
|
|
|
func (w *WGWatcher) wgState() (time.Time, error) {
|
|
wgState, err := w.wgIfaceStater.GetStats(w.peerKey)
|
|
if err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
return wgState.LastHandshake, nil
|
|
}
|