[client] Apply routes right away instead of on peer connection (#3907)

This commit is contained in:
Viktor Liu 2025-06-03 10:53:39 +02:00 committed by GitHub
parent 1ce4ee0cef
commit 06980e7fa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 349 additions and 265 deletions

View File

@ -489,7 +489,7 @@ func (s *DefaultServer) applyHostConfig() {
} }
} }
log.Debugf("extra match domains: %v", s.extraDomains) log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil { if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err) log.Errorf("failed to apply DNS host manager update: %v", err)

View File

@ -994,6 +994,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
} }
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
}
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
// apply routes first, route related actions might depend on routing being enabled // apply routes first, route related actions might depend on routing being enabled
@ -1061,15 +1070,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers()) excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers())
e.connMgr.SetExcludeList(excludedLazyPeers) e.connMgr.SetExcludeList(excludedLazyPeers)
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
}
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
e.networkSerial = serial e.networkSerial = serial
// Test received (upstream) servers for availability right away instead of upon usage. // Test received (upstream) servers for availability right away instead of upon usage.

View File

@ -1,4 +1,4 @@
package routemanager package client
import ( import (
"context" "context"
@ -7,10 +7,8 @@ import (
"runtime" "runtime"
"time" "time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
@ -36,6 +34,7 @@ const (
reasonRouteUpdate reasonRouteUpdate
reasonPeerUpdate reasonPeerUpdate
reasonShutdown reasonShutdown
reasonHA
) )
type routerPeerStatus struct { type routerPeerStatus struct {
@ -44,9 +43,9 @@ type routerPeerStatus struct {
latency time.Duration latency time.Duration
} }
type routesUpdate struct { type RoutesUpdate struct {
updateSerial uint64 UpdateSerial uint64
routes []*route.Route Routes []*route.Route
} }
// RouteHandler defines the interface for handling routes // RouteHandler defines the interface for handling routes
@ -58,64 +57,54 @@ type RouteHandler interface {
RemoveAllowedIPs() error RemoveAllowedIPs() error
} }
type clientNetwork struct { type WatcherConfig struct {
Context context.Context
DNSRouteInterval time.Duration
WGInterface iface.WGIface
StatusRecorder *peer.Status
Route *route.Route
Handler RouteHandler
}
// Watcher watches route and peer changes and updates allowed IPs accordingly.
// Once stopped, it cannot be reused.
type Watcher struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface iface.WGIface wgInterface iface.WGIface
routes map[route.ID]*route.Route routes map[route.ID]*route.Route
routeUpdate chan routesUpdate routeUpdate chan RoutesUpdate
peerStateUpdate chan struct{} peerStateUpdate chan struct{}
routePeersNotifiers map[string]chan struct{} routePeersNotifiers map[string]chan struct{} // map of peer key to channel for peer state changes
currentChosen *route.Route currentChosen *route.Route
handler RouteHandler handler RouteHandler
updateSerial uint64 updateSerial uint64
} }
func newClientNetworkWatcher( func NewWatcher(config WatcherConfig) *Watcher {
ctx context.Context, ctx, cancel := context.WithCancel(config.Context)
dnsRouteInterval time.Duration,
wgInterface iface.WGIface,
statusRecorder *peer.Status,
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
useNewDNSRoute bool,
) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{ client := &Watcher{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
statusRecorder: statusRecorder, statusRecorder: config.StatusRecorder,
wgInterface: wgInterface, wgInterface: config.WGInterface,
routes: make(map[route.ID]*route.Route), routes: make(map[route.ID]*route.Route),
routePeersNotifiers: make(map[string]chan struct{}), routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate), routeUpdate: make(chan RoutesUpdate),
peerStateUpdate: make(chan struct{}), peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute( handler: config.Handler,
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouteInterval,
statusRecorder,
wgInterface,
dnsServer,
peerStore,
useNewDNSRoute,
),
} }
return client return client
} }
func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus { func (w *Watcher) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
routePeerStatuses := make(map[route.ID]routerPeerStatus) routePeerStatuses := make(map[route.ID]routerPeerStatus)
for _, r := range c.routes { for _, r := range w.routes {
peerStatus, err := c.statusRecorder.GetPeer(r.Peer) peerStatus, err := w.statusRecorder.GetPeer(r.Peer)
if err != nil { if err != nil {
log.Debugf("couldn't fetch peer state: %v", err) log.Debugf("couldn't fetch peer state %v: %v", r.Peer, err)
continue continue
} }
routePeerStatuses[r.ID] = routerPeerStatus{ routePeerStatuses[r.ID] = routerPeerStatus{
@ -128,7 +117,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
} }
// getBestRouteFromStatuses determines the most optimal route from the available routes // getBestRouteFromStatuses determines the most optimal route from the available routes
// within a clientNetwork, taking into account peer connection status, route metrics, and // within a Watcher, taking into account peer connection status, route metrics, and
// preference for non-relayed and direct connections. // preference for non-relayed and direct connections.
// //
// It follows these prioritization rules: // It follows these prioritization rules:
@ -140,17 +129,17 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
// * Stability: In case of equal scores, the currently active route (if any) is maintained. // * Stability: In case of equal scores, the currently active route (if any) is maintained.
// //
// It returns the ID of the selected optimal route. // It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID { func (w *Watcher) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
chosen := route.ID("") var chosen route.ID
chosenScore := float64(0) chosenScore := float64(0)
currScore := float64(0) currScore := float64(0)
currID := route.ID("") var currID route.ID
if c.currentChosen != nil { if w.currentChosen != nil {
currID = c.currentChosen.ID currID = w.currentChosen.ID
} }
for _, r := range c.routes { for _, r := range w.routes {
tempScore := float64(0) tempScore := float64(0)
peerStatus, found := routePeerStatuses[r.ID] peerStatus, found := routePeerStatuses[r.ID]
if !found || !peerStatus.connected { if !found || !peerStatus.connected {
@ -167,7 +156,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
if peerStatus.latency != 0 { if peerStatus.latency != 0 {
latency = peerStatus.latency latency = peerStatus.latency
} else { } else {
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler) log.Tracef("peer %s has 0 latency, range %s", r.Peer, w.handler)
} }
// avoid negative tempScore on the higher latency calculation // avoid negative tempScore on the higher latency calculation
@ -197,35 +186,45 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
} }
} }
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore) chosenID := chosen
if chosen == "" {
chosenID = "<none>"
}
currentID := currID
if currID == "" {
currentID = "<none>"
}
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosenID, chosenScore, currentID, currScore)
switch { switch {
case chosen == "": case chosen == "":
var peers []string var peers []string
for _, r := range c.routes { for _, r := range w.routes {
peers = append(peers, r.Peer) peers = append(peers, r.Peer)
} }
log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers) log.Infof("network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", w.handler, peers)
case chosen != currID: case chosen != currID:
// we compare the current score + 10ms to the chosen score to avoid flapping between routes // we compare the current score + 10ms to the chosen score to avoid flapping between routes
if currScore != 0 && currScore+0.01 > chosenScore { if currScore != 0 && currScore+0.01 > chosenScore {
log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore) log.Debugf("keeping current routing peer %s for [%v]: the score difference with latency is less than 0.01(10ms): current: %f, new: %f",
w.currentChosen.Peer, w.handler, currScore, chosenScore)
return currID return currID
} }
var p string var p string
if rt := c.routes[chosen]; rt != nil { if rt := w.routes[chosen]; rt != nil {
p = rt.Peer p = rt.Peer
} }
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler) log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, w.handler)
} }
return chosen return chosen
} }
func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) { func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) {
subscription := c.statusRecorder.SubscribeToPeerStateChanges(ctx, peerKey) subscription := w.statusRecorder.SubscribeToPeerStateChanges(ctx, peerKey)
defer c.statusRecorder.UnsubscribePeerStateChanges(subscription) defer w.statusRecorder.UnsubscribePeerStateChanges(subscription)
for { for {
select { select {
@ -240,105 +239,92 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
} }
} }
func (c *clientNetwork) startPeersStatusChangeWatcher() { func (w *Watcher) startNewPeerStatusWatchers() {
for _, r := range c.routes { for _, r := range w.routes {
_, found := c.routePeersNotifiers[r.Peer] if _, found := w.routePeersNotifiers[r.Peer]; found {
if found {
continue continue
} }
closerChan := make(chan struct{}) closerChan := make(chan struct{})
c.routePeersNotifiers[r.Peer] = closerChan w.routePeersNotifiers[r.Peer] = closerChan
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan) go w.watchPeerStatusChanges(w.ctx, r.Peer, w.peerStateUpdate, closerChan)
} }
} }
func (c *clientNetwork) removeRouteFromWireGuardPeer() error { // addAllowedIPs adds the allowed IPs for the current chosen route to the handler.
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil { func (w *Watcher) addAllowedIPs(route *route.Route) error {
if err := w.handler.AddAllowedIPs(route.Peer); err != nil {
return fmt.Errorf("add allowed IPs for peer %s: %w", route.Peer, err)
}
if err := w.statusRecorder.AddPeerStateRoute(route.Peer, w.handler.String(), route.GetResourceID()); err != nil {
log.Warnf("Failed to update peer state: %v", err) log.Warnf("Failed to update peer state: %v", err)
} }
if err := c.handler.RemoveAllowedIPs(); err != nil { w.connectEvent(route)
return nil
}
func (w *Watcher) removeAllowedIPs(route *route.Route, rsn reason) error {
if err := w.statusRecorder.RemovePeerStateRoute(route.Peer, w.handler.String()); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
if err := w.handler.RemoveAllowedIPs(); err != nil {
return fmt.Errorf("remove allowed IPs: %w", err) return fmt.Errorf("remove allowed IPs: %w", err)
} }
w.disconnectEvent(route, rsn)
return nil return nil
} }
func (c *clientNetwork) removeRouteFromPeerAndSystem(rsn reason) error { func (w *Watcher) recalculateRoutes(rsn reason) error {
if c.currentChosen == nil { routerPeerStatuses := w.getRouterPeerStatuses()
return nil
}
var merr *multierror.Error newChosenID := w.getBestRouteFromStatuses(routerPeerStatuses)
if err := c.removeRouteFromWireGuardPeer(); err != nil { // If no route is chosen, remove the route from the peer
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
}
if err := c.handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
}
c.disconnectEvent(rsn)
return nberrors.FormatErrorOrNil(merr)
}
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error {
routerPeerStatuses := c.getRouterPeerStatuses()
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
// If no route is chosen, remove the route from the peer and system
if newChosenID == "" { if newChosenID == "" {
if err := c.removeRouteFromPeerAndSystem(rsn); err != nil { if w.currentChosen == nil {
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err) return nil
} }
c.currentChosen = nil if err := w.removeAllowedIPs(w.currentChosen, rsn); err != nil {
return fmt.Errorf("remove obsolete: %w", err)
}
w.currentChosen = nil
return nil return nil
} }
// If the chosen route is the same as the current route, do nothing // If the chosen route is the same as the current route, do nothing
if c.currentChosen != nil && c.currentChosen.ID == newChosenID && if w.currentChosen != nil && w.currentChosen.ID == newChosenID &&
c.currentChosen.Equal(c.routes[newChosenID]) { w.currentChosen.Equal(w.routes[newChosenID]) {
return nil return nil
} }
var isNew bool // If the chosen route was assigned to a different peer, remove the allowed IPs first
if c.currentChosen == nil { if isNew := w.currentChosen == nil; !isNew {
// If they were not previously assigned to another peer, add routes to the system first if err := w.removeAllowedIPs(w.currentChosen, reasonHA); err != nil {
if err := c.handler.AddRoute(c.ctx); err != nil { return fmt.Errorf("remove old: %w", err)
return fmt.Errorf("add route: %w", err)
}
isNew = true
} else {
// Otherwise, remove the allowed IPs from the previous peer first
if err := c.removeRouteFromWireGuardPeer(); err != nil {
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
} }
} }
c.currentChosen = c.routes[newChosenID] newChosenRoute := w.routes[newChosenID]
if err := w.addAllowedIPs(newChosenRoute); err != nil {
if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil { return fmt.Errorf("add new: %w", err)
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
} }
if isNew { w.currentChosen = newChosenRoute
c.connectEvent()
}
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String(), c.currentChosen.GetResourceID())
if err != nil {
return fmt.Errorf("add peer state route: %w", err)
}
return nil return nil
} }
func (c *clientNetwork) connectEvent() { func (w *Watcher) connectEvent(route *route.Route) {
var defaultRoute bool var defaultRoute bool
for _, r := range c.routes { for _, r := range w.routes {
if r.Network.Bits() == 0 { if r.Network.Bits() == 0 {
defaultRoute = true defaultRoute = true
break break
@ -350,13 +336,13 @@ func (c *clientNetwork) connectEvent() {
} }
meta := map[string]string{ meta := map[string]string{
"network": c.handler.String(), "network": w.handler.String(),
} }
if c.currentChosen != nil { if route != nil {
meta["id"] = string(c.currentChosen.NetID) meta["id"] = string(route.NetID)
meta["peer"] = c.currentChosen.Peer meta["peer"] = route.Peer
} }
c.statusRecorder.PublishEvent( w.statusRecorder.PublishEvent(
proto.SystemEvent_INFO, proto.SystemEvent_INFO,
proto.SystemEvent_NETWORK, proto.SystemEvent_NETWORK,
"Default route added", "Default route added",
@ -365,9 +351,9 @@ func (c *clientNetwork) connectEvent() {
) )
} }
func (c *clientNetwork) disconnectEvent(rsn reason) { func (w *Watcher) disconnectEvent(route *route.Route, rsn reason) {
var defaultRoute bool var defaultRoute bool
for _, r := range c.routes { for _, r := range w.routes {
if r.Network.Bits() == 0 { if r.Network.Bits() == 0 {
defaultRoute = true defaultRoute = true
break break
@ -383,11 +369,11 @@ func (c *clientNetwork) disconnectEvent(rsn reason) {
var userMessage string var userMessage string
meta := make(map[string]string) meta := make(map[string]string)
if c.currentChosen != nil { if route != nil {
meta["id"] = string(c.currentChosen.NetID) meta["id"] = string(route.NetID)
meta["peer"] = c.currentChosen.Peer meta["peer"] = route.Peer
} }
meta["network"] = c.handler.String() meta["network"] = w.handler.String()
switch rsn { switch rsn {
case reasonShutdown: case reasonShutdown:
severity = proto.SystemEvent_INFO severity = proto.SystemEvent_INFO
@ -400,13 +386,17 @@ func (c *clientNetwork) disconnectEvent(rsn reason) {
severity = proto.SystemEvent_WARNING severity = proto.SystemEvent_WARNING
message = "Default route disconnected due to peer unreachability" message = "Default route disconnected due to peer unreachability"
userMessage = "Exit node connection lost. Your internet access might be affected." userMessage = "Exit node connection lost. Your internet access might be affected."
case reasonHA:
severity = proto.SystemEvent_INFO
message = "Default route disconnected due to high availability change"
userMessage = "Exit node disconnected due to high availability change."
default: default:
severity = proto.SystemEvent_ERROR severity = proto.SystemEvent_ERROR
message = "Default route disconnected for unknown reasons" message = "Default route disconnected for unknown reasons"
userMessage = "Exit node disconnected for unknown reasons." userMessage = "Exit node disconnected for unknown reasons."
} }
c.statusRecorder.PublishEvent( w.statusRecorder.PublishEvent(
severity, severity,
proto.SystemEvent_NETWORK, proto.SystemEvent_NETWORK,
message, message,
@ -415,86 +405,101 @@ func (c *clientNetwork) disconnectEvent(rsn reason) {
) )
} }
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { func (w *Watcher) SendUpdate(update RoutesUpdate) {
go func() { go func() {
c.routeUpdate <- update select {
case w.routeUpdate <- update:
case <-w.ctx.Done():
}
}() }()
} }
func (c *clientNetwork) handleUpdate(update routesUpdate) bool { func (w *Watcher) classifyUpdate(update RoutesUpdate) bool {
isUpdateMapDifferent := false isUpdateMapDifferent := false
updateMap := make(map[route.ID]*route.Route) updateMap := make(map[route.ID]*route.Route)
for _, r := range update.routes { for _, r := range update.Routes {
updateMap[r.ID] = r updateMap[r.ID] = r
} }
if len(c.routes) != len(updateMap) { if len(w.routes) != len(updateMap) {
isUpdateMapDifferent = true isUpdateMapDifferent = true
} }
for id, r := range c.routes { for id, r := range w.routes {
_, found := updateMap[id] _, found := updateMap[id]
if !found { if !found {
close(c.routePeersNotifiers[r.Peer]) close(w.routePeersNotifiers[r.Peer])
delete(c.routePeersNotifiers, r.Peer) delete(w.routePeersNotifiers, r.Peer)
isUpdateMapDifferent = true isUpdateMapDifferent = true
continue continue
} }
if !reflect.DeepEqual(c.routes[id], updateMap[id]) { if !reflect.DeepEqual(w.routes[id], updateMap[id]) {
isUpdateMapDifferent = true isUpdateMapDifferent = true
} }
} }
c.routes = updateMap w.routes = updateMap
return isUpdateMapDifferent return isUpdateMapDifferent
} }
// peersStateAndUpdateWatcher is the main point of reacting on client network routing events. // Start is the main point of reacting on client network routing events.
// All the processing related to the client network should be done here. Thread-safe. // All the processing related to the client network should be done here. Thread-safe.
func (c *clientNetwork) peersStateAndUpdateWatcher() { func (w *Watcher) Start() {
for { for {
select { select {
case <-c.ctx.Done(): case <-w.ctx.Done():
log.Debugf("Stopping watcher for network [%v]", c.handler)
if err := c.removeRouteFromPeerAndSystem(reasonShutdown); err != nil {
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
}
return return
case <-c.peerStateUpdate: case <-w.peerStateUpdate:
err := c.recalculateRouteAndUpdatePeerAndSystem(reasonPeerUpdate) if err := w.recalculateRoutes(reasonPeerUpdate); err != nil {
if err != nil { log.Errorf("Failed to recalculate routes for network [%v]: %v", w.handler, err)
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
} }
case update := <-c.routeUpdate: case update := <-w.routeUpdate:
if update.updateSerial < c.updateSerial { if update.UpdateSerial < w.updateSerial {
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial) log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", w.updateSerial, update.UpdateSerial)
continue continue
} }
log.Debugf("Received a new client network route update for [%v]", c.handler) w.handleRouteUpdate(update)
}
}
}
func (w *Watcher) handleRouteUpdate(update RoutesUpdate) {
log.Debugf("Received a new client network route update for [%v]", w.handler)
// hash update somehow // hash update somehow
isTrueRouteUpdate := c.handleUpdate(update) isTrueRouteUpdate := w.classifyUpdate(update)
c.updateSerial = update.updateSerial w.updateSerial = update.UpdateSerial
if isTrueRouteUpdate { if isTrueRouteUpdate {
log.Debug("Client network update contains different routes, recalculating routes") log.Debugf("client network update %v for [%v] contains different routes, recalculating routes", update.UpdateSerial, w.handler)
err := c.recalculateRouteAndUpdatePeerAndSystem(reasonRouteUpdate) if err := w.recalculateRoutes(reasonRouteUpdate); err != nil {
if err != nil { log.Errorf("failed to recalculate routes for network [%v]: %v", w.handler, err)
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
} }
} else { } else {
log.Debug("Route update is not different, skipping route recalculation") log.Debugf("route update %v for [%v] is not different, skipping route recalculation", update.UpdateSerial, w.handler)
} }
c.startPeersStatusChangeWatcher() w.startNewPeerStatusWatchers()
} }
// Stop stops the watcher and cleans up resources.
func (w *Watcher) Stop() {
log.Debugf("Stopping watcher for network [%v]", w.handler)
w.cancel()
if w.currentChosen == nil {
return
}
if err := w.removeAllowedIPs(w.currentChosen, reasonShutdown); err != nil {
log.Errorf("Failed to remove routes for [%v]: %v", w.handler, err)
} }
} }
func handlerFromRoute( func HandlerFromRoute(
rt *route.Route, rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter, routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,

View File

@ -1,4 +1,4 @@
package routemanager package client
import ( import (
"fmt" "fmt"
@ -395,7 +395,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
} }
// create new clientNetwork // create new clientNetwork
client := &clientNetwork{ client := &Watcher{
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil), handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
routes: tc.existingRoutes, routes: tc.existingRoutes,
currentChosen: currentRoute, currentChosen: currentRoute,

View File

@ -11,9 +11,11 @@ import (
"sync" "sync"
"time" "time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
@ -21,9 +23,11 @@ import (
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client"
"github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/server"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
@ -68,9 +72,9 @@ type DefaultManager struct {
ctx context.Context ctx context.Context
stop context.CancelFunc stop context.CancelFunc
mux sync.Mutex mux sync.Mutex
clientNetworks map[route.HAUniqueID]*clientNetwork clientNetworks map[route.HAUniqueID]*client.Watcher
routeSelector *routeselector.RouteSelector routeSelector *routeselector.RouteSelector
serverRouter *serverRouter serverRouter *server.Router
sysOps *systemops.SysOps sysOps *systemops.SysOps
statusRecorder *peer.Status statusRecorder *peer.Status
relayMgr *relayClient.Manager relayMgr *relayClient.Manager
@ -88,6 +92,7 @@ type DefaultManager struct {
useNewDNSRoute bool useNewDNSRoute bool
disableClientRoutes bool disableClientRoutes bool
disableServerRoutes bool disableServerRoutes bool
activeRoutes map[route.HAUniqueID]client.RouteHandler
} }
func NewManager(config ManagerConfig) *DefaultManager { func NewManager(config ManagerConfig) *DefaultManager {
@ -99,7 +104,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
ctx: mCTX, ctx: mCTX,
stop: cancel, stop: cancel,
dnsRouteInterval: config.DNSRouteInterval, dnsRouteInterval: config.DNSRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork), clientNetworks: make(map[route.HAUniqueID]*client.Watcher),
relayMgr: config.RelayManager, relayMgr: config.RelayManager,
sysOps: sysOps, sysOps: sysOps,
statusRecorder: config.StatusRecorder, statusRecorder: config.StatusRecorder,
@ -111,6 +116,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
peerStore: config.PeerStore, peerStore: config.PeerStore,
disableClientRoutes: config.DisableClientRoutes, disableClientRoutes: config.DisableClientRoutes,
disableServerRoutes: config.DisableServerRoutes, disableServerRoutes: config.DisableServerRoutes,
activeRoutes: make(map[route.HAUniqueID]client.RouteHandler),
} }
useNoop := netstack.IsEnabled() || config.DisableClientRoutes useNoop := netstack.IsEnabled() || config.DisableClientRoutes
@ -226,7 +232,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
} }
var err error var err error
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
if err != nil { if err != nil {
return err return err
} }
@ -237,7 +243,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
m.stop() m.stop()
if m.serverRouter != nil { if m.serverRouter != nil {
m.serverRouter.cleanUp() m.serverRouter.CleanUp()
} }
if m.routeRefCounter != nil { if m.routeRefCounter != nil {
@ -265,6 +271,54 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
} }
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
toAdd := make(map[route.HAUniqueID]*route.Route)
toRemove := make(map[route.HAUniqueID]client.RouteHandler)
for id, routes := range newRoutes {
if len(routes) > 0 {
toAdd[id] = routes[0]
}
}
for id, activeHandler := range m.activeRoutes {
if _, exists := toAdd[id]; exists {
delete(toAdd, id)
} else {
toRemove[id] = activeHandler
}
}
var merr *multierror.Error
for id, handler := range toRemove {
if err := handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
}
delete(m.activeRoutes, id)
}
for id, route := range toAdd {
handler := client.HandlerFromRoute(
route,
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsRouteInterval,
m.statusRecorder,
m.wgInterface,
m.dnsServer,
m.peerStore,
m.useNewDNSRoute,
)
if err := handler.AddRoute(m.ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
continue
}
m.activeRoutes[id] = handler
}
return nberrors.FormatErrorOrNil(merr)
}
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error { func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
@ -281,6 +335,11 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
if !m.disableClientRoutes { if !m.disableClientRoutes {
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
if err := m.updateSystemRoutes(filteredClientRoutes); err != nil {
log.Errorf("Failed to update system routes: %v", err)
}
m.updateClientNetworks(updateSerial, filteredClientRoutes) m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.OnNewRoutes(filteredClientRoutes) m.notifier.OnNewRoutes(filteredClientRoutes)
} }
@ -290,7 +349,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
return nil return nil
} }
if err := m.serverRouter.updateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil { if err := m.serverRouter.UpdateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil {
return fmt.Errorf("update routes: %w", err) return fmt.Errorf("update routes: %w", err)
} }
@ -341,6 +400,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
m.notifier.OnNewRoutes(networks) m.notifier.OnNewRoutes(networks)
if err := m.updateSystemRoutes(networks); err != nil {
log.Errorf("failed to update system routes during selection: %v", err)
}
m.stopObsoleteClients(networks) m.stopObsoleteClients(networks)
for id, routes := range networks { for id, routes := range networks {
@ -349,21 +412,24 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
continue continue
} }
clientNetworkWatcher := newClientNetworkWatcher( handler := m.activeRoutes[id]
m.ctx, if handler == nil {
m.dnsRouteInterval, log.Warnf("no active handler found for route %s", id)
m.wgInterface, continue
m.statusRecorder, }
routes[0],
m.routeRefCounter, config := client.WatcherConfig{
m.allowedIPsRefCounter, Context: m.ctx,
m.dnsServer, DNSRouteInterval: m.dnsRouteInterval,
m.peerStore, WGInterface: m.wgInterface,
m.useNewDNSRoute, StatusRecorder: m.statusRecorder,
) Route: routes[0],
Handler: handler,
}
clientNetworkWatcher := client.NewWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.Start()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
} }
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil { if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
@ -375,8 +441,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) { func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
for id, client := range m.clientNetworks { for id, client := range m.clientNetworks {
if _, ok := networks[id]; !ok { if _, ok := networks[id]; !ok {
log.Debugf("Stopping client network watcher, %s", id) client.Stop()
client.cancel()
delete(m.clientNetworks, id) delete(m.clientNetworks, id)
} }
} }
@ -389,26 +454,29 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks { for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id] clientNetworkWatcher, found := m.clientNetworks[id]
if !found { if !found {
clientNetworkWatcher = newClientNetworkWatcher( handler := m.activeRoutes[id]
m.ctx, if handler == nil {
m.dnsRouteInterval, log.Errorf("No active handler found for route %s", id)
m.wgInterface, continue
m.statusRecorder, }
routes[0],
m.routeRefCounter, config := client.WatcherConfig{
m.allowedIPsRefCounter, Context: m.ctx,
m.dnsServer, DNSRouteInterval: m.dnsRouteInterval,
m.peerStore, WGInterface: m.wgInterface,
m.useNewDNSRoute, StatusRecorder: m.statusRecorder,
) Route: routes[0],
Handler: handler,
}
clientNetworkWatcher = client.NewWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.Start()
} }
update := routesUpdate{ update := client.RoutesUpdate{
updateSerial: updateSerial, UpdateSerial: updateSerial,
routes: routes, Routes: routes,
} }
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update) clientNetworkWatcher.SendUpdate(update)
} }
} }

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"net/netip" "net/netip"
"runtime"
"testing" "testing"
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
@ -454,8 +453,8 @@ func TestManagerUpdateRoutes(t *testing.T) {
} }
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
if runtime.GOOS == "linux" && routeManager.serverRouter != nil { if routeManager.serverRouter != nil {
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match") require.Equal(t, testCase.serverRoutesExpected, routeManager.serverRouter.RoutesCount(), "server networks size should match")
} }
}) })
} }

