This commit is contained in:
Viktor Liu 2024-12-10 19:14:42 +01:00
parent 9d820f1eae
commit d802b7b9ba
10 changed files with 132 additions and 94 deletions

View File

@ -105,6 +105,7 @@ func (s *serviceViaListener) Stop() {
} }
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) { func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
log.Debugf("registering dns handler for pattern: %s", pattern)
s.dnsMux.Handle(pattern, handler) s.dnsMux.Handle(pattern, handler)
} }

View File

@ -68,6 +68,7 @@ func (s *ServiceViaMemory) Stop() {
} }
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) { func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
log.Debugf("registering dns handler for pattern: %s", pattern)
s.dnsMux.Handle(pattern, handler) s.dnsMux.Handle(pattern, handler)
} }

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"maps"
"math/rand" "math/rand"
"net" "net"
"net/netip" "net/netip"
@ -137,10 +136,6 @@ type Engine struct {
TURNs []*stun.URI TURNs []*stun.URI
stunTurn atomic.Value stunTurn atomic.Value
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
clientRoutesMu sync.RWMutex
clientCtx context.Context clientCtx context.Context
clientCancel context.CancelFunc clientCancel context.CancelFunc
@ -300,10 +295,6 @@ func (e *Engine) Stop() error {
return fmt.Errorf("failed to remove all peers: %s", err) return fmt.Errorf("failed to remove all peers: %s", err)
} }
e.clientRoutesMu.Lock()
e.clientRoutes = nil
e.clientRoutesMu.Unlock()
if e.cancel != nil { if e.cancel != nil {
e.cancel() e.cancel()
} }
@ -383,6 +374,7 @@ func (e *Engine) Start() error {
initialRoutes, initialRoutes,
e.stateManager, e.stateManager,
dnsServer, dnsServer,
e.peerConns,
) )
beforePeerHook, afterPeerHook, err := e.routeManager.Init() beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil { if err != nil {
@ -812,15 +804,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
protoRoutes = []*mgmProto.Route{} protoRoutes = []*mgmProto.Route{}
} }
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) if err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)); err != nil {
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err) log.Errorf("failed to update clientRoutes, err: %v", err)
} }
e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers()) e.updateOfflinePeers(networkMap.GetOfflinePeers())
@ -868,8 +855,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
protoDNSConfig = &mgmProto.DNSConfig{} protoDNSConfig = &mgmProto.DNSConfig{}
} }
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)) if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil {
if err != nil {
log.Errorf("failed to update dns server, err: %v", err) log.Errorf("failed to update dns server, err: %v", err)
} }
@ -1136,7 +1122,7 @@ func (e *Engine) receiveSignalEvents() {
return err return err
} }
go conn.OnRemoteCandidate(candidate, e.GetClientRoutes()) go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
case sProto.Body_MODE: case sProto.Body_MODE:
} }
@ -1323,26 +1309,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
} }
} }
// GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() route.HAMap {
e.clientRoutesMu.RLock()
defer e.clientRoutesMu.RUnlock()
return maps.Clone(e.clientRoutes)
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
e.clientRoutesMu.RLock()
defer e.clientRoutesMu.RUnlock()
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes {
routes[id.NetID()] = v
}
return routes
}
// GetRouteManager returns the route manager // GetRouteManager returns the route manager
func (e *Engine) GetRouteManager() routemanager.Manager { func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager return e.routeManager
@ -1506,7 +1472,7 @@ func (e *Engine) startNetworkMonitor() {
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) { func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
var vpnRoutes []netip.Prefix var vpnRoutes []netip.Prefix
for _, routes := range e.GetClientRoutes() { for _, routes := range e.routeManager.GetClientRoutes() {
if len(routes) > 0 && routes[0] != nil { if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network) vpnRoutes = append(vpnRoutes, routes[0].Network)
} }

View File

@ -251,7 +251,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}, },
} }
engine.wgInterface = wgIface engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil, nil)
_, _, err = engine.routeManager.Init() _, _, err = engine.routeManager.Init()
require.NoError(t, err) require.NoError(t, err)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{

View File

@ -65,6 +65,7 @@ func newClientNetworkWatcher(
routeRefCounter *refcounter.RouteRefCounter, routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsServer nbdns.Server, dnsServer nbdns.Server,
peerConns map[string]*peer.Conn,
) *clientNetwork { ) *clientNetwork {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
@ -77,7 +78,16 @@ func newClientNetworkWatcher(
routePeersNotifiers: make(map[string]chan struct{}), routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate), routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}), peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface, dnsServer), handler: handlerFromRoute(
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouteInterval,
statusRecorder,
wgInterface,
dnsServer,
peerConns,
),
} }
return client return client
} }
@ -388,13 +398,29 @@ func handlerFromRoute(
statusRecorder *peer.Status, statusRecorder *peer.Status,
wgInterface iface.IWGIface, wgInterface iface.IWGIface,
dnsServer nbdns.Server, dnsServer nbdns.Server,
peerConns map[string]*peer.Conn,
) RouteHandler { ) RouteHandler {
if rt.IsDynamic() { if rt.IsDynamic() {
if useNewDNSRoute { if useNewDNSRoute {
return dnsinterceptor.New(rt, routeRefCounter, allowedIPsRefCounter, statusRecorder, dnsServer) return dnsinterceptor.New(
rt,
routeRefCounter,
allowedIPsRefCounter,
statusRecorder,
dnsServer,
peerConns,
)
} }
dns := nbdns.NewServiceViaMemory(wgInterface) dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())) return dynamic.NewRoute(
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouterInteval,
statusRecorder,
wgInterface,
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
)
} }
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
} }

