diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 3f49c23fd..7b845235c 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -489,7 +489,7 @@ func (s *DefaultServer) applyHostConfig() { } } - log.Debugf("extra match domains: %v", s.extraDomains) + log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains)) if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil { log.Errorf("failed to apply DNS host manager update: %v", err) diff --git a/client/internal/engine.go b/client/internal/engine.go index 5efc0b92b..d015c1d6c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -994,6 +994,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } } + protoDNSConfig := networkMap.GetDNSConfig() + if protoDNSConfig == nil { + protoDNSConfig = &mgmProto.DNSConfig{} + } + + if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil { + log.Errorf("failed to update dns server, err: %v", err) + } + dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) // apply routes first, route related actions might depend on routing being enabled @@ -1061,15 +1070,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers()) e.connMgr.SetExcludeList(excludedLazyPeers) - protoDNSConfig := networkMap.GetDNSConfig() - if protoDNSConfig == nil { - protoDNSConfig = &mgmProto.DNSConfig{} - } - - if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil { - log.Errorf("failed to update dns server, err: %v", err) - } - e.networkSerial = serial // Test received (upstream) servers for availability right away instead of upon usage. diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client/client.go similarity index 52% rename from client/internal/routemanager/client.go rename to client/internal/routemanager/client/client.go index bff954c27..5582591a9 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client/client.go @@ -1,4 +1,4 @@ -package routemanager +package client import ( "context" @@ -7,10 +7,8 @@ import ( "runtime" "time" - "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" - nberrors "github.com/netbirdio/netbird/client/errors" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" @@ -36,6 +34,7 @@ const ( reasonRouteUpdate reasonPeerUpdate reasonShutdown + reasonHA ) type routerPeerStatus struct { @@ -44,9 +43,9 @@ type routerPeerStatus struct { latency time.Duration } -type routesUpdate struct { - updateSerial uint64 - routes []*route.Route +type RoutesUpdate struct { + UpdateSerial uint64 + Routes []*route.Route } // RouteHandler defines the interface for handling routes @@ -58,64 +57,54 @@ type RouteHandler interface { RemoveAllowedIPs() error } -type clientNetwork struct { +type WatcherConfig struct { + Context context.Context + DNSRouteInterval time.Duration + WGInterface iface.WGIface + StatusRecorder *peer.Status + Route *route.Route + Handler RouteHandler +} + +// Watcher watches route and peer changes and updates allowed IPs accordingly. +// Once stopped, it cannot be reused. +type Watcher struct { ctx context.Context cancel context.CancelFunc statusRecorder *peer.Status wgInterface iface.WGIface routes map[route.ID]*route.Route - routeUpdate chan routesUpdate + routeUpdate chan RoutesUpdate peerStateUpdate chan struct{} - routePeersNotifiers map[string]chan struct{} + routePeersNotifiers map[string]chan struct{} // map of peer key to channel for peer state changes currentChosen *route.Route handler RouteHandler updateSerial uint64 } -func newClientNetworkWatcher( - ctx context.Context, - dnsRouteInterval time.Duration, - wgInterface iface.WGIface, - statusRecorder *peer.Status, - rt *route.Route, - routeRefCounter *refcounter.RouteRefCounter, - allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, - dnsServer nbdns.Server, - peerStore *peerstore.Store, - useNewDNSRoute bool, -) *clientNetwork { - ctx, cancel := context.WithCancel(ctx) +func NewWatcher(config WatcherConfig) *Watcher { + ctx, cancel := context.WithCancel(config.Context) - client := &clientNetwork{ + client := &Watcher{ ctx: ctx, cancel: cancel, - statusRecorder: statusRecorder, - wgInterface: wgInterface, + statusRecorder: config.StatusRecorder, + wgInterface: config.WGInterface, routes: make(map[route.ID]*route.Route), routePeersNotifiers: make(map[string]chan struct{}), - routeUpdate: make(chan routesUpdate), + routeUpdate: make(chan RoutesUpdate), peerStateUpdate: make(chan struct{}), - handler: handlerFromRoute( - rt, - routeRefCounter, - allowedIPsRefCounter, - dnsRouteInterval, - statusRecorder, - wgInterface, - dnsServer, - peerStore, - useNewDNSRoute, - ), + handler: config.Handler, } return client } -func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus { +func (w *Watcher) getRouterPeerStatuses() map[route.ID]routerPeerStatus { routePeerStatuses := make(map[route.ID]routerPeerStatus) - for _, r := range c.routes { - peerStatus, err := c.statusRecorder.GetPeer(r.Peer) + for _, r := range w.routes { + peerStatus, err := w.statusRecorder.GetPeer(r.Peer) if err != nil { - log.Debugf("couldn't fetch peer state: %v", err) + log.Debugf("couldn't fetch peer state %v: %v", r.Peer, err) continue } routePeerStatuses[r.ID] = routerPeerStatus{ @@ -128,7 +117,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus { } // getBestRouteFromStatuses determines the most optimal route from the available routes -// within a clientNetwork, taking into account peer connection status, route metrics, and +// within a Watcher, taking into account peer connection status, route metrics, and // preference for non-relayed and direct connections. // // It follows these prioritization rules: @@ -140,17 +129,17 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus { // * Stability: In case of equal scores, the currently active route (if any) is maintained. // // It returns the ID of the selected optimal route. -func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID { - chosen := route.ID("") +func (w *Watcher) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID { + var chosen route.ID chosenScore := float64(0) currScore := float64(0) - currID := route.ID("") - if c.currentChosen != nil { - currID = c.currentChosen.ID + var currID route.ID + if w.currentChosen != nil { + currID = w.currentChosen.ID } - for _, r := range c.routes { + for _, r := range w.routes { tempScore := float64(0) peerStatus, found := routePeerStatuses[r.ID] if !found || !peerStatus.connected { @@ -167,7 +156,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID] if peerStatus.latency != 0 { latency = peerStatus.latency } else { - log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler) + log.Tracef("peer %s has 0 latency, range %s", r.Peer, w.handler) } // avoid negative tempScore on the higher latency calculation @@ -197,35 +186,45 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID] } } - log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore) + chosenID := chosen + if chosen == "" { + chosenID = "" + } + currentID := currID + if currID == "" { + currentID = "" + } + + log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosenID, chosenScore, currentID, currScore) switch { case chosen == "": var peers []string - for _, r := range c.routes { + for _, r := range w.routes { peers = append(peers, r.Peer) } - log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers) + log.Infof("network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", w.handler, peers) case chosen != currID: // we compare the current score + 10ms to the chosen score to avoid flapping between routes if currScore != 0 && currScore+0.01 > chosenScore { - log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore) + log.Debugf("keeping current routing peer %s for [%v]: the score difference with latency is less than 0.01(10ms): current: %f, new: %f", + w.currentChosen.Peer, w.handler, currScore, chosenScore) return currID } var p string - if rt := c.routes[chosen]; rt != nil { + if rt := w.routes[chosen]; rt != nil { p = rt.Peer } - log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler) + log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, w.handler) } return chosen } -func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) { - subscription := c.statusRecorder.SubscribeToPeerStateChanges(ctx, peerKey) - defer c.statusRecorder.UnsubscribePeerStateChanges(subscription) +func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) { + subscription := w.statusRecorder.SubscribeToPeerStateChanges(ctx, peerKey) + defer w.statusRecorder.UnsubscribePeerStateChanges(subscription) for { select { @@ -240,105 +239,92 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri } } -func (c *clientNetwork) startPeersStatusChangeWatcher() { - for _, r := range c.routes { - _, found := c.routePeersNotifiers[r.Peer] - if found { +func (w *Watcher) startNewPeerStatusWatchers() { + for _, r := range w.routes { + if _, found := w.routePeersNotifiers[r.Peer]; found { continue } closerChan := make(chan struct{}) - c.routePeersNotifiers[r.Peer] = closerChan - go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan) + w.routePeersNotifiers[r.Peer] = closerChan + go w.watchPeerStatusChanges(w.ctx, r.Peer, w.peerStateUpdate, closerChan) } } -func (c *clientNetwork) removeRouteFromWireGuardPeer() error { - if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil { +// addAllowedIPs adds the allowed IPs for the current chosen route to the handler. +func (w *Watcher) addAllowedIPs(route *route.Route) error { + if err := w.handler.AddAllowedIPs(route.Peer); err != nil { + return fmt.Errorf("add allowed IPs for peer %s: %w", route.Peer, err) + } + + if err := w.statusRecorder.AddPeerStateRoute(route.Peer, w.handler.String(), route.GetResourceID()); err != nil { log.Warnf("Failed to update peer state: %v", err) } - if err := c.handler.RemoveAllowedIPs(); err != nil { - return fmt.Errorf("remove allowed IPs: %w", err) - } + w.connectEvent(route) return nil } -func (c *clientNetwork) removeRouteFromPeerAndSystem(rsn reason) error { - if c.currentChosen == nil { - return nil +func (w *Watcher) removeAllowedIPs(route *route.Route, rsn reason) error { + if err := w.statusRecorder.RemovePeerStateRoute(route.Peer, w.handler.String()); err != nil { + log.Warnf("Failed to update peer state: %v", err) } - var merr *multierror.Error - - if err := c.removeRouteFromWireGuardPeer(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)) - } - if err := c.handler.RemoveRoute(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err)) + if err := w.handler.RemoveAllowedIPs(); err != nil { + return fmt.Errorf("remove allowed IPs: %w", err) } - c.disconnectEvent(rsn) + w.disconnectEvent(route, rsn) - return nberrors.FormatErrorOrNil(merr) + return nil } -func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error { - routerPeerStatuses := c.getRouterPeerStatuses() +func (w *Watcher) recalculateRoutes(rsn reason) error { + routerPeerStatuses := w.getRouterPeerStatuses() - newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses) + newChosenID := w.getBestRouteFromStatuses(routerPeerStatuses) - // If no route is chosen, remove the route from the peer and system + // If no route is chosen, remove the route from the peer if newChosenID == "" { - if err := c.removeRouteFromPeerAndSystem(rsn); err != nil { - return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err) + if w.currentChosen == nil { + return nil } - c.currentChosen = nil + if err := w.removeAllowedIPs(w.currentChosen, rsn); err != nil { + return fmt.Errorf("remove obsolete: %w", err) + } + + w.currentChosen = nil return nil } // If the chosen route is the same as the current route, do nothing - if c.currentChosen != nil && c.currentChosen.ID == newChosenID && - c.currentChosen.Equal(c.routes[newChosenID]) { + if w.currentChosen != nil && w.currentChosen.ID == newChosenID && + w.currentChosen.Equal(w.routes[newChosenID]) { return nil } - var isNew bool - if c.currentChosen == nil { - // If they were not previously assigned to another peer, add routes to the system first - if err := c.handler.AddRoute(c.ctx); err != nil { - return fmt.Errorf("add route: %w", err) - } - isNew = true - } else { - // Otherwise, remove the allowed IPs from the previous peer first - if err := c.removeRouteFromWireGuardPeer(); err != nil { - return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err) + // If the chosen route was assigned to a different peer, remove the allowed IPs first + if isNew := w.currentChosen == nil; !isNew { + if err := w.removeAllowedIPs(w.currentChosen, reasonHA); err != nil { + return fmt.Errorf("remove old: %w", err) } } - c.currentChosen = c.routes[newChosenID] - - if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil { - return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err) + newChosenRoute := w.routes[newChosenID] + if err := w.addAllowedIPs(newChosenRoute); err != nil { + return fmt.Errorf("add new: %w", err) } - if isNew { - c.connectEvent() - } + w.currentChosen = newChosenRoute - err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String(), c.currentChosen.GetResourceID()) - if err != nil { - return fmt.Errorf("add peer state route: %w", err) - } return nil } -func (c *clientNetwork) connectEvent() { +func (w *Watcher) connectEvent(route *route.Route) { var defaultRoute bool - for _, r := range c.routes { + for _, r := range w.routes { if r.Network.Bits() == 0 { defaultRoute = true break @@ -350,13 +336,13 @@ func (c *clientNetwork) connectEvent() { } meta := map[string]string{ - "network": c.handler.String(), + "network": w.handler.String(), } - if c.currentChosen != nil { - meta["id"] = string(c.currentChosen.NetID) - meta["peer"] = c.currentChosen.Peer + if route != nil { + meta["id"] = string(route.NetID) + meta["peer"] = route.Peer } - c.statusRecorder.PublishEvent( + w.statusRecorder.PublishEvent( proto.SystemEvent_INFO, proto.SystemEvent_NETWORK, "Default route added", @@ -365,9 +351,9 @@ func (c *clientNetwork) connectEvent() { ) } -func (c *clientNetwork) disconnectEvent(rsn reason) { +func (w *Watcher) disconnectEvent(route *route.Route, rsn reason) { var defaultRoute bool - for _, r := range c.routes { + for _, r := range w.routes { if r.Network.Bits() == 0 { defaultRoute = true break @@ -383,11 +369,11 @@ func (c *clientNetwork) disconnectEvent(rsn reason) { var userMessage string meta := make(map[string]string) - if c.currentChosen != nil { - meta["id"] = string(c.currentChosen.NetID) - meta["peer"] = c.currentChosen.Peer + if route != nil { + meta["id"] = string(route.NetID) + meta["peer"] = route.Peer } - meta["network"] = c.handler.String() + meta["network"] = w.handler.String() switch rsn { case reasonShutdown: severity = proto.SystemEvent_INFO @@ -400,13 +386,17 @@ func (c *clientNetwork) disconnectEvent(rsn reason) { severity = proto.SystemEvent_WARNING message = "Default route disconnected due to peer unreachability" userMessage = "Exit node connection lost. Your internet access might be affected." + case reasonHA: + severity = proto.SystemEvent_INFO + message = "Default route disconnected due to high availability change" + userMessage = "Exit node disconnected due to high availability change." default: severity = proto.SystemEvent_ERROR message = "Default route disconnected for unknown reasons" userMessage = "Exit node disconnected for unknown reasons." } - c.statusRecorder.PublishEvent( + w.statusRecorder.PublishEvent( severity, proto.SystemEvent_NETWORK, message, @@ -415,86 +405,101 @@ func (c *clientNetwork) disconnectEvent(rsn reason) { ) } -func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { +func (w *Watcher) SendUpdate(update RoutesUpdate) { go func() { - c.routeUpdate <- update + select { + case w.routeUpdate <- update: + case <-w.ctx.Done(): + } }() } -func (c *clientNetwork) handleUpdate(update routesUpdate) bool { +func (w *Watcher) classifyUpdate(update RoutesUpdate) bool { isUpdateMapDifferent := false updateMap := make(map[route.ID]*route.Route) - for _, r := range update.routes { + for _, r := range update.Routes { updateMap[r.ID] = r } - if len(c.routes) != len(updateMap) { + if len(w.routes) != len(updateMap) { isUpdateMapDifferent = true } - for id, r := range c.routes { + for id, r := range w.routes { _, found := updateMap[id] if !found { - close(c.routePeersNotifiers[r.Peer]) - delete(c.routePeersNotifiers, r.Peer) + close(w.routePeersNotifiers[r.Peer]) + delete(w.routePeersNotifiers, r.Peer) isUpdateMapDifferent = true continue } - if !reflect.DeepEqual(c.routes[id], updateMap[id]) { + if !reflect.DeepEqual(w.routes[id], updateMap[id]) { isUpdateMapDifferent = true } } - c.routes = updateMap + w.routes = updateMap return isUpdateMapDifferent } -// peersStateAndUpdateWatcher is the main point of reacting on client network routing events. +// Start is the main point of reacting on client network routing events. // All the processing related to the client network should be done here. Thread-safe. -func (c *clientNetwork) peersStateAndUpdateWatcher() { +func (w *Watcher) Start() { for { select { - case <-c.ctx.Done(): - log.Debugf("Stopping watcher for network [%v]", c.handler) - if err := c.removeRouteFromPeerAndSystem(reasonShutdown); err != nil { - log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err) - } + case <-w.ctx.Done(): return - case <-c.peerStateUpdate: - err := c.recalculateRouteAndUpdatePeerAndSystem(reasonPeerUpdate) - if err != nil { - log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err) + case <-w.peerStateUpdate: + if err := w.recalculateRoutes(reasonPeerUpdate); err != nil { + log.Errorf("Failed to recalculate routes for network [%v]: %v", w.handler, err) } - case update := <-c.routeUpdate: - if update.updateSerial < c.updateSerial { - log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial) + case update := <-w.routeUpdate: + if update.UpdateSerial < w.updateSerial { + log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", w.updateSerial, update.UpdateSerial) continue } - log.Debugf("Received a new client network route update for [%v]", c.handler) - - // hash update somehow - isTrueRouteUpdate := c.handleUpdate(update) - - c.updateSerial = update.updateSerial - - if isTrueRouteUpdate { - log.Debug("Client network update contains different routes, recalculating routes") - err := c.recalculateRouteAndUpdatePeerAndSystem(reasonRouteUpdate) - if err != nil { - log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err) - } - } else { - log.Debug("Route update is not different, skipping route recalculation") - } - - c.startPeersStatusChangeWatcher() + w.handleRouteUpdate(update) } } } -func handlerFromRoute( +func (w *Watcher) handleRouteUpdate(update RoutesUpdate) { + log.Debugf("Received a new client network route update for [%v]", w.handler) + + // hash update somehow + isTrueRouteUpdate := w.classifyUpdate(update) + + w.updateSerial = update.UpdateSerial + + if isTrueRouteUpdate { + log.Debugf("client network update %v for [%v] contains different routes, recalculating routes", update.UpdateSerial, w.handler) + if err := w.recalculateRoutes(reasonRouteUpdate); err != nil { + log.Errorf("failed to recalculate routes for network [%v]: %v", w.handler, err) + } + } else { + log.Debugf("route update %v for [%v] is not different, skipping route recalculation", update.UpdateSerial, w.handler) + } + + w.startNewPeerStatusWatchers() +} + +// Stop stops the watcher and cleans up resources. +func (w *Watcher) Stop() { + log.Debugf("Stopping watcher for network [%v]", w.handler) + + w.cancel() + + if w.currentChosen == nil { + return + } + if err := w.removeAllowedIPs(w.currentChosen, reasonShutdown); err != nil { + log.Errorf("Failed to remove routes for [%v]: %v", w.handler, err) + } +} + +func HandlerFromRoute( rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, diff --git a/client/internal/routemanager/client_test.go b/client/internal/routemanager/client/client_test.go similarity index 99% rename from client/internal/routemanager/client_test.go rename to client/internal/routemanager/client/client_test.go index 56fcf1613..48a9495bf 100644 --- a/client/internal/routemanager/client_test.go +++ b/client/internal/routemanager/client/client_test.go @@ -1,4 +1,4 @@ -package routemanager +package client import ( "fmt" @@ -395,7 +395,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { } // create new clientNetwork - client := &clientNetwork{ + client := &Watcher{ handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil), routes: tc.existingRoutes, currentChosen: currentRoute, diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 078206ab9..afb74c23e 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -11,9 +11,11 @@ import ( "sync" "time" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/netstack" @@ -21,9 +23,11 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/client" "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/client/internal/routemanager/server" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" @@ -68,9 +72,9 @@ type DefaultManager struct { ctx context.Context stop context.CancelFunc mux sync.Mutex - clientNetworks map[route.HAUniqueID]*clientNetwork + clientNetworks map[route.HAUniqueID]*client.Watcher routeSelector *routeselector.RouteSelector - serverRouter *serverRouter + serverRouter *server.Router sysOps *systemops.SysOps statusRecorder *peer.Status relayMgr *relayClient.Manager @@ -88,6 +92,7 @@ type DefaultManager struct { useNewDNSRoute bool disableClientRoutes bool disableServerRoutes bool + activeRoutes map[route.HAUniqueID]client.RouteHandler } func NewManager(config ManagerConfig) *DefaultManager { @@ -99,7 +104,7 @@ func NewManager(config ManagerConfig) *DefaultManager { ctx: mCTX, stop: cancel, dnsRouteInterval: config.DNSRouteInterval, - clientNetworks: make(map[route.HAUniqueID]*clientNetwork), + clientNetworks: make(map[route.HAUniqueID]*client.Watcher), relayMgr: config.RelayManager, sysOps: sysOps, statusRecorder: config.StatusRecorder, @@ -111,6 +116,7 @@ func NewManager(config ManagerConfig) *DefaultManager { peerStore: config.PeerStore, disableClientRoutes: config.DisableClientRoutes, disableServerRoutes: config.DisableServerRoutes, + activeRoutes: make(map[route.HAUniqueID]client.RouteHandler), } useNoop := netstack.IsEnabled() || config.DisableClientRoutes @@ -226,7 +232,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { } var err error - m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) + m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) if err != nil { return err } @@ -237,7 +243,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { m.stop() if m.serverRouter != nil { - m.serverRouter.cleanUp() + m.serverRouter.CleanUp() } if m.routeRefCounter != nil { @@ -265,6 +271,54 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps +func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error { + toAdd := make(map[route.HAUniqueID]*route.Route) + toRemove := make(map[route.HAUniqueID]client.RouteHandler) + + for id, routes := range newRoutes { + if len(routes) > 0 { + toAdd[id] = routes[0] + } + } + + for id, activeHandler := range m.activeRoutes { + if _, exists := toAdd[id]; exists { + delete(toAdd, id) + } else { + toRemove[id] = activeHandler + } + } + + var merr *multierror.Error + for id, handler := range toRemove { + if err := handler.RemoveRoute(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err)) + } + delete(m.activeRoutes, id) + } + + for id, route := range toAdd { + handler := client.HandlerFromRoute( + route, + m.routeRefCounter, + m.allowedIPsRefCounter, + m.dnsRouteInterval, + m.statusRecorder, + m.wgInterface, + m.dnsServer, + m.peerStore, + m.useNewDNSRoute, + ) + if err := handler.AddRoute(m.ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err)) + continue + } + m.activeRoutes[id] = handler + } + + return nberrors.FormatErrorOrNil(merr) +} + func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error { select { case <-m.ctx.Done(): @@ -281,6 +335,11 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro if !m.disableClientRoutes { filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) + + if err := m.updateSystemRoutes(filteredClientRoutes); err != nil { + log.Errorf("Failed to update system routes: %v", err) + } + m.updateClientNetworks(updateSerial, filteredClientRoutes) m.notifier.OnNewRoutes(filteredClientRoutes) } @@ -290,7 +349,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro return nil } - if err := m.serverRouter.updateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil { + if err := m.serverRouter.UpdateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil { return fmt.Errorf("update routes: %w", err) } @@ -341,6 +400,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { m.notifier.OnNewRoutes(networks) + if err := m.updateSystemRoutes(networks); err != nil { + log.Errorf("failed to update system routes during selection: %v", err) + } + m.stopObsoleteClients(networks) for id, routes := range networks { @@ -349,21 +412,24 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { continue } - clientNetworkWatcher := newClientNetworkWatcher( - m.ctx, - m.dnsRouteInterval, - m.wgInterface, - m.statusRecorder, - routes[0], - m.routeRefCounter, - m.allowedIPsRefCounter, - m.dnsServer, - m.peerStore, - m.useNewDNSRoute, - ) + handler := m.activeRoutes[id] + if handler == nil { + log.Warnf("no active handler found for route %s", id) + continue + } + + config := client.WatcherConfig{ + Context: m.ctx, + DNSRouteInterval: m.dnsRouteInterval, + WGInterface: m.wgInterface, + StatusRecorder: m.statusRecorder, + Route: routes[0], + Handler: handler, + } + clientNetworkWatcher := client.NewWatcher(config) m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.peersStateAndUpdateWatcher() - clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) + go clientNetworkWatcher.Start() + clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes}) } if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil { @@ -375,8 +441,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) { for id, client := range m.clientNetworks { if _, ok := networks[id]; !ok { - log.Debugf("Stopping client network watcher, %s", id) - client.cancel() + client.Stop() delete(m.clientNetworks, id) } } @@ -389,26 +454,29 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout for id, routes := range networks { clientNetworkWatcher, found := m.clientNetworks[id] if !found { - clientNetworkWatcher = newClientNetworkWatcher( - m.ctx, - m.dnsRouteInterval, - m.wgInterface, - m.statusRecorder, - routes[0], - m.routeRefCounter, - m.allowedIPsRefCounter, - m.dnsServer, - m.peerStore, - m.useNewDNSRoute, - ) + handler := m.activeRoutes[id] + if handler == nil { + log.Errorf("No active handler found for route %s", id) + continue + } + + config := client.WatcherConfig{ + Context: m.ctx, + DNSRouteInterval: m.dnsRouteInterval, + WGInterface: m.wgInterface, + StatusRecorder: m.statusRecorder, + Route: routes[0], + Handler: handler, + } + clientNetworkWatcher = client.NewWatcher(config) m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.peersStateAndUpdateWatcher() + go clientNetworkWatcher.Start() } - update := routesUpdate{ - updateSerial: updateSerial, - routes: routes, + update := client.RoutesUpdate{ + UpdateSerial: updateSerial, + Routes: routes, } - clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update) + clientNetworkWatcher.SendUpdate(update) } } diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 318ef5ae5..680bd813f 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net/netip" - "runtime" "testing" "github.com/pion/transport/v3/stdnet" @@ -454,8 +453,8 @@ func TestManagerUpdateRoutes(t *testing.T) { } require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") - if runtime.GOOS == "linux" && routeManager.serverRouter != nil { - require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match") + if routeManager.serverRouter != nil { + require.Equal(t, testCase.serverRoutesExpected, routeManager.serverRouter.RoutesCount(), "server networks size should match") } }) } diff --git a/client/internal/routemanager/server.go b/client/internal/routemanager/server/server.go similarity index 63% rename from client/internal/routemanager/server.go rename to client/internal/routemanager/server/server.go index 5bacb856c..e674c80cd 100644 --- a/client/internal/routemanager/server.go +++ b/client/internal/routemanager/server/server.go @@ -1,4 +1,4 @@ -package routemanager +package server import ( "context" @@ -14,7 +14,7 @@ import ( "github.com/netbirdio/netbird/route" ) -type serverRouter struct { +type Router struct { mux sync.Mutex ctx context.Context routes map[route.ID]*route.Route @@ -23,8 +23,8 @@ type serverRouter struct { statusRecorder *peer.Status } -func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) { - return &serverRouter{ +func NewRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*Router, error) { + return &Router{ ctx: ctx, routes: make(map[route.ID]*route.Route), firewall: firewall, @@ -33,104 +33,110 @@ func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall fi }, nil } -func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error { - m.mux.Lock() - defer m.mux.Unlock() +func (r *Router) UpdateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error { + r.mux.Lock() + defer r.mux.Unlock() serverRoutesToRemove := make([]route.ID, 0) - for routeID := range m.routes { + for routeID := range r.routes { update, found := routesMap[routeID] - if !found || !update.Equal(m.routes[routeID]) { + if !found || !update.Equal(r.routes[routeID]) { serverRoutesToRemove = append(serverRoutesToRemove, routeID) } } for _, routeID := range serverRoutesToRemove { - oldRoute := m.routes[routeID] - err := m.removeFromServerNetwork(oldRoute) + oldRoute := r.routes[routeID] + err := r.removeFromServerNetwork(oldRoute) if err != nil { log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v", oldRoute.ID, oldRoute.Network, err) } - delete(m.routes, routeID) + delete(r.routes, routeID) } // If routing is to be disabled, do it after routes have been removed // If routing is to be enabled, do it before adding new routes; addToServerNetwork needs routing to be enabled if len(routesMap) > 0 { - if err := m.firewall.EnableRouting(); err != nil { + if err := r.firewall.EnableRouting(); err != nil { return fmt.Errorf("enable routing: %w", err) } } else { - if err := m.firewall.DisableRouting(); err != nil { + if err := r.firewall.DisableRouting(); err != nil { return fmt.Errorf("disable routing: %w", err) } } for id, newRoute := range routesMap { - _, found := m.routes[id] + _, found := r.routes[id] if found { continue } - err := m.addToServerNetwork(newRoute, useNewDNSRoute) + err := r.addToServerNetwork(newRoute, useNewDNSRoute) if err != nil { log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err) continue } - m.routes[id] = newRoute + r.routes[id] = newRoute } return nil } -func (m *serverRouter) removeFromServerNetwork(route *route.Route) error { - if m.ctx.Err() != nil { +func (r *Router) removeFromServerNetwork(route *route.Route) error { + if r.ctx.Err() != nil { log.Infof("Not removing from server network because context is done") - return m.ctx.Err() + return r.ctx.Err() } routerPair := routeToRouterPair(route, false) - if err := m.firewall.RemoveNatRule(routerPair); err != nil { + if err := r.firewall.RemoveNatRule(routerPair); err != nil { return fmt.Errorf("remove routing rules: %w", err) } - delete(m.routes, route.ID) - m.statusRecorder.RemoveLocalPeerStateRoute(route.NetString()) + delete(r.routes, route.ID) + r.statusRecorder.RemoveLocalPeerStateRoute(route.NetString()) return nil } -func (m *serverRouter) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error { - if m.ctx.Err() != nil { +func (r *Router) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error { + if r.ctx.Err() != nil { log.Infof("Not adding to server network because context is done") - return m.ctx.Err() + return r.ctx.Err() } routerPair := routeToRouterPair(route, useNewDNSRoute) - if err := m.firewall.AddNatRule(routerPair); err != nil { + if err := r.firewall.AddNatRule(routerPair); err != nil { return fmt.Errorf("insert routing rules: %w", err) } - m.routes[route.ID] = route - m.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID()) + r.routes[route.ID] = route + r.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID()) return nil } -func (m *serverRouter) cleanUp() { - m.mux.Lock() - defer m.mux.Unlock() +func (r *Router) CleanUp() { + r.mux.Lock() + defer r.mux.Unlock() - for _, r := range m.routes { - routerPair := routeToRouterPair(r, false) - if err := m.firewall.RemoveNatRule(routerPair); err != nil { + for _, route := range r.routes { + routerPair := routeToRouterPair(route, false) + if err := r.firewall.RemoveNatRule(routerPair); err != nil { log.Errorf("Failed to remove cleanup route: %v", err) } } - m.statusRecorder.CleanLocalPeerStateRoutes() + r.statusRecorder.CleanLocalPeerStateRoutes() +} + +func (r *Router) RoutesCount() int { + r.mux.Lock() + defer r.mux.Unlock() + return len(r.routes) } func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterPair { diff --git a/client/internal/routemanager/static/route.go b/client/internal/routemanager/static/route.go index 681c192fb..c8b9338e0 100644 --- a/client/internal/routemanager/static/route.go +++ b/client/internal/routemanager/static/route.go @@ -29,13 +29,17 @@ func (r *Route) String() string { } func (r *Route) AddRoute(context.Context) error { - _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}) - return err + if _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}); err != nil { + return err + } + return nil } func (r *Route) RemoveRoute() error { - _, err := r.routeRefCounter.Decrement(r.route.Network) - return err + if _, err := r.routeRefCounter.Decrement(r.route.Network); err != nil { + return err + } + return nil } func (r *Route) AddAllowedIPs(peerKey string) error { @@ -51,6 +55,8 @@ func (r *Route) AddAllowedIPs(peerKey string) error { } func (r *Route) RemoveAllowedIPs() error { - _, err := r.allowedIPsRefcounter.Decrement(r.route.Network) - return err + if _, err := r.allowedIPsRefcounter.Decrement(r.route.Network); err != nil { + return err + } + return nil }