View File

@ -1,4 +1,4 @@
package routemanager package server
import ( import (
"context" "context"
@ -14,7 +14,7 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type serverRouter struct { type Router struct {
mux sync.Mutex mux sync.Mutex
ctx context.Context ctx context.Context
routes map[route.ID]*route.Route routes map[route.ID]*route.Route
@ -23,8 +23,8 @@ type serverRouter struct {
statusRecorder *peer.Status statusRecorder *peer.Status
} }
func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) { func NewRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*Router, error) {
return &serverRouter{ return &Router{
ctx: ctx, ctx: ctx,
routes: make(map[route.ID]*route.Route), routes: make(map[route.ID]*route.Route),
firewall: firewall, firewall: firewall,
@ -33,104 +33,110 @@ func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall fi
}, nil }, nil
} }
func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error { func (r *Router) UpdateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error {
m.mux.Lock() r.mux.Lock()
defer m.mux.Unlock() defer r.mux.Unlock()
serverRoutesToRemove := make([]route.ID, 0) serverRoutesToRemove := make([]route.ID, 0)
for routeID := range m.routes { for routeID := range r.routes {
update, found := routesMap[routeID] update, found := routesMap[routeID]
if !found || !update.Equal(m.routes[routeID]) { if !found || !update.Equal(r.routes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID) serverRoutesToRemove = append(serverRoutesToRemove, routeID)
} }
} }
for _, routeID := range serverRoutesToRemove { for _, routeID := range serverRoutesToRemove {
oldRoute := m.routes[routeID] oldRoute := r.routes[routeID]
err := m.removeFromServerNetwork(oldRoute) err := r.removeFromServerNetwork(oldRoute)
if err != nil { if err != nil {
log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v", log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err) oldRoute.ID, oldRoute.Network, err)
} }
delete(m.routes, routeID) delete(r.routes, routeID)
} }
// If routing is to be disabled, do it after routes have been removed // If routing is to be disabled, do it after routes have been removed
// If routing is to be enabled, do it before adding new routes; addToServerNetwork needs routing to be enabled // If routing is to be enabled, do it before adding new routes; addToServerNetwork needs routing to be enabled
if len(routesMap) > 0 { if len(routesMap) > 0 {
if err := m.firewall.EnableRouting(); err != nil { if err := r.firewall.EnableRouting(); err != nil {
return fmt.Errorf("enable routing: %w", err) return fmt.Errorf("enable routing: %w", err)
} }
} else { } else {
if err := m.firewall.DisableRouting(); err != nil { if err := r.firewall.DisableRouting(); err != nil {
return fmt.Errorf("disable routing: %w", err) return fmt.Errorf("disable routing: %w", err)
} }
} }
for id, newRoute := range routesMap { for id, newRoute := range routesMap {
_, found := m.routes[id] _, found := r.routes[id]
if found { if found {
continue continue
} }
err := m.addToServerNetwork(newRoute, useNewDNSRoute) err := r.addToServerNetwork(newRoute, useNewDNSRoute)
if err != nil { if err != nil {
log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err) log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
continue continue
} }
m.routes[id] = newRoute r.routes[id] = newRoute
} }
return nil return nil
} }
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error { func (r *Router) removeFromServerNetwork(route *route.Route) error {
if m.ctx.Err() != nil { if r.ctx.Err() != nil {
log.Infof("Not removing from server network because context is done") log.Infof("Not removing from server network because context is done")
return m.ctx.Err() return r.ctx.Err()
} }
routerPair := routeToRouterPair(route, false) routerPair := routeToRouterPair(route, false)
if err := m.firewall.RemoveNatRule(routerPair); err != nil { if err := r.firewall.RemoveNatRule(routerPair); err != nil {
return fmt.Errorf("remove routing rules: %w", err) return fmt.Errorf("remove routing rules: %w", err)
} }
delete(m.routes, route.ID) delete(r.routes, route.ID)
m.statusRecorder.RemoveLocalPeerStateRoute(route.NetString()) r.statusRecorder.RemoveLocalPeerStateRoute(route.NetString())
return nil return nil
} }
func (m *serverRouter) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error { func (r *Router) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error {
if m.ctx.Err() != nil { if r.ctx.Err() != nil {
log.Infof("Not adding to server network because context is done") log.Infof("Not adding to server network because context is done")
return m.ctx.Err() return r.ctx.Err()
} }
routerPair := routeToRouterPair(route, useNewDNSRoute) routerPair := routeToRouterPair(route, useNewDNSRoute)
if err := m.firewall.AddNatRule(routerPair); err != nil { if err := r.firewall.AddNatRule(routerPair); err != nil {
return fmt.Errorf("insert routing rules: %w", err) return fmt.Errorf("insert routing rules: %w", err)
} }
m.routes[route.ID] = route r.routes[route.ID] = route
m.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID()) r.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID())
return nil return nil
} }
func (m *serverRouter) cleanUp() { func (r *Router) CleanUp() {
m.mux.Lock() r.mux.Lock()
defer m.mux.Unlock() defer r.mux.Unlock()
for _, r := range m.routes { for _, route := range r.routes {
routerPair := routeToRouterPair(r, false) routerPair := routeToRouterPair(route, false)
if err := m.firewall.RemoveNatRule(routerPair); err != nil { if err := r.firewall.RemoveNatRule(routerPair); err != nil {
log.Errorf("Failed to remove cleanup route: %v", err) log.Errorf("Failed to remove cleanup route: %v", err)
} }
} }
m.statusRecorder.CleanLocalPeerStateRoutes() r.statusRecorder.CleanLocalPeerStateRoutes()
}
func (r *Router) RoutesCount() int {
r.mux.Lock()
defer r.mux.Unlock()
return len(r.routes)
} }
func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterPair { func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterPair {

View File

@ -29,14 +29,18 @@ func (r *Route) String() string {
} }
func (r *Route) AddRoute(context.Context) error { func (r *Route) AddRoute(context.Context) error {
_, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}) if _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}); err != nil {
return err return err
} }
return nil
}
func (r *Route) RemoveRoute() error { func (r *Route) RemoveRoute() error {
_, err := r.routeRefCounter.Decrement(r.route.Network) if _, err := r.routeRefCounter.Decrement(r.route.Network); err != nil {
return err return err
} }
return nil
}
func (r *Route) AddAllowedIPs(peerKey string) error { func (r *Route) AddAllowedIPs(peerKey string) error {
if ref, err := r.allowedIPsRefcounter.Increment(r.route.Network, peerKey); err != nil { if ref, err := r.allowedIPsRefcounter.Increment(r.route.Network, peerKey); err != nil {
@ -51,6 +55,8 @@ func (r *Route) AddAllowedIPs(peerKey string) error {
} }
func (r *Route) RemoveAllowedIPs() error { func (r *Route) RemoveAllowedIPs() error {
_, err := r.allowedIPsRefcounter.Decrement(r.route.Network) if _, err := r.allowedIPsRefcounter.Decrement(r.route.Network); err != nil {
return err return err
} }
return nil
}