From 636a0e2475ac28e68ff0b8899a64483f1be5fed4 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 10 Mar 2025 13:32:12 +0100 Subject: [PATCH] [client] Fix engine restart (#3435) - Refactor the network monitoring to handle one event and it after return - In the engine restart cancel the upper layer context and the responsibility of the engine stop will be the upper layer - Before triggering a restart, the engine checks whether the state is already down. This helps avoid unnecessary delayed network restart events. --- client/internal/connect.go | 11 +- client/internal/engine.go | 48 +++----- .../{monitor_bsd.go => check_change_bsd.go} | 16 +-- ...monitor_linux.go => check_change_linux.go} | 7 +- ...tor_windows.go => check_change_windows.go} | 13 +- client/internal/networkmonitor/monitor.go | 113 +++++++++++++++++- .../networkmonitor/monitor_generic.go | 82 ------------- .../internal/networkmonitor/monitor_mobile.go | 17 ++- .../internal/networkmonitor/monitor_test.go | 99 +++++++++++++++ 9 files changed, 254 insertions(+), 152 deletions(-) rename client/internal/networkmonitor/{monitor_bsd.go => check_change_bsd.go} (90%) rename client/internal/networkmonitor/{monitor_linux.go => check_change_linux.go} (93%) rename client/internal/networkmonitor/{monitor_windows.go => check_change_windows.go} (89%) delete mode 100644 client/internal/networkmonitor/monitor_generic.go create mode 100644 client/internal/networkmonitor/monitor_test.go diff --git a/client/internal/connect.go b/client/internal/connect.go index 7cbe47b74..504c88c6f 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -161,7 +161,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan defer c.statusRecorder.ClientStop() operation := func() error { // if context cancelled we not start new backoff cycle - if c.isContextCancelled() { + if c.ctx.Err() != nil { return nil } @@ -379,15 +379,6 @@ func (c *ConnectClient) Stop() error { return nil } -func (c *ConnectClient) isContextCancelled() bool { - select { - case <-c.ctx.Done(): - return true - default: - return false - } -} - // SetNetworkMapPersistence enables or disables network map persistence. // When enabled, the last received network map will be stored and can be retrieved // through the Engine's getLatestNetworkMap method. When disabled, any stored diff --git a/client/internal/engine.go b/client/internal/engine.go index cedf8364c..2693976dd 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1589,16 +1589,19 @@ func (e *Engine) probeTURNs() []relay.ProbeResult { return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns) } +// restartEngine restarts the engine by cancelling the client context func (e *Engine) restartEngine() { - log.Info("restarting engine") - CtxGetState(e.ctx).Set(StatusConnecting) + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() - if err := e.Stop(); err != nil { - log.Errorf("Failed to stop engine: %v", err) + if e.ctx.Err() != nil { + return } + log.Info("restarting engine") + CtxGetState(e.ctx).Set(StatusConnecting) _ = CtxGetState(e.ctx).Wrap(ErrResetConnection) - log.Infof("cancelling client, engine will be recreated") + log.Infof("cancelling client context, engine will be recreated") e.clientCancel() } @@ -1610,34 +1613,17 @@ func (e *Engine) startNetworkMonitor() { e.networkMonitor = networkmonitor.New() go func() { - var mu sync.Mutex - var debounceTimer *time.Timer - - // Start the network monitor with a callback, Start will block until the monitor is stopped, - // a network change is detected, or an error occurs on start up - err := e.networkMonitor.Start(e.ctx, func() { - // This function is called when a network change is detected - mu.Lock() - defer mu.Unlock() - - if debounceTimer != nil { - log.Infof("Network monitor: detected network change, reset debounceTimer") - debounceTimer.Stop() + if err := e.networkMonitor.Listen(e.ctx); err != nil { + if errors.Is(err, context.Canceled) { + log.Infof("network monitor stopped") + return } - - // Set a new timer to debounce rapid network changes - debounceTimer = time.AfterFunc(2*time.Second, func() { - // This function is called after the debounce period - mu.Lock() - defer mu.Unlock() - - log.Infof("Network monitor: detected network change, restarting engine") - e.restartEngine() - }) - }) - if err != nil && !errors.Is(err, networkmonitor.ErrStopped) { - log.Errorf("Network monitor: %v", err) + log.Errorf("network monitor error: %v", err) + return } + + log.Infof("Network monitor: detected network change, restarting engine") + e.restartEngine() }() } diff --git a/client/internal/networkmonitor/monitor_bsd.go b/client/internal/networkmonitor/check_change_bsd.go similarity index 90% rename from client/internal/networkmonitor/monitor_bsd.go rename to client/internal/networkmonitor/check_change_bsd.go index 4dc2c1aa3..bb327a877 100644 --- a/client/internal/networkmonitor/monitor_bsd.go +++ b/client/internal/networkmonitor/check_change_bsd.go @@ -16,7 +16,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) -func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) if err != nil { return fmt.Errorf("failed to open routing socket: %v", err) @@ -28,18 +28,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca } }() - go func() { - <-ctx.Done() - err := unix.Close(fd) - if err != nil && !errors.Is(err, unix.EBADF) { - log.Debugf("Network monitor: closed routing socket: %v", err) - } - }() - for { select { case <-ctx.Done(): - return ErrStopped + return ctx.Err() default: buf := make([]byte, 2048) n, err := unix.Read(fd, buf) @@ -76,11 +68,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca switch msg.Type { case unix.RTM_ADD: log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) - go callback() + return nil case unix.RTM_DELETE: if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) - go callback() + return nil } } } diff --git a/client/internal/networkmonitor/monitor_linux.go b/client/internal/networkmonitor/check_change_linux.go similarity index 93% rename from client/internal/networkmonitor/monitor_linux.go rename to client/internal/networkmonitor/check_change_linux.go index 035be1f09..efd8b5884 100644 --- a/client/internal/networkmonitor/monitor_linux.go +++ b/client/internal/networkmonitor/check_change_linux.go @@ -14,7 +14,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) -func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { if nexthopv4.Intf == nil && nexthopv6.Intf == nil { return errors.New("no interfaces available") } @@ -31,8 +31,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca for { select { case <-ctx.Done(): - return ErrStopped - + return ctx.Err() // handle route changes case route := <-routeChan: // default route and main table @@ -43,12 +42,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca // triggered on added/replaced routes case syscall.RTM_NEWROUTE: log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex) - go callback() return nil case syscall.RTM_DELROUTE: if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) { log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex) - go callback() return nil } } diff --git a/client/internal/networkmonitor/monitor_windows.go b/client/internal/networkmonitor/check_change_windows.go similarity index 89% rename from client/internal/networkmonitor/monitor_windows.go rename to client/internal/networkmonitor/check_change_windows.go index cd48c269d..582865738 100644 --- a/client/internal/networkmonitor/monitor_windows.go +++ b/client/internal/networkmonitor/check_change_windows.go @@ -10,7 +10,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) -func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { routeMonitor, err := systemops.NewRouteMonitor(ctx) if err != nil { return fmt.Errorf("failed to create route monitor: %w", err) @@ -24,20 +24,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca for { select { case <-ctx.Done(): - return ErrStopped + return ctx.Err() case route := <-routeMonitor.RouteUpdates(): if route.Destination.Bits() != 0 { continue } - if routeChanged(route, nexthopv4, nexthopv6, callback) { - break + if routeChanged(route, nexthopv4, nexthopv6) { + return nil } } } } -func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool { +func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool { intf := "" if route.Interface != nil { intf = route.Interface.Name @@ -51,18 +51,15 @@ func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Ne case systemops.RouteModified: // TODO: get routing table to figure out if our route is affected for modified routes log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf) - go callback() return true case systemops.RouteAdded: if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP { log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf) - go callback() return true } case systemops.RouteDeleted: if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP { log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf) - go callback() return true } } diff --git a/client/internal/networkmonitor/monitor.go b/client/internal/networkmonitor/monitor.go index 5475455c6..5896b66b6 100644 --- a/client/internal/networkmonitor/monitor.go +++ b/client/internal/networkmonitor/monitor.go @@ -1,12 +1,27 @@ +//go:build !ios && !android + package networkmonitor import ( "context" "errors" + "fmt" + "net/netip" + "runtime/debug" "sync" + "time" + + "github.com/cenkalti/backoff/v4" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) -var ErrStopped = errors.New("monitor has been stopped") +const ( + debounceTime = 2 * time.Second +) + +var checkChangeFn = checkChange // NetworkMonitor watches for changes in network configuration. type NetworkMonitor struct { @@ -19,3 +34,99 @@ type NetworkMonitor struct { func New() *NetworkMonitor { return &NetworkMonitor{} } + +// Listen begins monitoring network changes. When a change is detected, this function will return without error. +func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) { + nw.mu.Lock() + if nw.cancel != nil { + nw.mu.Unlock() + return errors.New("network monitor already started") + } + + ctx, nw.cancel = context.WithCancel(ctx) + defer nw.cancel() + nw.wg.Add(1) + nw.mu.Unlock() + + defer nw.wg.Done() + + var nexthop4, nexthop6 systemops.Nexthop + + operation := func() error { + var errv4, errv6 error + nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified()) + nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified()) + + if errv4 != nil && errv6 != nil { + return errors.New("failed to get default next hops") + } + + if errv4 == nil { + log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name) + } + if errv6 == nil { + log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name) + } + + // continue if either route was found + return nil + } + + expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx) + + if err := backoff.Retry(operation, expBackOff); err != nil { + return fmt.Errorf("failed to get default next hops: %w", err) + } + + // recover in case sys ops panic + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack()) + } + }() + + event := make(chan struct{}, 1) + go nw.checkChanges(ctx, event, nexthop4, nexthop6) + + // debounce changes + timer := time.NewTimer(0) + timer.Stop() + for { + select { + case <-event: + timer.Reset(debounceTime) + case <-timer.C: + return nil + case <-ctx.Done(): + timer.Stop() + return ctx.Err() + } + } +} + +// Stop stops the network monitor. +func (nw *NetworkMonitor) Stop() { + nw.mu.Lock() + defer nw.mu.Unlock() + + if nw.cancel == nil { + return + } + + nw.cancel() + nw.wg.Wait() +} + +func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) { + for { + if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil { + close(event) + return + } + // prevent blocking + select { + case event <- struct{}{}: + default: + } + } +} diff --git a/client/internal/networkmonitor/monitor_generic.go b/client/internal/networkmonitor/monitor_generic.go deleted file mode 100644 index 19648edba..000000000 --- a/client/internal/networkmonitor/monitor_generic.go +++ /dev/null @@ -1,82 +0,0 @@ -//go:build !ios && !android - -package networkmonitor - -import ( - "context" - "errors" - "fmt" - "net/netip" - "runtime/debug" - - "github.com/cenkalti/backoff/v4" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/routemanager/systemops" -) - -// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns. -func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) { - if ctx.Err() != nil { - return ctx.Err() - } - - nw.mu.Lock() - ctx, nw.cancel = context.WithCancel(ctx) - nw.mu.Unlock() - - nw.wg.Add(1) - defer nw.wg.Done() - - var nexthop4, nexthop6 systemops.Nexthop - - operation := func() error { - var errv4, errv6 error - nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified()) - nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified()) - - if errv4 != nil && errv6 != nil { - return errors.New("failed to get default next hops") - } - - if errv4 == nil { - log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name) - } - if errv6 == nil { - log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name) - } - - // continue if either route was found - return nil - } - - expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx) - - if err := backoff.Retry(operation, expBackOff); err != nil { - return fmt.Errorf("failed to get default next hops: %w", err) - } - - // recover in case sys ops panic - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack()) - } - }() - - if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil { - return fmt.Errorf("check change: %w", err) - } - - return nil -} - -// Stop stops the network monitor. -func (nw *NetworkMonitor) Stop() { - nw.mu.Lock() - defer nw.mu.Unlock() - - if nw.cancel != nil { - nw.cancel() - nw.wg.Wait() - } -} diff --git a/client/internal/networkmonitor/monitor_mobile.go b/client/internal/networkmonitor/monitor_mobile.go index c81fad16c..861dbbe3c 100644 --- a/client/internal/networkmonitor/monitor_mobile.go +++ b/client/internal/networkmonitor/monitor_mobile.go @@ -2,10 +2,21 @@ package networkmonitor -import "context" +import ( + "context" + "fmt" +) -func (nw *NetworkMonitor) Start(context.Context, func()) error { - return nil +type NetworkMonitor struct { +} + +// New creates a new network monitor. +func New() *NetworkMonitor { + return &NetworkMonitor{} +} + +func (nw *NetworkMonitor) Listen(_ context.Context) error { + return fmt.Errorf("network monitor not supported on mobile platforms") } func (nw *NetworkMonitor) Stop() { diff --git a/client/internal/networkmonitor/monitor_test.go b/client/internal/networkmonitor/monitor_test.go new file mode 100644 index 000000000..164686689 --- /dev/null +++ b/client/internal/networkmonitor/monitor_test.go @@ -0,0 +1,99 @@ +package networkmonitor + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +type MocMultiEvent struct { + counter int +} + +func (m *MocMultiEvent) checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + if m.counter == 0 { + <-ctx.Done() + return ctx.Err() + } + + time.Sleep(1 * time.Second) + m.counter-- + return nil +} + +func TestNetworkMonitor_Close(t *testing.T) { + checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + <-ctx.Done() + return ctx.Err() + } + nw := New() + + var resErr error + done := make(chan struct{}) + go func() { + resErr = nw.Listen(context.Background()) + close(done) + }() + + time.Sleep(1 * time.Second) // wait for the goroutine to start + nw.Stop() + + <-done + if !errors.Is(resErr, context.Canceled) { + t.Errorf("unexpected error: %v", resErr) + } +} + +func TestNetworkMonitor_Event(t *testing.T) { + checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + timeout, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timeout.Done(): + return nil + } + } + nw := New() + defer nw.Stop() + + var resErr error + done := make(chan struct{}) + go func() { + resErr = nw.Listen(context.Background()) + close(done) + }() + + <-done + if !errors.Is(resErr, nil) { + t.Errorf("unexpected error: %v", nil) + } +} + +func TestNetworkMonitor_MultiEvent(t *testing.T) { + eventsRepeated := 3 + me := &MocMultiEvent{counter: eventsRepeated} + checkChangeFn = me.checkChange + + nw := New() + defer nw.Stop() + + done := make(chan struct{}) + started := time.Now() + go func() { + if resErr := nw.Listen(context.Background()); resErr != nil { + t.Errorf("unexpected error: %v", resErr) + } + close(done) + }() + + <-done + expectedResponseTime := time.Duration(eventsRepeated)*time.Second + debounceTime + if time.Since(started) < expectedResponseTime { + t.Errorf("unexpected duration: %v", time.Since(started)) + } +}