View File

@ -24,6 +24,7 @@ type DnsInterceptor struct {
dnsServer nbdns.Server dnsServer nbdns.Server
currentPeerKey string currentPeerKey string
interceptedIPs map[string]netip.Prefix interceptedIPs map[string]netip.Prefix
peerConns map[string]*peer.Conn
} }
func New( func New(
@ -32,6 +33,7 @@ func New(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
statusRecorder *peer.Status, statusRecorder *peer.Status,
dnsServer nbdns.Server, dnsServer nbdns.Server,
peerConns map[string]*peer.Conn,
) *DnsInterceptor { ) *DnsInterceptor {
return &DnsInterceptor{ return &DnsInterceptor{
route: rt, route: rt,
@ -40,6 +42,7 @@ func New(
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
dnsServer: dnsServer, dnsServer: dnsServer,
interceptedIPs: make(map[string]netip.Prefix), interceptedIPs: make(map[string]netip.Prefix),
peerConns: peerConns,
} }
} }

View File

@ -12,6 +12,7 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
@ -34,9 +35,11 @@ import (
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
TriggerSelection(route.HAMap) TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector GetRouteSelector() *routeselector.RouteSelector
GetClientRoutes() route.HAMap
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener) SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error EnableServerRouter(firewall firewall.Manager) error
@ -61,7 +64,10 @@ type DefaultManager struct {
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
stateManager *statemanager.Manager stateManager *statemanager.Manager
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
dnsServer dns.Server dnsServer dns.Server
peerConns map[string]*peer.Conn
} }
func NewManager( func NewManager(
@ -74,6 +80,7 @@ func NewManager(
initialRoutes []*route.Route, initialRoutes []*route.Route,
stateManager *statemanager.Manager, stateManager *statemanager.Manager,
dnsServer dns.Server, dnsServer dns.Server,
peerConns map[string]*peer.Conn,
) *DefaultManager { ) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx) mCTX, cancel := context.WithCancel(ctx)
notifier := notifier.NewNotifier() notifier := notifier.NewNotifier()
@ -92,6 +99,7 @@ func NewManager(
notifier: notifier, notifier: notifier,
stateManager: stateManager, stateManager: stateManager,
dnsServer: dnsServer, dnsServer: dnsServer,
peerConns: peerConns,
} }
dm.routeRefCounter = refcounter.New( dm.routeRefCounter = refcounter.New(
@ -120,7 +128,7 @@ func NewManager(
) )
if runtime.GOOS == "android" { if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes) cr := dm.initialClientRoutes(initialRoutes)
dm.notifier.SetInitialClientRoutes(cr) dm.notifier.SetInitialClientRoutes(cr)
} }
return dm return dm
@ -211,15 +219,21 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
} }
m.ctx = nil m.ctx = nil
m.mux.Lock()
defer m.mux.Unlock()
m.clientRoutes = nil
} }
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
log.Infof("not updating routes as context is closed") log.Infof("not updating routes as context is closed")
return nil, nil, m.ctx.Err() return nil
default: default:
}
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
@ -232,12 +246,13 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
if m.serverRouter != nil { if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap) err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("update routes: %w", err) return err
} }
} }
return newServerRoutesMap, newClientRoutesIDMap, nil m.clientRoutes = newClientRoutesIDMap
}
return nil
} }
// SetRouteChangeListener set RouteListener for route change Notifier // SetRouteChangeListener set RouteListener for route change Notifier
@ -255,9 +270,24 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
return m.routeSelector return m.routeSelector
} }
// GetClientRoutes returns the client routes // GetClientRoutes returns most recent list of clientRoutes received from the Management Service
func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork { func (m *DefaultManager) GetClientRoutes() route.HAMap {
return m.clientNetworks m.mux.Lock()
defer m.mux.Unlock()
return maps.Clone(m.clientRoutes)
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
m.mux.Lock()
defer m.mux.Unlock()
routes := make(map[route.NetID][]*route.Route, len(m.clientRoutes))
for id, v := range m.clientRoutes {
routes[id.NetID()] = v
}
return routes
} }
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones // TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
@ -286,6 +316,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
m.routeRefCounter, m.routeRefCounter,
m.allowedIPsRefCounter, m.allowedIPsRefCounter,
m.dnsServer, m.dnsServer,
m.peerConns,
) )
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
@ -315,16 +346,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks { for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id] clientNetworkWatcher, found := m.clientNetworks[id]
if !found { if !found {
clientNetworkWatcher = newClientNetworkWatcher( clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, nil)
m.ctx,
m.dnsRouteInterval,
m.wgInterface,
m.statusRecorder,
routes[0],
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsServer,
)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
} }
@ -367,7 +389,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
return newServerRoutesMap, newClientRoutesIDMap return newServerRoutesMap, newClientRoutesIDMap
} }
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route { func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.classifyRoutes(initialRoutes) _, crMap := m.classifyRoutes(initialRoutes)
rs := make([]*route.Route, 0, len(crMap)) rs := make([]*route.Route, 0, len(crMap))
for _, routes := range crMap { for _, routes := range crMap {

View File

@ -424,7 +424,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm") statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO() ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil) routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil, nil)
_, _, err = routeManager.Init() _, _, err = routeManager.Init()
@ -436,11 +436,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
} }
if len(testCase.inputInitRoutes) > 0 { if len(testCase.inputInitRoutes) > 0 {
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) _ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
require.NoError(t, err, "should update routes with init routes") require.NoError(t, err, "should update routes with init routes")
} }
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) _ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
require.NoError(t, err, "should update routes") require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected expectedWatchers := testCase.clientNetworkWatchersExpected

