Refactor network monitor to wait for stop (#1992)

This commit is contained in:
Viktor Liu 2024-05-17 09:43:18 +02:00 committed by GitHub
parent a5811a2d7d
commit bd58eea8ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 74 additions and 52 deletions

View File

@ -133,7 +133,7 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64 networkSerial uint64
networkWatcher *networkmonitor.NetworkWatcher networkMonitor *networkmonitor.NetworkMonitor
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error) sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server sshServer nbssh.Server
@ -212,7 +212,6 @@ func NewEngineWithProbes(
networkSerial: 0, networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer, sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
networkWatcher: networkmonitor.New(),
mgmProbe: mgmProbe, mgmProbe: mgmProbe,
signalProbe: signalProbe, signalProbe: signalProbe,
relayProbe: relayProbe, relayProbe: relayProbe,
@ -229,7 +228,10 @@ func (e *Engine) Stop() error {
} }
// stopping network monitor first to avoid starting the engine again // stopping network monitor first to avoid starting the engine again
e.networkWatcher.Stop() if e.networkMonitor != nil {
e.networkMonitor.Stop()
}
log.Info("Network monitor: stopped")
err := e.removeAllPeers() err := e.removeAllPeers()
if err != nil { if err != nil {
@ -344,20 +346,8 @@ func (e *Engine) Start() error {
e.receiveManagementEvents() e.receiveManagementEvents()
e.receiveProbeEvents() e.receiveProbeEvents()
if e.config.NetworkMonitor { // starting network monitor at the very last to avoid disruptions
// starting network monitor at the very last to avoid disruptions e.startNetworkMonitor()
go e.networkWatcher.Start(e.ctx, func() {
log.Infof("Network monitor detected network change, restarting engine")
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
})
} else {
log.Infof("Network monitor is disabled, not starting")
}
return nil return nil
} }
@ -1399,3 +1389,26 @@ func (e *Engine) probeSTUNs() []relay.ProbeResult {
func (e *Engine) probeTURNs() []relay.ProbeResult { func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs) return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
} }
func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor {
log.Infof("Network monitor is disabled, not starting")
return
}
e.networkMonitor = networkmonitor.New()
go func() {
err := e.networkMonitor.Start(e.ctx, func() {
log.Infof("Network monitor detected network change, restarting engine")
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)
}
}()
}

View File

@ -2,14 +2,20 @@ package networkmonitor
import ( import (
"context" "context"
"errors"
"sync"
) )
// NetworkWatcher watches for changes in network configuration. var ErrStopped = errors.New("monitor has been stopped")
type NetworkWatcher struct {
// NetworkMonitor watches for changes in network configuration.
type NetworkMonitor struct {
cancel context.CancelFunc cancel context.CancelFunc
wg sync.WaitGroup
mu sync.Mutex
} }
// New creates a new network monitor. // New creates a new network monitor.
func New() *NetworkWatcher { func New() *NetworkMonitor {
return &NetworkWatcher{} return &NetworkMonitor{}
} }

View File

@ -31,7 +31,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ErrStopped
default: default:
buf := make([]byte, 2048) buf := make([]byte, 2048)
n, err := unix.Read(fd, buf) n, err := unix.Read(fd, buf)
@ -63,7 +63,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
} }
log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name) log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name)
callback() go callback()
// handle route changes // handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE: case unix.RTM_ADD, syscall.RTM_DELETE:
@ -84,11 +84,11 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
switch msg.Type { switch msg.Type {
case unix.RTM_ADD: case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
callback() go callback()
case unix.RTM_DELETE: case unix.RTM_DELETE:
if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 { if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
callback() go callback()
} }
} }
} }

View File

@ -5,6 +5,7 @@ package networkmonitor
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net" "net"
"net/netip" "net/netip"
"runtime/debug" "runtime/debug"
@ -15,20 +16,18 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
) )
// Start begins watching for network changes and calls the callback function and stops when a change is detected. // Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
func (nw *NetworkWatcher) Start(ctx context.Context, callback func()) { func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) {
if nw.cancel != nil {
log.Warn("Network monitor: already running, stopping previous watcher")
nw.Stop()
}
if ctx.Err() != nil { if ctx.Err() != nil {
log.Info("Network monitor: not starting, context is already cancelled") return ctx.Err()
return
} }
nw.mu.Lock()
ctx, nw.cancel = context.WithCancel(ctx) ctx, nw.cancel = context.WithCancel(ctx)
defer nw.Stop() nw.mu.Unlock()
nw.wg.Add(1)
defer nw.wg.Done()
var nexthop4, nexthop6 netip.Addr var nexthop4, nexthop6 netip.Addr
var intf4, intf6 *net.Interface var intf4, intf6 *net.Interface
@ -56,27 +55,30 @@ func (nw *NetworkWatcher) Start(ctx context.Context, callback func()) {
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx) expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil { if err := backoff.Retry(operation, expBackOff); err != nil {
log.Errorf("Network monitor: failed to get default next hops: %v", err) return fmt.Errorf("failed to get default next hops: %w", err)
return
} }
// recover in case sys ops panic // recover in case sys ops panic
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
log.Errorf("Network monitor: panic occurred: %v, stack trace: %s", r, string(debug.Stack())) err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
} }
}() }()
if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil && !errors.Is(err, context.Canceled) { if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil {
log.Errorf("Network monitor: failed to start: %v", err) return fmt.Errorf("check change: %w", err)
} }
return nil
} }
// Stop stops the network monitor. // Stop stops the network monitor.
func (nw *NetworkWatcher) Stop() { func (nw *NetworkMonitor) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel != nil { if nw.cancel != nil {
nw.cancel() nw.cancel()
nw.cancel = nil nw.wg.Wait()
log.Info("Network monitor: stopped")
} }
} }

View File

@ -36,7 +36,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ErrStopped
// handle interface state changes // handle interface state changes
case update := <-linkChan: case update := <-linkChan:
@ -47,12 +47,12 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
switch update.Header.Type { switch update.Header.Type {
case syscall.RTM_DELLINK: case syscall.RTM_DELLINK:
log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name) log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name)
callback() go callback()
return nil return nil
case syscall.RTM_NEWLINK: case syscall.RTM_NEWLINK:
if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown { if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown {
log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name) log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name)
callback() go callback()
return nil return nil
} }
} }
@ -67,12 +67,12 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
// triggered on added/replaced routes // triggered on added/replaced routes
case syscall.RTM_NEWROUTE: case syscall.RTM_NEWROUTE:
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex) log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
callback() go callback()
return nil return nil
case syscall.RTM_DELROUTE: case syscall.RTM_DELROUTE:
if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) { if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) {
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex) log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
callback() go callback()
return nil return nil
} }
} }

View File

@ -4,8 +4,9 @@ package networkmonitor
import "context" import "context"
func (nw *NetworkWatcher) Start(context.Context, func()) { func (nw *NetworkMonitor) Start(context.Context, func()) error {
return nil
} }
func (nw *NetworkWatcher) Stop() { func (nw *NetworkMonitor) Stop() {
} }

View File

@ -48,10 +48,10 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ErrStopped
case <-ticker.C: case <-ticker.C:
if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) { if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) {
callback() go callback()
return nil return nil
} }
} }