mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 02:08:40 +02:00
[client] Apply routes right away instead of on peer connection (#3907)
This commit is contained in:
parent
1ce4ee0cef
commit
06980e7fa0
@ -489,7 +489,7 @@ func (s *DefaultServer) applyHostConfig() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("extra match domains: %v", s.extraDomains)
|
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
|
||||||
|
|
||||||
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
||||||
log.Errorf("failed to apply DNS host manager update: %v", err)
|
log.Errorf("failed to apply DNS host manager update: %v", err)
|
||||||
|
@ -994,6 +994,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protoDNSConfig := networkMap.GetDNSConfig()
|
||||||
|
if protoDNSConfig == nil {
|
||||||
|
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
||||||
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||||
|
|
||||||
// apply routes first, route related actions might depend on routing being enabled
|
// apply routes first, route related actions might depend on routing being enabled
|
||||||
@ -1061,15 +1070,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers())
|
excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers())
|
||||||
e.connMgr.SetExcludeList(excludedLazyPeers)
|
e.connMgr.SetExcludeList(excludedLazyPeers)
|
||||||
|
|
||||||
protoDNSConfig := networkMap.GetDNSConfig()
|
|
||||||
if protoDNSConfig == nil {
|
|
||||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
|
||||||
log.Errorf("failed to update dns server, err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.networkSerial = serial
|
e.networkSerial = serial
|
||||||
|
|
||||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package routemanager
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -7,10 +7,8 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
@ -36,6 +34,7 @@ const (
|
|||||||
reasonRouteUpdate
|
reasonRouteUpdate
|
||||||
reasonPeerUpdate
|
reasonPeerUpdate
|
||||||
reasonShutdown
|
reasonShutdown
|
||||||
|
reasonHA
|
||||||
)
|
)
|
||||||
|
|
||||||
type routerPeerStatus struct {
|
type routerPeerStatus struct {
|
||||||
@ -44,9 +43,9 @@ type routerPeerStatus struct {
|
|||||||
latency time.Duration
|
latency time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type routesUpdate struct {
|
type RoutesUpdate struct {
|
||||||
updateSerial uint64
|
UpdateSerial uint64
|
||||||
routes []*route.Route
|
Routes []*route.Route
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouteHandler defines the interface for handling routes
|
// RouteHandler defines the interface for handling routes
|
||||||
@ -58,64 +57,54 @@ type RouteHandler interface {
|
|||||||
RemoveAllowedIPs() error
|
RemoveAllowedIPs() error
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientNetwork struct {
|
type WatcherConfig struct {
|
||||||
|
Context context.Context
|
||||||
|
DNSRouteInterval time.Duration
|
||||||
|
WGInterface iface.WGIface
|
||||||
|
StatusRecorder *peer.Status
|
||||||
|
Route *route.Route
|
||||||
|
Handler RouteHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watcher watches route and peer changes and updates allowed IPs accordingly.
|
||||||
|
// Once stopped, it cannot be reused.
|
||||||
|
type Watcher struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
wgInterface iface.WGIface
|
wgInterface iface.WGIface
|
||||||
routes map[route.ID]*route.Route
|
routes map[route.ID]*route.Route
|
||||||
routeUpdate chan routesUpdate
|
routeUpdate chan RoutesUpdate
|
||||||
peerStateUpdate chan struct{}
|
peerStateUpdate chan struct{}
|
||||||
routePeersNotifiers map[string]chan struct{}
|
routePeersNotifiers map[string]chan struct{} // map of peer key to channel for peer state changes
|
||||||
currentChosen *route.Route
|
currentChosen *route.Route
|
||||||
handler RouteHandler
|
handler RouteHandler
|
||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientNetworkWatcher(
|
func NewWatcher(config WatcherConfig) *Watcher {
|
||||||
ctx context.Context,
|
ctx, cancel := context.WithCancel(config.Context)
|
||||||
dnsRouteInterval time.Duration,
|
|
||||||
wgInterface iface.WGIface,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
rt *route.Route,
|
|
||||||
routeRefCounter *refcounter.RouteRefCounter,
|
|
||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
|
||||||
dnsServer nbdns.Server,
|
|
||||||
peerStore *peerstore.Store,
|
|
||||||
useNewDNSRoute bool,
|
|
||||||
) *clientNetwork {
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
|
|
||||||
client := &clientNetwork{
|
client := &Watcher{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: config.StatusRecorder,
|
||||||
wgInterface: wgInterface,
|
wgInterface: config.WGInterface,
|
||||||
routes: make(map[route.ID]*route.Route),
|
routes: make(map[route.ID]*route.Route),
|
||||||
routePeersNotifiers: make(map[string]chan struct{}),
|
routePeersNotifiers: make(map[string]chan struct{}),
|
||||||
routeUpdate: make(chan routesUpdate),
|
routeUpdate: make(chan RoutesUpdate),
|
||||||
peerStateUpdate: make(chan struct{}),
|
peerStateUpdate: make(chan struct{}),
|
||||||
handler: handlerFromRoute(
|
handler: config.Handler,
|
||||||
rt,
|
|
||||||
routeRefCounter,
|
|
||||||
allowedIPsRefCounter,
|
|
||||||
dnsRouteInterval,
|
|
||||||
statusRecorder,
|
|
||||||
wgInterface,
|
|
||||||
dnsServer,
|
|
||||||
peerStore,
|
|
||||||
useNewDNSRoute,
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
|
func (w *Watcher) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
|
||||||
routePeerStatuses := make(map[route.ID]routerPeerStatus)
|
routePeerStatuses := make(map[route.ID]routerPeerStatus)
|
||||||
for _, r := range c.routes {
|
for _, r := range w.routes {
|
||||||
peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
|
peerStatus, err := w.statusRecorder.GetPeer(r.Peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("couldn't fetch peer state: %v", err)
|
log.Debugf("couldn't fetch peer state %v: %v", r.Peer, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
routePeerStatuses[r.ID] = routerPeerStatus{
|
routePeerStatuses[r.ID] = routerPeerStatus{
|
||||||
@ -128,7 +117,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getBestRouteFromStatuses determines the most optimal route from the available routes
|
// getBestRouteFromStatuses determines the most optimal route from the available routes
|
||||||
// within a clientNetwork, taking into account peer connection status, route metrics, and
|
// within a Watcher, taking into account peer connection status, route metrics, and
|
||||||
// preference for non-relayed and direct connections.
|
// preference for non-relayed and direct connections.
|
||||||
//
|
//
|
||||||
// It follows these prioritization rules:
|
// It follows these prioritization rules:
|
||||||
@ -140,17 +129,17 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
|
|||||||
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
||||||
//
|
//
|
||||||
// It returns the ID of the selected optimal route.
|
// It returns the ID of the selected optimal route.
|
||||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
|
func (w *Watcher) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
|
||||||
chosen := route.ID("")
|
var chosen route.ID
|
||||||
chosenScore := float64(0)
|
chosenScore := float64(0)
|
||||||
currScore := float64(0)
|
currScore := float64(0)
|
||||||
|
|
||||||
currID := route.ID("")
|
var currID route.ID
|
||||||
if c.currentChosen != nil {
|
if w.currentChosen != nil {
|
||||||
currID = c.currentChosen.ID
|
currID = w.currentChosen.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range c.routes {
|
for _, r := range w.routes {
|
||||||
tempScore := float64(0)
|
tempScore := float64(0)
|
||||||
peerStatus, found := routePeerStatuses[r.ID]
|
peerStatus, found := routePeerStatuses[r.ID]
|
||||||
if !found || !peerStatus.connected {
|
if !found || !peerStatus.connected {
|
||||||
@ -167,7 +156,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
|||||||
if peerStatus.latency != 0 {
|
if peerStatus.latency != 0 {
|
||||||
latency = peerStatus.latency
|
latency = peerStatus.latency
|
||||||
} else {
|
} else {
|
||||||
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
|
log.Tracef("peer %s has 0 latency, range %s", r.Peer, w.handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// avoid negative tempScore on the higher latency calculation
|
// avoid negative tempScore on the higher latency calculation
|
||||||
@ -197,35 +186,45 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)
|
chosenID := chosen
|
||||||
|
if chosen == "" {
|
||||||
|
chosenID = "<none>"
|
||||||
|
}
|
||||||
|
currentID := currID
|
||||||
|
if currID == "" {
|
||||||
|
currentID = "<none>"
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosenID, chosenScore, currentID, currScore)
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case chosen == "":
|
case chosen == "":
|
||||||
var peers []string
|
var peers []string
|
||||||
for _, r := range c.routes {
|
for _, r := range w.routes {
|
||||||
peers = append(peers, r.Peer)
|
peers = append(peers, r.Peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers)
|
log.Infof("network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", w.handler, peers)
|
||||||
case chosen != currID:
|
case chosen != currID:
|
||||||
// we compare the current score + 10ms to the chosen score to avoid flapping between routes
|
// we compare the current score + 10ms to the chosen score to avoid flapping between routes
|
||||||
if currScore != 0 && currScore+0.01 > chosenScore {
|
if currScore != 0 && currScore+0.01 > chosenScore {
|
||||||
log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
|
log.Debugf("keeping current routing peer %s for [%v]: the score difference with latency is less than 0.01(10ms): current: %f, new: %f",
|
||||||
|
w.currentChosen.Peer, w.handler, currScore, chosenScore)
|
||||||
return currID
|
return currID
|
||||||
}
|
}
|
||||||
var p string
|
var p string
|
||||||
if rt := c.routes[chosen]; rt != nil {
|
if rt := w.routes[chosen]; rt != nil {
|
||||||
p = rt.Peer
|
p = rt.Peer
|
||||||
}
|
}
|
||||||
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler)
|
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, w.handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
return chosen
|
return chosen
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) {
|
func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) {
|
||||||
subscription := c.statusRecorder.SubscribeToPeerStateChanges(ctx, peerKey)
|
subscription := w.statusRecorder.SubscribeToPeerStateChanges(ctx, peerKey)
|
||||||
defer c.statusRecorder.UnsubscribePeerStateChanges(subscription)
|
defer w.statusRecorder.UnsubscribePeerStateChanges(subscription)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@ -240,105 +239,92 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
func (w *Watcher) startNewPeerStatusWatchers() {
|
||||||
for _, r := range c.routes {
|
for _, r := range w.routes {
|
||||||
_, found := c.routePeersNotifiers[r.Peer]
|
if _, found := w.routePeersNotifiers[r.Peer]; found {
|
||||||
if found {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
closerChan := make(chan struct{})
|
closerChan := make(chan struct{})
|
||||||
c.routePeersNotifiers[r.Peer] = closerChan
|
w.routePeersNotifiers[r.Peer] = closerChan
|
||||||
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
|
go w.watchPeerStatusChanges(w.ctx, r.Peer, w.peerStateUpdate, closerChan)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
|
// addAllowedIPs adds the allowed IPs for the current chosen route to the handler.
|
||||||
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
|
func (w *Watcher) addAllowedIPs(route *route.Route) error {
|
||||||
|
if err := w.handler.AddAllowedIPs(route.Peer); err != nil {
|
||||||
|
return fmt.Errorf("add allowed IPs for peer %s: %w", route.Peer, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.statusRecorder.AddPeerStateRoute(route.Peer, w.handler.String(), route.GetResourceID()); err != nil {
|
||||||
log.Warnf("Failed to update peer state: %v", err)
|
log.Warnf("Failed to update peer state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.handler.RemoveAllowedIPs(); err != nil {
|
w.connectEvent(route)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) removeAllowedIPs(route *route.Route, rsn reason) error {
|
||||||
|
if err := w.statusRecorder.RemovePeerStateRoute(route.Peer, w.handler.String()); err != nil {
|
||||||
|
log.Warnf("Failed to update peer state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.handler.RemoveAllowedIPs(); err != nil {
|
||||||
return fmt.Errorf("remove allowed IPs: %w", err)
|
return fmt.Errorf("remove allowed IPs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
w.disconnectEvent(route, rsn)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) removeRouteFromPeerAndSystem(rsn reason) error {
|
func (w *Watcher) recalculateRoutes(rsn reason) error {
|
||||||
if c.currentChosen == nil {
|
routerPeerStatuses := w.getRouterPeerStatuses()
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
newChosenID := w.getBestRouteFromStatuses(routerPeerStatuses)
|
||||||
|
|
||||||
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
// If no route is chosen, remove the route from the peer
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
|
|
||||||
}
|
|
||||||
if err := c.handler.RemoveRoute(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
c.disconnectEvent(rsn)
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error {
|
|
||||||
routerPeerStatuses := c.getRouterPeerStatuses()
|
|
||||||
|
|
||||||
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
|
|
||||||
|
|
||||||
// If no route is chosen, remove the route from the peer and system
|
|
||||||
if newChosenID == "" {
|
if newChosenID == "" {
|
||||||
if err := c.removeRouteFromPeerAndSystem(rsn); err != nil {
|
if w.currentChosen == nil {
|
||||||
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c.currentChosen = nil
|
if err := w.removeAllowedIPs(w.currentChosen, rsn); err != nil {
|
||||||
|
return fmt.Errorf("remove obsolete: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.currentChosen = nil
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the chosen route is the same as the current route, do nothing
|
// If the chosen route is the same as the current route, do nothing
|
||||||
if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
|
if w.currentChosen != nil && w.currentChosen.ID == newChosenID &&
|
||||||
c.currentChosen.Equal(c.routes[newChosenID]) {
|
w.currentChosen.Equal(w.routes[newChosenID]) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var isNew bool
|
// If the chosen route was assigned to a different peer, remove the allowed IPs first
|
||||||
if c.currentChosen == nil {
|
if isNew := w.currentChosen == nil; !isNew {
|
||||||
// If they were not previously assigned to another peer, add routes to the system first
|
if err := w.removeAllowedIPs(w.currentChosen, reasonHA); err != nil {
|
||||||
if err := c.handler.AddRoute(c.ctx); err != nil {
|
return fmt.Errorf("remove old: %w", err)
|
||||||
return fmt.Errorf("add route: %w", err)
|
|
||||||
}
|
|
||||||
isNew = true
|
|
||||||
} else {
|
|
||||||
// Otherwise, remove the allowed IPs from the previous peer first
|
|
||||||
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
|
||||||
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.currentChosen = c.routes[newChosenID]
|
newChosenRoute := w.routes[newChosenID]
|
||||||
|
if err := w.addAllowedIPs(newChosenRoute); err != nil {
|
||||||
if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil {
|
return fmt.Errorf("add new: %w", err)
|
||||||
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if isNew {
|
w.currentChosen = newChosenRoute
|
||||||
c.connectEvent()
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String(), c.currentChosen.GetResourceID())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("add peer state route: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) connectEvent() {
|
func (w *Watcher) connectEvent(route *route.Route) {
|
||||||
var defaultRoute bool
|
var defaultRoute bool
|
||||||
for _, r := range c.routes {
|
for _, r := range w.routes {
|
||||||
if r.Network.Bits() == 0 {
|
if r.Network.Bits() == 0 {
|
||||||
defaultRoute = true
|
defaultRoute = true
|
||||||
break
|
break
|
||||||
@ -350,13 +336,13 @@ func (c *clientNetwork) connectEvent() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]string{
|
meta := map[string]string{
|
||||||
"network": c.handler.String(),
|
"network": w.handler.String(),
|
||||||
}
|
}
|
||||||
if c.currentChosen != nil {
|
if route != nil {
|
||||||
meta["id"] = string(c.currentChosen.NetID)
|
meta["id"] = string(route.NetID)
|
||||||
meta["peer"] = c.currentChosen.Peer
|
meta["peer"] = route.Peer
|
||||||
}
|
}
|
||||||
c.statusRecorder.PublishEvent(
|
w.statusRecorder.PublishEvent(
|
||||||
proto.SystemEvent_INFO,
|
proto.SystemEvent_INFO,
|
||||||
proto.SystemEvent_NETWORK,
|
proto.SystemEvent_NETWORK,
|
||||||
"Default route added",
|
"Default route added",
|
||||||
@ -365,9 +351,9 @@ func (c *clientNetwork) connectEvent() {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) disconnectEvent(rsn reason) {
|
func (w *Watcher) disconnectEvent(route *route.Route, rsn reason) {
|
||||||
var defaultRoute bool
|
var defaultRoute bool
|
||||||
for _, r := range c.routes {
|
for _, r := range w.routes {
|
||||||
if r.Network.Bits() == 0 {
|
if r.Network.Bits() == 0 {
|
||||||
defaultRoute = true
|
defaultRoute = true
|
||||||
break
|
break
|
||||||
@ -383,11 +369,11 @@ func (c *clientNetwork) disconnectEvent(rsn reason) {
|
|||||||
var userMessage string
|
var userMessage string
|
||||||
meta := make(map[string]string)
|
meta := make(map[string]string)
|
||||||
|
|
||||||
if c.currentChosen != nil {
|
if route != nil {
|
||||||
meta["id"] = string(c.currentChosen.NetID)
|
meta["id"] = string(route.NetID)
|
||||||
meta["peer"] = c.currentChosen.Peer
|
meta["peer"] = route.Peer
|
||||||
}
|
}
|
||||||
meta["network"] = c.handler.String()
|
meta["network"] = w.handler.String()
|
||||||
switch rsn {
|
switch rsn {
|
||||||
case reasonShutdown:
|
case reasonShutdown:
|
||||||
severity = proto.SystemEvent_INFO
|
severity = proto.SystemEvent_INFO
|
||||||
@ -400,13 +386,17 @@ func (c *clientNetwork) disconnectEvent(rsn reason) {
|
|||||||
severity = proto.SystemEvent_WARNING
|
severity = proto.SystemEvent_WARNING
|
||||||
message = "Default route disconnected due to peer unreachability"
|
message = "Default route disconnected due to peer unreachability"
|
||||||
userMessage = "Exit node connection lost. Your internet access might be affected."
|
userMessage = "Exit node connection lost. Your internet access might be affected."
|
||||||
|
case reasonHA:
|
||||||
|
severity = proto.SystemEvent_INFO
|
||||||
|
message = "Default route disconnected due to high availability change"
|
||||||
|
userMessage = "Exit node disconnected due to high availability change."
|
||||||
default:
|
default:
|
||||||
severity = proto.SystemEvent_ERROR
|
severity = proto.SystemEvent_ERROR
|
||||||
message = "Default route disconnected for unknown reasons"
|
message = "Default route disconnected for unknown reasons"
|
||||||
userMessage = "Exit node disconnected for unknown reasons."
|
userMessage = "Exit node disconnected for unknown reasons."
|
||||||
}
|
}
|
||||||
|
|
||||||
c.statusRecorder.PublishEvent(
|
w.statusRecorder.PublishEvent(
|
||||||
severity,
|
severity,
|
||||||
proto.SystemEvent_NETWORK,
|
proto.SystemEvent_NETWORK,
|
||||||
message,
|
message,
|
||||||
@ -415,86 +405,101 @@ func (c *clientNetwork) disconnectEvent(rsn reason) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
func (w *Watcher) SendUpdate(update RoutesUpdate) {
|
||||||
go func() {
|
go func() {
|
||||||
c.routeUpdate <- update
|
select {
|
||||||
|
case w.routeUpdate <- update:
|
||||||
|
case <-w.ctx.Done():
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) handleUpdate(update routesUpdate) bool {
|
func (w *Watcher) classifyUpdate(update RoutesUpdate) bool {
|
||||||
isUpdateMapDifferent := false
|
isUpdateMapDifferent := false
|
||||||
updateMap := make(map[route.ID]*route.Route)
|
updateMap := make(map[route.ID]*route.Route)
|
||||||
|
|
||||||
for _, r := range update.routes {
|
for _, r := range update.Routes {
|
||||||
updateMap[r.ID] = r
|
updateMap[r.ID] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.routes) != len(updateMap) {
|
if len(w.routes) != len(updateMap) {
|
||||||
isUpdateMapDifferent = true
|
isUpdateMapDifferent = true
|
||||||
}
|
}
|
||||||
|
|
||||||
for id, r := range c.routes {
|
for id, r := range w.routes {
|
||||||
_, found := updateMap[id]
|
_, found := updateMap[id]
|
||||||
if !found {
|
if !found {
|
||||||
close(c.routePeersNotifiers[r.Peer])
|
close(w.routePeersNotifiers[r.Peer])
|
||||||
delete(c.routePeersNotifiers, r.Peer)
|
delete(w.routePeersNotifiers, r.Peer)
|
||||||
isUpdateMapDifferent = true
|
isUpdateMapDifferent = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(c.routes[id], updateMap[id]) {
|
if !reflect.DeepEqual(w.routes[id], updateMap[id]) {
|
||||||
isUpdateMapDifferent = true
|
isUpdateMapDifferent = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.routes = updateMap
|
w.routes = updateMap
|
||||||
return isUpdateMapDifferent
|
return isUpdateMapDifferent
|
||||||
}
|
}
|
||||||
|
|
||||||
// peersStateAndUpdateWatcher is the main point of reacting on client network routing events.
|
// Start is the main point of reacting on client network routing events.
|
||||||
// All the processing related to the client network should be done here. Thread-safe.
|
// All the processing related to the client network should be done here. Thread-safe.
|
||||||
func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
func (w *Watcher) Start() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-c.ctx.Done():
|
case <-w.ctx.Done():
|
||||||
log.Debugf("Stopping watcher for network [%v]", c.handler)
|
|
||||||
if err := c.removeRouteFromPeerAndSystem(reasonShutdown); err != nil {
|
|
||||||
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
case <-c.peerStateUpdate:
|
case <-w.peerStateUpdate:
|
||||||
err := c.recalculateRouteAndUpdatePeerAndSystem(reasonPeerUpdate)
|
if err := w.recalculateRoutes(reasonPeerUpdate); err != nil {
|
||||||
if err != nil {
|
log.Errorf("Failed to recalculate routes for network [%v]: %v", w.handler, err)
|
||||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
|
||||||
}
|
}
|
||||||
case update := <-c.routeUpdate:
|
case update := <-w.routeUpdate:
|
||||||
if update.updateSerial < c.updateSerial {
|
if update.UpdateSerial < w.updateSerial {
|
||||||
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial)
|
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", w.updateSerial, update.UpdateSerial)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Received a new client network route update for [%v]", c.handler)
|
w.handleRouteUpdate(update)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Watcher) handleRouteUpdate(update RoutesUpdate) {
|
||||||
|
log.Debugf("Received a new client network route update for [%v]", w.handler)
|
||||||
|
|
||||||
// hash update somehow
|
// hash update somehow
|
||||||
isTrueRouteUpdate := c.handleUpdate(update)
|
isTrueRouteUpdate := w.classifyUpdate(update)
|
||||||
|
|
||||||
c.updateSerial = update.updateSerial
|
w.updateSerial = update.UpdateSerial
|
||||||
|
|
||||||
if isTrueRouteUpdate {
|
if isTrueRouteUpdate {
|
||||||
log.Debug("Client network update contains different routes, recalculating routes")
|
log.Debugf("client network update %v for [%v] contains different routes, recalculating routes", update.UpdateSerial, w.handler)
|
||||||
err := c.recalculateRouteAndUpdatePeerAndSystem(reasonRouteUpdate)
|
if err := w.recalculateRoutes(reasonRouteUpdate); err != nil {
|
||||||
if err != nil {
|
log.Errorf("failed to recalculate routes for network [%v]: %v", w.handler, err)
|
||||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Debug("Route update is not different, skipping route recalculation")
|
log.Debugf("route update %v for [%v] is not different, skipping route recalculation", update.UpdateSerial, w.handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.startPeersStatusChangeWatcher()
|
w.startNewPeerStatusWatchers()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop stops the watcher and cleans up resources.
|
||||||
|
func (w *Watcher) Stop() {
|
||||||
|
log.Debugf("Stopping watcher for network [%v]", w.handler)
|
||||||
|
|
||||||
|
w.cancel()
|
||||||
|
|
||||||
|
if w.currentChosen == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := w.removeAllowedIPs(w.currentChosen, reasonShutdown); err != nil {
|
||||||
|
log.Errorf("Failed to remove routes for [%v]: %v", w.handler, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handlerFromRoute(
|
func HandlerFromRoute(
|
||||||
rt *route.Route,
|
rt *route.Route,
|
||||||
routeRefCounter *refcounter.RouteRefCounter,
|
routeRefCounter *refcounter.RouteRefCounter,
|
||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
@ -1,4 +1,4 @@
|
|||||||
package routemanager
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -395,7 +395,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create new clientNetwork
|
// create new clientNetwork
|
||||||
client := &clientNetwork{
|
client := &Watcher{
|
||||||
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
|
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
|
||||||
routes: tc.existingRoutes,
|
routes: tc.existingRoutes,
|
||||||
currentChosen: currentRoute,
|
currentChosen: currentRoute,
|
@ -11,9 +11,11 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
@ -21,9 +23,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/server"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
@ -68,9 +72,9 @@ type DefaultManager struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
stop context.CancelFunc
|
stop context.CancelFunc
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
clientNetworks map[route.HAUniqueID]*clientNetwork
|
clientNetworks map[route.HAUniqueID]*client.Watcher
|
||||||
routeSelector *routeselector.RouteSelector
|
routeSelector *routeselector.RouteSelector
|
||||||
serverRouter *serverRouter
|
serverRouter *server.Router
|
||||||
sysOps *systemops.SysOps
|
sysOps *systemops.SysOps
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
relayMgr *relayClient.Manager
|
relayMgr *relayClient.Manager
|
||||||
@ -88,6 +92,7 @@ type DefaultManager struct {
|
|||||||
useNewDNSRoute bool
|
useNewDNSRoute bool
|
||||||
disableClientRoutes bool
|
disableClientRoutes bool
|
||||||
disableServerRoutes bool
|
disableServerRoutes bool
|
||||||
|
activeRoutes map[route.HAUniqueID]client.RouteHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(config ManagerConfig) *DefaultManager {
|
func NewManager(config ManagerConfig) *DefaultManager {
|
||||||
@ -99,7 +104,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
|||||||
ctx: mCTX,
|
ctx: mCTX,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
dnsRouteInterval: config.DNSRouteInterval,
|
dnsRouteInterval: config.DNSRouteInterval,
|
||||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
clientNetworks: make(map[route.HAUniqueID]*client.Watcher),
|
||||||
relayMgr: config.RelayManager,
|
relayMgr: config.RelayManager,
|
||||||
sysOps: sysOps,
|
sysOps: sysOps,
|
||||||
statusRecorder: config.StatusRecorder,
|
statusRecorder: config.StatusRecorder,
|
||||||
@ -111,6 +116,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
|||||||
peerStore: config.PeerStore,
|
peerStore: config.PeerStore,
|
||||||
disableClientRoutes: config.DisableClientRoutes,
|
disableClientRoutes: config.DisableClientRoutes,
|
||||||
disableServerRoutes: config.DisableServerRoutes,
|
disableServerRoutes: config.DisableServerRoutes,
|
||||||
|
activeRoutes: make(map[route.HAUniqueID]client.RouteHandler),
|
||||||
}
|
}
|
||||||
|
|
||||||
useNoop := netstack.IsEnabled() || config.DisableClientRoutes
|
useNoop := netstack.IsEnabled() || config.DisableClientRoutes
|
||||||
@ -226,7 +232,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -237,7 +243,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
|||||||
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||||
m.stop()
|
m.stop()
|
||||||
if m.serverRouter != nil {
|
if m.serverRouter != nil {
|
||||||
m.serverRouter.cleanUp()
|
m.serverRouter.CleanUp()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.routeRefCounter != nil {
|
if m.routeRefCounter != nil {
|
||||||
@ -265,6 +271,54 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||||
|
func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
||||||
|
toAdd := make(map[route.HAUniqueID]*route.Route)
|
||||||
|
toRemove := make(map[route.HAUniqueID]client.RouteHandler)
|
||||||
|
|
||||||
|
for id, routes := range newRoutes {
|
||||||
|
if len(routes) > 0 {
|
||||||
|
toAdd[id] = routes[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, activeHandler := range m.activeRoutes {
|
||||||
|
if _, exists := toAdd[id]; exists {
|
||||||
|
delete(toAdd, id)
|
||||||
|
} else {
|
||||||
|
toRemove[id] = activeHandler
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for id, handler := range toRemove {
|
||||||
|
if err := handler.RemoveRoute(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
|
||||||
|
}
|
||||||
|
delete(m.activeRoutes, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, route := range toAdd {
|
||||||
|
handler := client.HandlerFromRoute(
|
||||||
|
route,
|
||||||
|
m.routeRefCounter,
|
||||||
|
m.allowedIPsRefCounter,
|
||||||
|
m.dnsRouteInterval,
|
||||||
|
m.statusRecorder,
|
||||||
|
m.wgInterface,
|
||||||
|
m.dnsServer,
|
||||||
|
m.peerStore,
|
||||||
|
m.useNewDNSRoute,
|
||||||
|
)
|
||||||
|
if err := handler.AddRoute(m.ctx); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.activeRoutes[id] = handler
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error {
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
@ -281,6 +335,11 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
|||||||
|
|
||||||
if !m.disableClientRoutes {
|
if !m.disableClientRoutes {
|
||||||
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
||||||
|
|
||||||
|
if err := m.updateSystemRoutes(filteredClientRoutes); err != nil {
|
||||||
|
log.Errorf("Failed to update system routes: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||||
}
|
}
|
||||||
@ -290,7 +349,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.serverRouter.updateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil {
|
if err := m.serverRouter.UpdateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil {
|
||||||
return fmt.Errorf("update routes: %w", err)
|
return fmt.Errorf("update routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -341,6 +400,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
|||||||
|
|
||||||
m.notifier.OnNewRoutes(networks)
|
m.notifier.OnNewRoutes(networks)
|
||||||
|
|
||||||
|
if err := m.updateSystemRoutes(networks); err != nil {
|
||||||
|
log.Errorf("failed to update system routes during selection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
m.stopObsoleteClients(networks)
|
m.stopObsoleteClients(networks)
|
||||||
|
|
||||||
for id, routes := range networks {
|
for id, routes := range networks {
|
||||||
@ -349,21 +412,24 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
clientNetworkWatcher := newClientNetworkWatcher(
|
handler := m.activeRoutes[id]
|
||||||
m.ctx,
|
if handler == nil {
|
||||||
m.dnsRouteInterval,
|
log.Warnf("no active handler found for route %s", id)
|
||||||
m.wgInterface,
|
continue
|
||||||
m.statusRecorder,
|
}
|
||||||
routes[0],
|
|
||||||
m.routeRefCounter,
|
config := client.WatcherConfig{
|
||||||
m.allowedIPsRefCounter,
|
Context: m.ctx,
|
||||||
m.dnsServer,
|
DNSRouteInterval: m.dnsRouteInterval,
|
||||||
m.peerStore,
|
WGInterface: m.wgInterface,
|
||||||
m.useNewDNSRoute,
|
StatusRecorder: m.statusRecorder,
|
||||||
)
|
Route: routes[0],
|
||||||
|
Handler: handler,
|
||||||
|
}
|
||||||
|
clientNetworkWatcher := client.NewWatcher(config)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
go clientNetworkWatcher.Start()
|
||||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
|
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
|
||||||
@ -375,8 +441,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
|||||||
func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
|
func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
|
||||||
for id, client := range m.clientNetworks {
|
for id, client := range m.clientNetworks {
|
||||||
if _, ok := networks[id]; !ok {
|
if _, ok := networks[id]; !ok {
|
||||||
log.Debugf("Stopping client network watcher, %s", id)
|
client.Stop()
|
||||||
client.cancel()
|
|
||||||
delete(m.clientNetworks, id)
|
delete(m.clientNetworks, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -389,26 +454,29 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
|||||||
for id, routes := range networks {
|
for id, routes := range networks {
|
||||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||||
if !found {
|
if !found {
|
||||||
clientNetworkWatcher = newClientNetworkWatcher(
|
handler := m.activeRoutes[id]
|
||||||
m.ctx,
|
if handler == nil {
|
||||||
m.dnsRouteInterval,
|
log.Errorf("No active handler found for route %s", id)
|
||||||
m.wgInterface,
|
continue
|
||||||
m.statusRecorder,
|
}
|
||||||
routes[0],
|
|
||||||
m.routeRefCounter,
|
config := client.WatcherConfig{
|
||||||
m.allowedIPsRefCounter,
|
Context: m.ctx,
|
||||||
m.dnsServer,
|
DNSRouteInterval: m.dnsRouteInterval,
|
||||||
m.peerStore,
|
WGInterface: m.wgInterface,
|
||||||
m.useNewDNSRoute,
|
StatusRecorder: m.statusRecorder,
|
||||||
)
|
Route: routes[0],
|
||||||
|
Handler: handler,
|
||||||
|
}
|
||||||
|
clientNetworkWatcher = client.NewWatcher(config)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
go clientNetworkWatcher.Start()
|
||||||
}
|
}
|
||||||
update := routesUpdate{
|
update := client.RoutesUpdate{
|
||||||
updateSerial: updateSerial,
|
UpdateSerial: updateSerial,
|
||||||
routes: routes,
|
Routes: routes,
|
||||||
}
|
}
|
||||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
clientNetworkWatcher.SendUpdate(update)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/pion/transport/v3/stdnet"
|
||||||
@ -454,8 +453,8 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
||||||
|
|
||||||
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
if routeManager.serverRouter != nil {
|
||||||
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
|
require.Equal(t, testCase.serverRoutesExpected, routeManager.serverRouter.RoutesCount(), "server networks size should match")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package routemanager
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
type serverRouter struct {
|
type Router struct {
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
routes map[route.ID]*route.Route
|
routes map[route.ID]*route.Route
|
||||||
@ -23,8 +23,8 @@ type serverRouter struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
|
func NewRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*Router, error) {
|
||||||
return &serverRouter{
|
return &Router{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
routes: make(map[route.ID]*route.Route),
|
routes: make(map[route.ID]*route.Route),
|
||||||
firewall: firewall,
|
firewall: firewall,
|
||||||
@ -33,104 +33,110 @@ func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall fi
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error {
|
func (r *Router) UpdateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error {
|
||||||
m.mux.Lock()
|
r.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer r.mux.Unlock()
|
||||||
|
|
||||||
serverRoutesToRemove := make([]route.ID, 0)
|
serverRoutesToRemove := make([]route.ID, 0)
|
||||||
|
|
||||||
for routeID := range m.routes {
|
for routeID := range r.routes {
|
||||||
update, found := routesMap[routeID]
|
update, found := routesMap[routeID]
|
||||||
if !found || !update.Equal(m.routes[routeID]) {
|
if !found || !update.Equal(r.routes[routeID]) {
|
||||||
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, routeID := range serverRoutesToRemove {
|
for _, routeID := range serverRoutesToRemove {
|
||||||
oldRoute := m.routes[routeID]
|
oldRoute := r.routes[routeID]
|
||||||
err := m.removeFromServerNetwork(oldRoute)
|
err := r.removeFromServerNetwork(oldRoute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v",
|
log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v",
|
||||||
oldRoute.ID, oldRoute.Network, err)
|
oldRoute.ID, oldRoute.Network, err)
|
||||||
}
|
}
|
||||||
delete(m.routes, routeID)
|
delete(r.routes, routeID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If routing is to be disabled, do it after routes have been removed
|
// If routing is to be disabled, do it after routes have been removed
|
||||||
// If routing is to be enabled, do it before adding new routes; addToServerNetwork needs routing to be enabled
|
// If routing is to be enabled, do it before adding new routes; addToServerNetwork needs routing to be enabled
|
||||||
if len(routesMap) > 0 {
|
if len(routesMap) > 0 {
|
||||||
if err := m.firewall.EnableRouting(); err != nil {
|
if err := r.firewall.EnableRouting(); err != nil {
|
||||||
return fmt.Errorf("enable routing: %w", err)
|
return fmt.Errorf("enable routing: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := m.firewall.DisableRouting(); err != nil {
|
if err := r.firewall.DisableRouting(); err != nil {
|
||||||
return fmt.Errorf("disable routing: %w", err)
|
return fmt.Errorf("disable routing: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for id, newRoute := range routesMap {
|
for id, newRoute := range routesMap {
|
||||||
_, found := m.routes[id]
|
_, found := r.routes[id]
|
||||||
if found {
|
if found {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := m.addToServerNetwork(newRoute, useNewDNSRoute)
|
err := r.addToServerNetwork(newRoute, useNewDNSRoute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
|
log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
m.routes[id] = newRoute
|
r.routes[id] = newRoute
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
|
func (r *Router) removeFromServerNetwork(route *route.Route) error {
|
||||||
if m.ctx.Err() != nil {
|
if r.ctx.Err() != nil {
|
||||||
log.Infof("Not removing from server network because context is done")
|
log.Infof("Not removing from server network because context is done")
|
||||||
return m.ctx.Err()
|
return r.ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
routerPair := routeToRouterPair(route, false)
|
routerPair := routeToRouterPair(route, false)
|
||||||
if err := m.firewall.RemoveNatRule(routerPair); err != nil {
|
if err := r.firewall.RemoveNatRule(routerPair); err != nil {
|
||||||
return fmt.Errorf("remove routing rules: %w", err)
|
return fmt.Errorf("remove routing rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(m.routes, route.ID)
|
delete(r.routes, route.ID)
|
||||||
m.statusRecorder.RemoveLocalPeerStateRoute(route.NetString())
|
r.statusRecorder.RemoveLocalPeerStateRoute(route.NetString())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *serverRouter) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error {
|
func (r *Router) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error {
|
||||||
if m.ctx.Err() != nil {
|
if r.ctx.Err() != nil {
|
||||||
log.Infof("Not adding to server network because context is done")
|
log.Infof("Not adding to server network because context is done")
|
||||||
return m.ctx.Err()
|
return r.ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
routerPair := routeToRouterPair(route, useNewDNSRoute)
|
routerPair := routeToRouterPair(route, useNewDNSRoute)
|
||||||
if err := m.firewall.AddNatRule(routerPair); err != nil {
|
if err := r.firewall.AddNatRule(routerPair); err != nil {
|
||||||
return fmt.Errorf("insert routing rules: %w", err)
|
return fmt.Errorf("insert routing rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.routes[route.ID] = route
|
r.routes[route.ID] = route
|
||||||
m.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID())
|
r.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *serverRouter) cleanUp() {
|
func (r *Router) CleanUp() {
|
||||||
m.mux.Lock()
|
r.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer r.mux.Unlock()
|
||||||
|
|
||||||
for _, r := range m.routes {
|
for _, route := range r.routes {
|
||||||
routerPair := routeToRouterPair(r, false)
|
routerPair := routeToRouterPair(route, false)
|
||||||
if err := m.firewall.RemoveNatRule(routerPair); err != nil {
|
if err := r.firewall.RemoveNatRule(routerPair); err != nil {
|
||||||
log.Errorf("Failed to remove cleanup route: %v", err)
|
log.Errorf("Failed to remove cleanup route: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.statusRecorder.CleanLocalPeerStateRoutes()
|
r.statusRecorder.CleanLocalPeerStateRoutes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) RoutesCount() int {
|
||||||
|
r.mux.Lock()
|
||||||
|
defer r.mux.Unlock()
|
||||||
|
return len(r.routes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterPair {
|
func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterPair {
|
@ -29,14 +29,18 @@ func (r *Route) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) AddRoute(context.Context) error {
|
func (r *Route) AddRoute(context.Context) error {
|
||||||
_, err := r.routeRefCounter.Increment(r.route.Network, struct{}{})
|
if _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Route) RemoveRoute() error {
|
func (r *Route) RemoveRoute() error {
|
||||||
_, err := r.routeRefCounter.Decrement(r.route.Network)
|
if _, err := r.routeRefCounter.Decrement(r.route.Network); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Route) AddAllowedIPs(peerKey string) error {
|
func (r *Route) AddAllowedIPs(peerKey string) error {
|
||||||
if ref, err := r.allowedIPsRefcounter.Increment(r.route.Network, peerKey); err != nil {
|
if ref, err := r.allowedIPsRefcounter.Increment(r.route.Network, peerKey); err != nil {
|
||||||
@ -51,6 +55,8 @@ func (r *Route) AddAllowedIPs(peerKey string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) RemoveAllowedIPs() error {
|
func (r *Route) RemoveAllowedIPs() error {
|
||||||
_, err := r.allowedIPsRefcounter.Decrement(r.route.Network)
|
if _, err := r.allowedIPsRefcounter.Decrement(r.route.Network); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user