diff --git a/client/internal/conn_mgr.go b/client/internal/conn_mgr.go index f7b1f6a05..88b47c511 100644 --- a/client/internal/conn_mgr.go +++ b/client/internal/conn_mgr.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer/dispatcher" "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/route" ) // ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections. @@ -97,6 +98,16 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er } } +// UpdateRouteHAMap updates the route HA mappings in the lazy connection manager +func (e *ConnMgr) UpdateRouteHAMap(haMap route.HAMap) { + if !e.isStartedWithLazyMgr() { + log.Debugf("lazy connection manager is not started, skipping UpdateRouteHAMap") + return + } + + e.lazyConnMgr.UpdateRouteHAMap(haMap) +} + // SetExcludeList sets the list of peer IDs that should always have permanent connections. func (e *ConnMgr) SetExcludeList(peerIDs map[string]bool) { if e.lazyConnMgr == nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 0962e9004..034057fe0 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1007,7 +1007,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { // apply routes first, route related actions might depend on routing being enabled routes := toRoutes(networkMap.GetRoutes()) - if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { + serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes) + + // lazy mgr needs to be aware of which routes are available before they are applied + if e.connMgr != nil { + e.connMgr.UpdateRouteHAMap(clientRoutes) + log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes)) + } + + if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil { log.Errorf("failed to update routes: %v", err) } @@ -1067,7 +1075,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } // must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store - excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers()) + excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, networkMap.GetRemotePeers()) e.connMgr.SetExcludeList(excludedLazyPeers) e.networkSerial = serial @@ -1933,18 +1941,8 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewal return forwardingRules, nberrors.FormatErrorOrNil(merr) } -func (e *Engine) toExcludedLazyPeers(routes []*route.Route, rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) map[string]bool { +func (e *Engine) toExcludedLazyPeers(rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) map[string]bool { excludedPeers := make(map[string]bool) - for _, r := range routes { - if r.Peer == "" { - continue - } - if !excludedPeers[r.Peer] { - log.Infof("exclude router peer from lazy connection: %s", r.Peer) - excludedPeers[r.Peer] = true - } - } - for _, r := range rules { ip := r.TranslatedAddress for _, p := range peers { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 6bdd9ae3c..a4470e0ec 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -643,12 +643,12 @@ func TestEngine_Sync(t *testing.T) { func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { testCases := []struct { - name string - inputErr error - networkMap *mgmtProto.NetworkMap - expectedLen int - expectedRoutes []*route.Route - expectedSerial uint64 + name string + inputErr error + networkMap *mgmtProto.NetworkMap + expectedLen int + expectedClientRoutes route.HAMap + expectedSerial uint64 }{ { name: "Routes Config Should Be Passed To Manager", @@ -676,22 +676,26 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { }, }, expectedLen: 2, - expectedRoutes: []*route.Route{ - { - ID: "a", - Network: netip.MustParsePrefix("192.168.0.0/24"), - NetID: "n1", - Peer: "p1", - NetworkType: 1, - Masquerade: false, + expectedClientRoutes: route.HAMap{ + "n1|192.168.0.0/24": []*route.Route{ + { + ID: "a", + Network: netip.MustParsePrefix("192.168.0.0/24"), + NetID: "n1", + Peer: "p1", + NetworkType: 1, + Masquerade: false, + }, }, - { - ID: "b", - Network: netip.MustParsePrefix("192.168.1.0/24"), - NetID: "n2", - Peer: "p1", - NetworkType: 1, - Masquerade: false, + "n2|192.168.1.0/24": []*route.Route{ + { + ID: "b", + Network: netip.MustParsePrefix("192.168.1.0/24"), + NetID: "n2", + Peer: "p1", + NetworkType: 1, + Masquerade: false, + }, }, }, expectedSerial: 1, @@ -704,9 +708,9 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { RemotePeersIsEmpty: false, Routes: nil, }, - expectedLen: 0, - expectedRoutes: []*route.Route{}, - expectedSerial: 1, + expectedLen: 0, + expectedClientRoutes: nil, + expectedSerial: 1, }, { name: "Error Shouldn't Break Engine", @@ -717,9 +721,9 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { RemotePeersIsEmpty: false, Routes: nil, }, - expectedLen: 0, - expectedRoutes: []*route.Route{}, - expectedSerial: 1, + expectedLen: 0, + expectedClientRoutes: nil, + expectedSerial: 1, }, } @@ -762,16 +766,29 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { engine.wgInterface, err = iface.NewWGIFace(opts) assert.NoError(t, err, "shouldn't return error") input := struct { - inputSerial uint64 - inputRoutes []*route.Route + inputSerial uint64 + clientRoutes route.HAMap }{} mockRouteManager := &routemanager.MockManager{ - UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { + UpdateRoutesFunc: func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error { input.inputSerial = updateSerial - input.inputRoutes = newRoutes + input.clientRoutes = clientRoutes return testCase.inputErr }, + ClassifyRoutesFunc: func(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) { + if len(newRoutes) == 0 { + return nil, nil + } + + // Classify all routes as client routes (not matching our public key) + clientRoutes := make(route.HAMap) + for _, r := range newRoutes { + haID := r.GetHAUniqueID() + clientRoutes[haID] = append(clientRoutes[haID], r) + } + return nil, clientRoutes + }, } engine.routeManager = mockRouteManager @@ -789,8 +806,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { err = engine.updateNetworkMap(testCase.networkMap) assert.NoError(t, err, "shouldn't return error") assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match") - assert.Len(t, input.inputRoutes, testCase.expectedLen, "clientRoutes len should match") - assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "clientRoutes should match") + assert.Len(t, input.clientRoutes, testCase.expectedLen, "clientRoutes len should match") + assert.Equal(t, testCase.expectedClientRoutes, input.clientRoutes, "clientRoutes should match") }) } } @@ -951,7 +968,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { assert.NoError(t, err, "shouldn't return error") mockRouteManager := &routemanager.MockManager{ - UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { + UpdateRoutesFunc: func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error { return nil }, } diff --git a/client/internal/lazyconn/manager/manager.go b/client/internal/lazyconn/manager/manager.go index af12a73e4..15979d553 100644 --- a/client/internal/lazyconn/manager/manager.go +++ b/client/internal/lazyconn/manager/manager.go @@ -6,6 +6,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn/activity" @@ -13,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/dispatcher" peerid "github.com/netbirdio/netbird/client/internal/peer/id" "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/route" ) const ( @@ -37,6 +39,7 @@ type Config struct { // - Managing inactivity monitors for lazy connections (based on peer disconnection events) // - Maintaining a list of excluded peers that should always have permanent connections // - Handling connection establishment based on peer signaling +// - Managing route HA groups and activating all peers in a group when one peer is activated type Manager struct { peerStore *peerstore.Store connStateDispatcher *dispatcher.ConnectionDispatcher @@ -51,6 +54,11 @@ type Manager struct { activityManager *activity.Manager inactivityMonitors map[peerid.ConnID]*inactivity.Monitor + // Route HA group management + peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to + haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group + routesMu sync.RWMutex // protects route mappings + cancel context.CancelFunc onInactive chan peerid.ConnID } @@ -66,6 +74,8 @@ func NewManager(config Config, peerStore *peerstore.Store, wgIface lazyconn.WGIf excludes: make(map[string]lazyconn.PeerConfig), activityManager: activity.NewManager(wgIface), inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor), + peerToHAGroups: make(map[string][]route.HAUniqueID), + haGroupToPeers: make(map[route.HAUniqueID][]string), onInactive: make(chan peerid.ConnID), } @@ -87,6 +97,41 @@ func NewManager(config Config, peerStore *peerstore.Store, wgIface lazyconn.WGIf return m } +// UpdateRouteHAMap updates the HA group mappings for routes +// This should be called when route configuration changes +func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) { + m.routesMu.Lock() + defer m.routesMu.Unlock() + + maps.Clear(m.peerToHAGroups) + maps.Clear(m.haGroupToPeers) + + for haUniqueID, routes := range haMap { + var peers []string + + peerSet := make(map[string]bool) + for _, r := range routes { + if !peerSet[r.Peer] { + peerSet[r.Peer] = true + peers = append(peers, r.Peer) + } + } + + if len(peers) <= 1 { + continue + } + + m.haGroupToPeers[haUniqueID] = peers + + for _, peerID := range peers { + m.peerToHAGroups[peerID] = append(m.peerToHAGroups[peerID], haUniqueID) + } + } + + log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes", + len(m.haGroupToPeers), len(m.peerToHAGroups)) +} + // Start starts the manager and listens for peer activity and inactivity events func (m *Manager) Start(ctx context.Context) { defer m.close() @@ -209,25 +254,47 @@ func (m *Manager) RemovePeer(peerID string) { } // ActivatePeer activates a peer connection when a signal message is received +// Also activates all peers in the same HA groups as this peer func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) { m.managedPeersMu.Lock() defer m.managedPeersMu.Unlock() + cfg, mp := m.getPeerForActivation(peerID) + if cfg == nil { + return false + } + if !m.activateSinglePeer(ctx, cfg, mp) { + return false + } + + m.activateHAGroupPeers(ctx, peerID) + + return true +} + +// getPeerForActivation checks if a peer can be activated and returns the necessary structs +// Returns nil values if the peer should be skipped +func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *managedPeer) { cfg, ok := m.managedPeers[peerID] if !ok { - return false + return nil, nil } mp, ok := m.managedPeersByConnID[cfg.PeerConnID] if !ok { - return false + return nil, nil } // signal messages coming continuously after success activation, with this avoid the multiple activation if mp.expectedWatcher == watcherInactivity { - return false + return nil, nil } + return cfg, mp +} + +// activateSinglePeer activates a single peer (internal method) +func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConfig, mp *managedPeer) bool { mp.expectedWatcher = watcherInactivity m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID) @@ -238,12 +305,53 @@ func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) return false } - mp.peerCfg.Log.Infof("starting inactivity monitor") + cfg.Log.Infof("starting inactivity monitor") go im.Start(ctx, m.onInactive) return true } +// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to +func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) { + m.routesMu.RLock() + haGroups := m.peerToHAGroups[triggerPeerID] + m.routesMu.RUnlock() + + if len(haGroups) == 0 { + log.Debugf("peer %s is not part of any HA groups", triggerPeerID) + return + } + + activatedCount := 0 + for _, haGroup := range haGroups { + m.routesMu.RLock() + peers := m.haGroupToPeers[haGroup] + m.routesMu.RUnlock() + + for _, peerID := range peers { + if peerID == triggerPeerID { + continue + } + + cfg, mp := m.getPeerForActivation(peerID) + if cfg == nil { + continue + } + + if m.activateSinglePeer(ctx, cfg, mp) { + activatedCount++ + cfg.Log.Infof("activated peer as part of HA group %s (triggered by %s)", haGroup, triggerPeerID) + m.peerStore.PeerConnOpen(ctx, cfg.PublicKey) + } + } + } + + if activatedCount > 0 { + log.Infof("activated %d additional peers in HA groups for peer %s (groups: %v)", + activatedCount, triggerPeerID, haGroups) + } +} + func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error { if _, ok := m.managedPeers[peerCfg.PublicKey]; ok { peerCfg.Log.Warnf("peer already managed") @@ -297,6 +405,13 @@ func (m *Manager) close() { m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor) m.managedPeers = make(map[string]*lazyconn.PeerConfig) m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer) + + // Clear route mappings + m.routesMu.Lock() + m.peerToHAGroups = make(map[string][]route.HAUniqueID) + m.haGroupToPeers = make(map[route.HAUniqueID][]string) + m.routesMu.Unlock() + log.Infof("lazy connection manager closed") } @@ -317,10 +432,11 @@ func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) mp.peerCfg.Log.Infof("detected peer activity") - mp.expectedWatcher = watcherInactivity + if !m.activateSinglePeer(ctx, mp.peerCfg, mp) { + return + } - mp.peerCfg.Log.Infof("starting inactivity monitor") - go m.inactivityMonitors[peerConnID].Start(ctx, m.onInactive) + m.activateHAGroupPeers(ctx, mp.peerCfg.PublicKey) m.peerStore.PeerConnOpen(ctx, mp.peerCfg.PublicKey) } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index ed2f1fe47..abeafd757 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -576,6 +576,10 @@ func (d *Status) FinishPeerListModifications() { d.mux.Unlock() d.notifyPeerListChanged() + + for key := range d.peers { + d.notifyPeerStateChangeListeners(key) + } } func (d *Status) SubscribeToPeerStateChanges(ctx context.Context, peerID string) *StatusChangeSubscription { diff --git a/client/internal/routemanager/client/client.go b/client/internal/routemanager/client/client.go index 5582591a9..6e3cf61c9 100644 --- a/client/internal/routemanager/client/client.go +++ b/client/internal/routemanager/client/client.go @@ -38,9 +38,9 @@ const ( ) type routerPeerStatus struct { - connected bool - relayed bool - latency time.Duration + status peer.ConnStatus + relayed bool + latency time.Duration } type RoutesUpdate struct { @@ -68,6 +68,7 @@ type WatcherConfig struct { // Watcher watches route and peer changes and updates allowed IPs accordingly. // Once stopped, it cannot be reused. +// The methods are not thread-safe and should be synchronized externally. type Watcher struct { ctx context.Context cancel context.CancelFunc @@ -78,6 +79,7 @@ type Watcher struct { peerStateUpdate chan struct{} routePeersNotifiers map[string]chan struct{} // map of peer key to channel for peer state changes currentChosen *route.Route + currentChosenStatus *routerPeerStatus handler RouteHandler updateSerial uint64 } @@ -95,6 +97,7 @@ func NewWatcher(config WatcherConfig) *Watcher { routeUpdate: make(chan RoutesUpdate), peerStateUpdate: make(chan struct{}), handler: config.Handler, + currentChosenStatus: nil, } return client } @@ -108,9 +111,9 @@ func (w *Watcher) getRouterPeerStatuses() map[route.ID]routerPeerStatus { continue } routePeerStatuses[r.ID] = routerPeerStatus{ - connected: peerStatus.ConnStatus == peer.StatusConnected, - relayed: peerStatus.Relayed, - latency: peerStatus.Latency, + status: peerStatus.ConnStatus, + relayed: peerStatus.Relayed, + latency: peerStatus.Latency, } } return routePeerStatuses @@ -121,15 +124,17 @@ func (w *Watcher) getRouterPeerStatuses() map[route.ID]routerPeerStatus { // preference for non-relayed and direct connections. // // It follows these prioritization rules: -// * Connected peers: Only routes with connected peers are considered. +// * Connection status: Both connected and idle peers are considered, but connected peers always take precedence. +// * Idle peer penalty: Idle peers receive a significant score penalty to ensure any connected peer is preferred. // * Metric: Routes with lower metrics (better) are prioritized. // * Non-relayed: Routes without relays are preferred. // * Latency: Routes with lower latency are prioritized. +// * Allowed IPs: Idle peers can still receive allowed IPs to enable lazy connection triggering. // * we compare the current score + 10ms to the chosen score to avoid flapping between routes // * Stability: In case of equal scores, the currently active route (if any) is maintained. // // It returns the ID of the selected optimal route. -func (w *Watcher) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID { +func (w *Watcher) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) (route.ID, routerPeerStatus) { var chosen route.ID chosenScore := float64(0) currScore := float64(0) @@ -139,10 +144,13 @@ func (w *Watcher) getBestRouteFromStatuses(routePeerStatuses map[route.ID]router currID = w.currentChosen.ID } + var chosenStatus routerPeerStatus + for _, r := range w.routes { tempScore := float64(0) peerStatus, found := routePeerStatuses[r.ID] - if !found || !peerStatus.connected { + // connecting status equals disconnected: no wireguard endpoint to assign allowed IPs to + if !found || peerStatus.status == peer.StatusConnecting { continue } @@ -155,8 +163,8 @@ func (w *Watcher) getBestRouteFromStatuses(routePeerStatuses map[route.ID]router latency := 999 * time.Millisecond if peerStatus.latency != 0 { latency = peerStatus.latency - } else { - log.Tracef("peer %s has 0 latency, range %s", r.Peer, w.handler) + } else if !peerStatus.relayed && peerStatus.status != peer.StatusIdle { + log.Tracef("peer %s has 0 latency: [%v]", r.Peer, w.handler) } // avoid negative tempScore on the higher latency calculation @@ -167,17 +175,24 @@ func (w *Watcher) getBestRouteFromStatuses(routePeerStatuses map[route.ID]router // higher latency is worse score tempScore += 1 - latency.Seconds() + // apply significant penalty for idle peers to ensure connected peers always take precedence + if peerStatus.status == peer.StatusConnected { + tempScore += 100_000 + } + if !peerStatus.relayed { tempScore++ } if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") { chosen = r.ID + chosenStatus = peerStatus chosenScore = tempScore } if chosen == "" && currID == "" { chosen = r.ID + chosenStatus = peerStatus chosenScore = tempScore } @@ -204,13 +219,13 @@ func (w *Watcher) getBestRouteFromStatuses(routePeerStatuses map[route.ID]router peers = append(peers, r.Peer) } - log.Infof("network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", w.handler, peers) + log.Infof("network [%v] has not been assigned a routing peer as no peers from the list %s are currently available", w.handler, peers) case chosen != currID: // we compare the current score + 10ms to the chosen score to avoid flapping between routes if currScore != 0 && currScore+0.01 > chosenScore { log.Debugf("keeping current routing peer %s for [%v]: the score difference with latency is less than 0.01(10ms): current: %f, new: %f", w.currentChosen.Peer, w.handler, currScore, chosenScore) - return currID + return currID, chosenStatus } var p string if rt := w.routes[chosen]; rt != nil { @@ -219,7 +234,7 @@ func (w *Watcher) getBestRouteFromStatuses(routePeerStatuses map[route.ID]router log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, w.handler) } - return chosen + return chosen, chosenStatus } func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) { @@ -279,10 +294,28 @@ func (w *Watcher) removeAllowedIPs(route *route.Route, rsn reason) error { return nil } +// shouldSkipRecalculation checks if we can skip route recalculation for the same route without status changes +func (w *Watcher) shouldSkipRecalculation(newChosenID route.ID, newStatus routerPeerStatus) bool { + if w.currentChosen == nil { + return false + } + + isSameRoute := w.currentChosen.ID == newChosenID && w.currentChosen.Equal(w.routes[newChosenID]) + if !isSameRoute { + return false + } + + if w.currentChosenStatus != nil { + return w.currentChosenStatus.status == newStatus.status + } + + return true +} + func (w *Watcher) recalculateRoutes(rsn reason) error { routerPeerStatuses := w.getRouterPeerStatuses() - newChosenID := w.getBestRouteFromStatuses(routerPeerStatuses) + newChosenID, newStatus := w.getBestRouteFromStatuses(routerPeerStatuses) // If no route is chosen, remove the route from the peer if newChosenID == "" { @@ -295,13 +328,13 @@ func (w *Watcher) recalculateRoutes(rsn reason) error { } w.currentChosen = nil + w.currentChosenStatus = nil return nil } - // If the chosen route is the same as the current route, do nothing - if w.currentChosen != nil && w.currentChosen.ID == newChosenID && - w.currentChosen.Equal(w.routes[newChosenID]) { + // If we can skip recalculation for the same route without changes, do nothing + if w.shouldSkipRecalculation(newChosenID, newStatus) { return nil } @@ -316,8 +349,12 @@ func (w *Watcher) recalculateRoutes(rsn reason) error { if err := w.addAllowedIPs(newChosenRoute); err != nil { return fmt.Errorf("add new: %w", err) } + if newStatus.status != peer.StatusIdle { + w.connectEvent(newChosenRoute) + } w.currentChosen = newChosenRoute + w.currentChosenStatus = &newStatus return nil } @@ -497,6 +534,7 @@ func (w *Watcher) Stop() { if err := w.removeAllowedIPs(w.currentChosen, reasonShutdown); err != nil { log.Errorf("Failed to remove routes for [%v]: %v", w.handler, err) } + w.currentChosenStatus = nil } func HandlerFromRoute( diff --git a/client/internal/routemanager/client/client_bench_test.go b/client/internal/routemanager/client/client_bench_test.go new file mode 100644 index 000000000..1fc41ec33 --- /dev/null +++ b/client/internal/routemanager/client/client_bench_test.go @@ -0,0 +1,155 @@ +package client + +import ( + "context" + "fmt" + "net/netip" + "sync" + "testing" + "time" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/route" +) + +type benchmarkTier struct { + name string + peers int + routes int + haPeersPerGroup int +} + +var benchmarkTiers = []benchmarkTier{ + {"Small", 100, 50, 4}, + {"Medium", 1000, 200, 16}, + {"Large", 5000, 500, 32}, +} + +type mockRouteHandler struct { + network string +} + +func (m *mockRouteHandler) String() string { return m.network } +func (m *mockRouteHandler) AddRoute(context.Context) error { return nil } +func (m *mockRouteHandler) RemoveRoute() error { return nil } +func (m *mockRouteHandler) AddAllowedIPs(string) error { return nil } +func (m *mockRouteHandler) RemoveAllowedIPs() error { return nil } + +func generateBenchmarkData(tier benchmarkTier) (*peer.Status, map[route.ID]*route.Route) { + statusRecorder := peer.NewRecorder("test-mgm") + routes := make(map[route.ID]*route.Route) + + peerKeys := make([]string, tier.peers) + for i := 0; i < tier.peers; i++ { + peerKey := fmt.Sprintf("peer-%d", i) + peerKeys[i] = peerKey + fqdn := fmt.Sprintf("peer-%d.example.com", i) + ip := fmt.Sprintf("10.0.%d.%d", i/256, i%256) + + err := statusRecorder.AddPeer(peerKey, fqdn, ip) + if err != nil { + panic(fmt.Sprintf("failed to add peer: %v", err)) + } + + var status peer.ConnStatus + var latency time.Duration + relayed := false + + switch i % 10 { + case 0, 1: // 20% disconnected + status = peer.StatusConnecting + latency = 0 + case 2: // 10% idle + status = peer.StatusIdle + latency = 50 * time.Millisecond + case 3, 4: // 20% relayed + status = peer.StatusConnected + relayed = true + latency = time.Duration(50+i%100) * time.Millisecond + default: // 50% direct connection + status = peer.StatusConnected + latency = time.Duration(10+i%40) * time.Millisecond + } + + // Update peer state + state := peer.State{ + PubKey: peerKey, + IP: ip, + FQDN: fqdn, + ConnStatus: status, + ConnStatusUpdate: time.Now(), + Relayed: relayed, + Latency: latency, + Mux: &sync.RWMutex{}, + } + + err = statusRecorder.UpdatePeerState(state) + if err != nil { + panic(fmt.Sprintf("failed to update peer state: %v", err)) + } + } + + routeID := 0 + for i := 0; i < tier.routes; i++ { + network := fmt.Sprintf("192.168.%d.0/24", i%256) + prefix := netip.MustParsePrefix(network) + + haGroupSize := 1 + if i%4 == 0 { // 25% of routes have HA + haGroupSize = tier.haPeersPerGroup + } + + for j := 0; j < haGroupSize; j++ { + peerIndex := (i*tier.haPeersPerGroup + j) % tier.peers + peerKey := peerKeys[peerIndex] + + rID := route.ID(fmt.Sprintf("route-%d-%d", i, j)) + + metric := 100 + j*10 + + routes[rID] = &route.Route{ + ID: rID, + Network: prefix, + Peer: peerKey, + Metric: metric, + NetID: route.NetID(fmt.Sprintf("net-%d", i)), + } + routeID++ + } + } + + return statusRecorder, routes +} + +// Benchmark the optimized recalculate routes +func BenchmarkRecalculateRoutes(b *testing.B) { + for _, tier := range benchmarkTiers { + b.Run(tier.name, func(b *testing.B) { + statusRecorder, routes := generateBenchmarkData(tier) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + watcher := &Watcher{ + ctx: ctx, + statusRecorder: statusRecorder, + routes: routes, + routePeersNotifiers: make(map[string]chan struct{}), + routeUpdate: make(chan RoutesUpdate), + peerStateUpdate: make(chan struct{}), + handler: &mockRouteHandler{network: "benchmark"}, + currentChosenStatus: nil, + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + err := watcher.recalculateRoutes(reasonPeerUpdate) + if err != nil { + b.Fatalf("recalculateRoutes failed: %v", err) + } + } + }) + } +} diff --git a/client/internal/routemanager/client/client_test.go b/client/internal/routemanager/client/client_test.go index 48a9495bf..e7aff28b6 100644 --- a/client/internal/routemanager/client/client_test.go +++ b/client/internal/routemanager/client/client_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/static" "github.com/netbirdio/netbird/route" ) @@ -23,8 +24,8 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name: "one route", statuses: map[route.ID]routerPeerStatus{ "route1": { - connected: true, - relayed: false, + status: peer.StatusConnected, + relayed: false, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -41,8 +42,8 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name: "one connected routes with relayed and direct", statuses: map[route.ID]routerPeerStatus{ "route1": { - connected: true, - relayed: true, + status: peer.StatusConnected, + relayed: true, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -59,8 +60,8 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name: "one connected routes with relayed and no direct", statuses: map[route.ID]routerPeerStatus{ "route1": { - connected: true, - relayed: true, + status: peer.StatusConnected, + relayed: true, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -77,8 +78,8 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name: "no connected peers", statuses: map[route.ID]routerPeerStatus{ "route1": { - connected: false, - relayed: false, + status: peer.StatusConnecting, + relayed: false, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -95,12 +96,12 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name: "multiple connected peers with different metrics", statuses: map[route.ID]routerPeerStatus{ "route1": { - connected: true, - relayed: false, + status: peer.StatusConnected, + relayed: false, }, "route2": { - connected: true, - relayed: false, + status: peer.StatusConnected, + relayed: false, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -122,12 +123,12 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name: "multiple connected peers with one relayed", statuses: map[route.ID]routerPeerStatus{ "route1": { - connected: true, - relayed: false, + status: peer.StatusConnected, + relayed: false, }, "route2": { - connected: true, - relayed: true, + status: peer.StatusConnected, + relayed: true, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -149,12 +150,12 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name: "multiple connected peers with different latencies", statuses: map[route.ID]routerPeerStatus{ "route1": { - connected: true, - latency: 300 * time.Millisecond, + status: peer.StatusConnected, + latency: 300 * time.Millisecond, }, "route2": { - connected: true, - latency: 10 * time.Millisecond, + status: peer.StatusConnected, + latency: 10 * time.Millisecond, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -176,12 +177,12 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name: "should ignore routes with latency 0", statuses: map[route.ID]routerPeerStatus{ "route1": { - connected: true, - latency: 0 * time.Millisecond, + status: peer.StatusConnected, + latency: 0 * time.Millisecond, }, "route2": { - connected: true, - latency: 10 * time.Millisecond, + status: peer.StatusConnected, + latency: 10 * time.Millisecond, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -203,14 +204,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name: "current route with similar score and similar but slightly worse latency should not change", statuses: map[route.ID]routerPeerStatus{ "route1": { - connected: true, - relayed: false, - latency: 15 * time.Millisecond, + status: peer.StatusConnected, + relayed: false, + latency: 15 * time.Millisecond, }, "route2": { - connected: true, - relayed: false, - latency: 10 * time.Millisecond, + status: peer.StatusConnected, + relayed: false, + latency: 10 * time.Millisecond, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -232,14 +233,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name: "relayed routes with latency 0 should maintain previous choice", statuses: map[route.ID]routerPeerStatus{ "route1": { - connected: true, - relayed: true, - latency: 0 * time.Millisecond, + status: peer.StatusConnected, + relayed: true, + latency: 0 * time.Millisecond, }, "route2": { - connected: true, - relayed: true, - latency: 0 * time.Millisecond, + status: peer.StatusConnected, + relayed: true, + latency: 0 * time.Millisecond, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -261,14 +262,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name: "p2p routes with latency 0 should maintain previous choice", statuses: map[route.ID]routerPeerStatus{ "route1": { - connected: true, - relayed: false, - latency: 0 * time.Millisecond, + status: peer.StatusConnected, + relayed: false, + latency: 0 * time.Millisecond, }, "route2": { - connected: true, - relayed: false, - latency: 0 * time.Millisecond, + status: peer.StatusConnected, + relayed: false, + latency: 0 * time.Millisecond, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -290,14 +291,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name: "current route with bad score should be changed to route with better score", statuses: map[route.ID]routerPeerStatus{ "route1": { - connected: true, - relayed: false, - latency: 200 * time.Millisecond, + status: peer.StatusConnected, + relayed: false, + latency: 200 * time.Millisecond, }, "route2": { - connected: true, - relayed: false, - latency: 10 * time.Millisecond, + status: peer.StatusConnected, + relayed: false, + latency: 10 * time.Millisecond, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -319,14 +320,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name: "current chosen route doesn't exist anymore", statuses: map[route.ID]routerPeerStatus{ "route1": { - connected: true, - relayed: false, - latency: 20 * time.Millisecond, + status: peer.StatusConnected, + relayed: false, + latency: 20 * time.Millisecond, }, "route2": { - connected: true, - relayed: false, - latency: 10 * time.Millisecond, + status: peer.StatusConnected, + relayed: false, + latency: 10 * time.Millisecond, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -344,6 +345,422 @@ func TestGetBestrouteFromStatuses(t *testing.T) { currentRoute: "routeDoesntExistAnymore", expectedRouteID: "route2", }, + { + name: "connected peer should be preferred over idle peer", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + status: peer.StatusIdle, + relayed: false, + latency: 10 * time.Millisecond, + }, + "route2": { + status: peer.StatusConnected, + relayed: false, + latency: 100 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route2", + }, + { + name: "idle peer should be selected when no connected peers", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + status: peer.StatusIdle, + relayed: false, + latency: 10 * time.Millisecond, + }, + "route2": { + status: peer.StatusConnecting, + relayed: false, + latency: 5 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route1", + }, + { + name: "best idle peer should be selected among multiple idle peers", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + status: peer.StatusIdle, + relayed: false, + latency: 100 * time.Millisecond, + }, + "route2": { + status: peer.StatusIdle, + relayed: false, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route2", + }, + { + name: "connecting peers should not be considered for routing", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + status: peer.StatusConnecting, + relayed: false, + latency: 10 * time.Millisecond, + }, + "route2": { + status: peer.StatusConnecting, + relayed: false, + latency: 5 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "", + }, + { + name: "mixed statuses - connected wins over idle and connecting", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + status: peer.StatusConnecting, + relayed: false, + latency: 5 * time.Millisecond, + }, + "route2": { + status: peer.StatusIdle, + relayed: false, + latency: 10 * time.Millisecond, + }, + "route3": { + status: peer.StatusConnected, + relayed: true, + latency: 200 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + "route3": { + ID: "route3", + Metric: route.MaxMetric, + Peer: "peer3", + }, + }, + currentRoute: "", + expectedRouteID: "route3", + }, + { + name: "idle peer with better metric should win over idle peer with worse metric", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + status: peer.StatusIdle, + relayed: false, + latency: 50 * time.Millisecond, + }, + "route2": { + status: peer.StatusIdle, + relayed: false, + latency: 50 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: 5000, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route1", + }, + { + name: "current idle route should be maintained for similar scores", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + status: peer.StatusIdle, + relayed: false, + latency: 20 * time.Millisecond, + }, + "route2": { + status: peer.StatusIdle, + relayed: false, + latency: 15 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "route1", + expectedRouteID: "route1", + }, + { + name: "idle peer with zero latency should still be considered", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + status: peer.StatusIdle, + relayed: false, + latency: 0 * time.Millisecond, + }, + "route2": { + status: peer.StatusConnecting, + relayed: false, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route1", + }, + { + name: "direct idle peer preferred over relayed idle peer", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + status: peer.StatusIdle, + relayed: true, + latency: 10 * time.Millisecond, + }, + "route2": { + status: peer.StatusIdle, + relayed: false, + latency: 50 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route2", + }, + { + name: "connected peer with worse metric still beats idle peer with better metric", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + status: peer.StatusIdle, + relayed: false, + latency: 10 * time.Millisecond, + }, + "route2": { + status: peer.StatusConnected, + relayed: false, + latency: 50 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: 1000, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route2", + }, + { + name: "connected peer wins even when idle peer has all advantages", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + status: peer.StatusIdle, + relayed: false, + latency: 1 * time.Millisecond, + }, + "route2": { + status: peer.StatusConnected, + relayed: true, + latency: 30 * time.Minute, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: 1, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route2", + }, + { + name: "connected peer should be preferred over idle peer", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + status: peer.StatusIdle, + relayed: false, + latency: 10 * time.Millisecond, + }, + "route2": { + status: peer.StatusConnected, + relayed: false, + latency: 100 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route2", + }, + { + name: "idle peer should be selected when no connected peers", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + status: peer.StatusIdle, + relayed: false, + latency: 10 * time.Millisecond, + }, + "route2": { + status: peer.StatusConnecting, + relayed: false, + latency: 5 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route1", + }, + { + name: "best idle peer should be selected among multiple idle peers", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + status: peer.StatusIdle, + relayed: false, + latency: 100 * time.Millisecond, + }, + "route2": { + status: peer.StatusIdle, + relayed: false, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route2", + }, } // fill the test data with random routes @@ -368,18 +785,18 @@ func TestGetBestrouteFromStatuses(t *testing.T) { for i := 0; i < 50; i++ { id := route.ID(fmt.Sprintf("dummy_p1_%d", i)) dummyStatus := routerPeerStatus{ - connected: false, - relayed: true, - latency: 0, + status: peer.StatusConnecting, + relayed: true, + latency: 0, } tc.statuses[id] = dummyStatus } for i := 0; i < 50; i++ { id := route.ID(fmt.Sprintf("dummy_p2_%d", i)) dummyStatus := routerPeerStatus{ - connected: false, - relayed: true, - latency: 0, + status: peer.StatusConnecting, + relayed: true, + latency: 0, } tc.statuses[id] = dummyStatus } @@ -401,7 +818,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { currentChosen: currentRoute, } - chosenRoute := client.getBestRouteFromStatuses(tc.statuses) + chosenRoute, _ := client.getBestRouteFromStatuses(tc.statuses) if chosenRoute != tc.expectedRouteID { t.Errorf("expected routeID %s, got %s", tc.expectedRouteID, chosenRoute) } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 8dbbb5f77..93a3788eb 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -41,7 +41,8 @@ import ( // Manager is a route manager interface type Manager interface { Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) - UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error + UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error + ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector GetClientRoutes() route.HAMap @@ -319,7 +320,12 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error { return nberrors.FormatErrorOrNil(merr) } -func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error { +func (m *DefaultManager) UpdateRoutes( + updateSerial uint64, + serverRoutes map[route.ID]*route.Route, + clientRoutes route.HAMap, + useNewDNSRoute bool, +) error { select { case <-m.ctx.Done(): log.Infof("not updating routes as context is closed") @@ -331,11 +337,9 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro defer m.mux.Unlock() m.useNewDNSRoute = useNewDNSRoute - newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) - var merr *multierror.Error if !m.disableClientRoutes { - filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) + filteredClientRoutes := m.routeSelector.FilterSelected(clientRoutes) if err := m.updateSystemRoutes(filteredClientRoutes); err != nil { merr = multierror.Append(merr, fmt.Errorf("update system routes: %w", err)) @@ -344,13 +348,13 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro m.updateClientNetworks(updateSerial, filteredClientRoutes) m.notifier.OnNewRoutes(filteredClientRoutes) } - m.clientRoutes = newClientRoutesIDMap + m.clientRoutes = clientRoutes if m.serverRouter == nil { return nberrors.FormatErrorOrNil(merr) } - if err := m.serverRouter.UpdateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil { + if err := m.serverRouter.UpdateRoutes(serverRoutes, useNewDNSRoute); err != nil { merr = multierror.Append(merr, fmt.Errorf("update server routes: %w", err)) } @@ -481,7 +485,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout } } -func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) { +func (m *DefaultManager) ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) { newClientRoutesIDMap := make(route.HAMap) newServerRoutesMap := make(map[route.ID]*route.Route) ownNetworkIDs := make(map[route.HAUniqueID]bool) @@ -508,7 +512,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID] } func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route { - _, crMap := m.classifyRoutes(initialRoutes) + _, crMap := m.ClassifyRoutes(initialRoutes) rs := make([]*route.Route, 0, len(crMap)) for _, routes := range crMap { rs = append(rs, routes...) diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index a46ae080e..486ee080a 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -439,12 +439,14 @@ func TestManagerUpdateRoutes(t *testing.T) { routeManager.serverRouter = nil } + serverRoutes, clientRoutes := routeManager.ClassifyRoutes(testCase.inputRoutes) + if len(testCase.inputInitRoutes) > 0 { - err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false) + err = routeManager.UpdateRoutes(testCase.inputSerial, serverRoutes, clientRoutes, false) require.NoError(t, err, "should update routes with init routes") } - err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false) + err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), serverRoutes, clientRoutes, false) require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 64fdffceb..63bad689e 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -14,7 +14,8 @@ import ( // MockManager is the mock instance of a route manager type MockManager struct { - UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error + ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap) + UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error TriggerSelectionFunc func(haMap route.HAMap) GetRouteSelectorFunc func() *routeselector.RouteSelector GetClientRoutesFunc func() route.HAMap @@ -32,13 +33,21 @@ func (m *MockManager) InitialRouteRange() []string { } // UpdateRoutes mock implementation of UpdateRoutes from Manager interface -func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, b bool) error { +func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error { if m.UpdateRoutesFunc != nil { - return m.UpdateRoutesFunc(updateSerial, newRoutes) + return m.UpdateRoutesFunc(updateSerial, newRoutes, clientRoutes, useNewDNSRoute) } return nil } +// ClassifyRoutes mock implementation of ClassifyRoutes from Manager interface +func (m *MockManager) ClassifyRoutes(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap) { + if m.ClassifyRoutesFunc != nil { + return m.ClassifyRoutesFunc(routes) + } + return nil, nil +} + func (m *MockManager) TriggerSelection(networks route.HAMap) { if m.TriggerSelectionFunc != nil { m.TriggerSelectionFunc(networks)