View File

@ -2,7 +2,6 @@ package routemanager
import ( import (
"context" "context"
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
@ -15,9 +14,11 @@ import (
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
type MockManager struct { type MockManager struct {
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
TriggerSelectionFunc func(haMap route.HAMap) TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
StopFunc func(manager *statemanager.Manager) StopFunc func(manager *statemanager.Manager)
} }
@ -31,11 +32,11 @@ func (m *MockManager) InitialRouteRange() []string {
} }
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface // UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
if m.UpdateRoutesFunc != nil { if m.UpdateRoutesFunc != nil {
return m.UpdateRoutesFunc(updateSerial, newRoutes) return m.UpdateRoutesFunc(updateSerial, newRoutes)
} }
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented") return nil
} }
func (m *MockManager) TriggerSelection(networks route.HAMap) { func (m *MockManager) TriggerSelection(networks route.HAMap) {
@ -52,6 +53,22 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
return nil return nil
} }
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
func (m *MockManager) GetClientRoutes() route.HAMap {
if m.GetClientRoutesFunc != nil {
return m.GetClientRoutesFunc()
}
return nil
}
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
if m.GetClientRoutesWithNetIDFunc != nil {
return m.GetClientRoutesWithNetIDFunc()
}
return nil
}
// Start mock implementation of Start from Manager interface // Start mock implementation of Start from Manager interface
func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) { func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) {
} }

View File

@ -34,7 +34,7 @@ func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.L
return nil, fmt.Errorf("not connected") return nil, fmt.Errorf("not connected")
} }
routesMap := engine.GetClientRoutesWithNetID() routesMap := engine.GetRouteManager().GetClientRoutesWithNetID()
routeSelector := engine.GetRouteManager().GetRouteSelector() routeSelector := engine.GetRouteManager().GetRouteSelector()
var routes []*selectRoute var routes []*selectRoute
@ -116,11 +116,12 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest)
routeSelector.SelectAllRoutes() routeSelector.SelectAllRoutes()
} else { } else {
routes := toNetIDs(req.GetRouteIDs()) routes := toNetIDs(req.GetRouteIDs())
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { netIdRoutes := maps.Keys(engine.GetRouteManager().GetClientRoutesWithNetID())
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil {
return nil, fmt.Errorf("select routes: %w", err) return nil, fmt.Errorf("select routes: %w", err)
} }
} }
routeManager.TriggerSelection(engine.GetClientRoutes()) routeManager.TriggerSelection(engine.GetRouteManager().GetClientRoutes())
return &proto.SelectRoutesResponse{}, nil return &proto.SelectRoutesResponse{}, nil
} }
@ -145,11 +146,12 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques
routeSelector.DeselectAllRoutes() routeSelector.DeselectAllRoutes()
} else { } else {
routes := toNetIDs(req.GetRouteIDs()) routes := toNetIDs(req.GetRouteIDs())
if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { netIdRoutes := maps.Keys(engine.GetRouteManager().GetClientRoutesWithNetID())
if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil {
return nil, fmt.Errorf("deselect routes: %w", err) return nil, fmt.Errorf("deselect routes: %w", err)
} }
} }
routeManager.TriggerSelection(engine.GetClientRoutes()) routeManager.TriggerSelection(engine.GetRouteManager().GetClientRoutes())
return &proto.SelectRoutesResponse{}, nil return &proto.SelectRoutesResponse{}, nil
} }