diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index e0f9da26f..72dc4bc6e 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -105,6 +105,7 @@ func (s *serviceViaListener) Stop() { } func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) { + log.Debugf("registering dns handler for pattern: %s", pattern) s.dnsMux.Handle(pattern, handler) } diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 729b90cc0..e198249ff 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -68,6 +68,7 @@ func (s *ServiceViaMemory) Stop() { } func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) { + log.Debugf("registering dns handler for pattern: %s", pattern) s.dnsMux.Handle(pattern, handler) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 0e622649a..f499eb3f0 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "maps" "math/rand" "net" "net/netip" @@ -137,10 +136,6 @@ type Engine struct { TURNs []*stun.URI stunTurn atomic.Value - // clientRoutes is the most recent list of clientRoutes received from the Management Service - clientRoutes route.HAMap - clientRoutesMu sync.RWMutex - clientCtx context.Context clientCancel context.CancelFunc @@ -300,10 +295,6 @@ func (e *Engine) Stop() error { return fmt.Errorf("failed to remove all peers: %s", err) } - e.clientRoutesMu.Lock() - e.clientRoutes = nil - e.clientRoutesMu.Unlock() - if e.cancel != nil { e.cancel() } @@ -383,6 +374,7 @@ func (e *Engine) Start() error { initialRoutes, e.stateManager, dnsServer, + e.peerConns, ) beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err != nil { @@ -812,15 +804,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoRoutes = []*mgmProto.Route{} } - _, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) - if err != nil { + if err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)); err != nil { log.Errorf("failed to update clientRoutes, err: %v", err) } - e.clientRoutesMu.Lock() - e.clientRoutes = clientRoutes - e.clientRoutesMu.Unlock() - log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) e.updateOfflinePeers(networkMap.GetOfflinePeers()) @@ -868,8 +855,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)) - if err != nil { + if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil { log.Errorf("failed to update dns server, err: %v", err) } @@ -1136,7 +1122,7 @@ func (e *Engine) receiveSignalEvents() { return err } - go conn.OnRemoteCandidate(candidate, e.GetClientRoutes()) + go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes()) case sProto.Body_MODE: } @@ -1323,26 +1309,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { } } -// GetClientRoutes returns the current routes from the route map -func (e *Engine) GetClientRoutes() route.HAMap { - e.clientRoutesMu.RLock() - defer e.clientRoutesMu.RUnlock() - - return maps.Clone(e.clientRoutes) -} - -// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only -func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { - e.clientRoutesMu.RLock() - defer e.clientRoutesMu.RUnlock() - - routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes)) - for id, v := range e.clientRoutes { - routes[id.NetID()] = v - } - return routes -} - // GetRouteManager returns the route manager func (e *Engine) GetRouteManager() routemanager.Manager { return e.routeManager @@ -1506,7 +1472,7 @@ func (e *Engine) startNetworkMonitor() { func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) { var vpnRoutes []netip.Prefix - for _, routes := range e.GetClientRoutes() { + for _, routes := range e.routeManager.GetClientRoutes() { if len(routes) > 0 && routes[0] != nil { vpnRoutes = append(vpnRoutes, routes[0].Network) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 5141fa04d..4a8f369b6 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -251,7 +251,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { }, } engine.wgInterface = wgIface - engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil) + engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil, nil) _, _, err = engine.routeManager.Init() require.NoError(t, err) engine.dnsServer = &dns.MockServer{ diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 9ef1855b0..4a5cc50f3 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -65,6 +65,7 @@ func newClientNetworkWatcher( routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsServer nbdns.Server, + peerConns map[string]*peer.Conn, ) *clientNetwork { ctx, cancel := context.WithCancel(ctx) @@ -77,7 +78,16 @@ func newClientNetworkWatcher( routePeersNotifiers: make(map[string]chan struct{}), routeUpdate: make(chan routesUpdate), peerStateUpdate: make(chan struct{}), - handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface, dnsServer), + handler: handlerFromRoute( + rt, + routeRefCounter, + allowedIPsRefCounter, + dnsRouteInterval, + statusRecorder, + wgInterface, + dnsServer, + peerConns, + ), } return client } @@ -388,13 +398,29 @@ func handlerFromRoute( statusRecorder *peer.Status, wgInterface iface.IWGIface, dnsServer nbdns.Server, + peerConns map[string]*peer.Conn, ) RouteHandler { if rt.IsDynamic() { if useNewDNSRoute { - return dnsinterceptor.New(rt, routeRefCounter, allowedIPsRefCounter, statusRecorder, dnsServer) + return dnsinterceptor.New( + rt, + routeRefCounter, + allowedIPsRefCounter, + statusRecorder, + dnsServer, + peerConns, + ) } dns := nbdns.NewServiceViaMemory(wgInterface) - return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())) + return dynamic.NewRoute( + rt, + routeRefCounter, + allowedIPsRefCounter, + dnsRouterInteval, + statusRecorder, + wgInterface, + fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()), + ) } return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 7c5b55a23..4cfa8f042 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -24,6 +24,7 @@ type DnsInterceptor struct { dnsServer nbdns.Server currentPeerKey string interceptedIPs map[string]netip.Prefix + peerConns map[string]*peer.Conn } func New( @@ -32,6 +33,7 @@ func New( allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, statusRecorder *peer.Status, dnsServer nbdns.Server, + peerConns map[string]*peer.Conn, ) *DnsInterceptor { return &DnsInterceptor{ route: rt, @@ -40,6 +42,7 @@ func New( statusRecorder: statusRecorder, dnsServer: dnsServer, interceptedIPs: make(map[string]netip.Prefix), + peerConns: peerConns, } } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index addedcb54..3afdd509e 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -12,6 +12,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" @@ -34,9 +35,11 @@ import ( // Manager is a route manager interface type Manager interface { Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) - UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) + UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector + GetClientRoutes() route.HAMap + GetClientRoutesWithNetID() map[route.NetID][]*route.Route SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string EnableServerRouter(firewall firewall.Manager) error @@ -61,7 +64,10 @@ type DefaultManager struct { allowedIPsRefCounter *refcounter.AllowedIPsRefCounter dnsRouteInterval time.Duration stateManager *statemanager.Manager - dnsServer dns.Server + // clientRoutes is the most recent list of clientRoutes received from the Management Service + clientRoutes route.HAMap + dnsServer dns.Server + peerConns map[string]*peer.Conn } func NewManager( @@ -74,6 +80,7 @@ func NewManager( initialRoutes []*route.Route, stateManager *statemanager.Manager, dnsServer dns.Server, + peerConns map[string]*peer.Conn, ) *DefaultManager { mCTX, cancel := context.WithCancel(ctx) notifier := notifier.NewNotifier() @@ -92,6 +99,7 @@ func NewManager( notifier: notifier, stateManager: stateManager, dnsServer: dnsServer, + peerConns: peerConns, } dm.routeRefCounter = refcounter.New( @@ -120,7 +128,7 @@ func NewManager( ) if runtime.GOOS == "android" { - cr := dm.clientRoutes(initialRoutes) + cr := dm.initialClientRoutes(initialRoutes) dm.notifier.SetInitialClientRoutes(cr) } return dm @@ -211,33 +219,40 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } m.ctx = nil + + m.mux.Lock() + defer m.mux.Unlock() + m.clientRoutes = nil } // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps -func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { +func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { select { case <-m.ctx.Done(): log.Infof("not updating routes as context is closed") - return nil, nil, m.ctx.Err() + return nil default: - m.mux.Lock() - defer m.mux.Unlock() - - newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) - - filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) - m.updateClientNetworks(updateSerial, filteredClientRoutes) - m.notifier.OnNewRoutes(filteredClientRoutes) - - if m.serverRouter != nil { - err := m.serverRouter.updateRoutes(newServerRoutesMap) - if err != nil { - return nil, nil, fmt.Errorf("update routes: %w", err) - } - } - - return newServerRoutesMap, newClientRoutesIDMap, nil } + + m.mux.Lock() + defer m.mux.Unlock() + + newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) + + filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) + m.updateClientNetworks(updateSerial, filteredClientRoutes) + m.notifier.OnNewRoutes(filteredClientRoutes) + + if m.serverRouter != nil { + err := m.serverRouter.updateRoutes(newServerRoutesMap) + if err != nil { + return err + } + } + + m.clientRoutes = newClientRoutesIDMap + + return nil } // SetRouteChangeListener set RouteListener for route change Notifier @@ -255,9 +270,24 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector { return m.routeSelector } -// GetClientRoutes returns the client routes -func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork { - return m.clientNetworks +// GetClientRoutes returns most recent list of clientRoutes received from the Management Service +func (m *DefaultManager) GetClientRoutes() route.HAMap { + m.mux.Lock() + defer m.mux.Unlock() + + return maps.Clone(m.clientRoutes) +} + +// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only +func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { + m.mux.Lock() + defer m.mux.Unlock() + + routes := make(map[route.NetID][]*route.Route, len(m.clientRoutes)) + for id, v := range m.clientRoutes { + routes[id.NetID()] = v + } + return routes } // TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones @@ -286,6 +316,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, + m.peerConns, ) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() @@ -315,16 +346,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout for id, routes := range networks { clientNetworkWatcher, found := m.clientNetworks[id] if !found { - clientNetworkWatcher = newClientNetworkWatcher( - m.ctx, - m.dnsRouteInterval, - m.wgInterface, - m.statusRecorder, - routes[0], - m.routeRefCounter, - m.allowedIPsRefCounter, - m.dnsServer, - ) + clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, nil) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() } @@ -367,7 +389,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID] return newServerRoutesMap, newClientRoutesIDMap } -func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route { +func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route { _, crMap := m.classifyRoutes(initialRoutes) rs := make([]*route.Route, 0, len(crMap)) for _, routes := range crMap { diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 07dac21b8..ce6e8b539 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -424,7 +424,7 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() - routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil) + routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil, nil) _, _, err = routeManager.Init() @@ -436,11 +436,11 @@ func TestManagerUpdateRoutes(t *testing.T) { } if len(testCase.inputInitRoutes) > 0 { - _, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) + _ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) require.NoError(t, err, "should update routes with init routes") } - _, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) + _ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 556a62351..0219b17c8 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -2,7 +2,6 @@ package routemanager import ( "context" - "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" @@ -15,10 +14,12 @@ import ( // MockManager is the mock instance of a route manager type MockManager struct { - UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) - TriggerSelectionFunc func(haMap route.HAMap) - GetRouteSelectorFunc func() *routeselector.RouteSelector - StopFunc func(manager *statemanager.Manager) + UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error + TriggerSelectionFunc func(haMap route.HAMap) + GetRouteSelectorFunc func() *routeselector.RouteSelector + GetClientRoutesFunc func() route.HAMap + GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route + StopFunc func(manager *statemanager.Manager) } func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { @@ -31,11 +32,11 @@ func (m *MockManager) InitialRouteRange() []string { } // UpdateRoutes mock implementation of UpdateRoutes from Manager interface -func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { +func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { if m.UpdateRoutesFunc != nil { return m.UpdateRoutesFunc(updateSerial, newRoutes) } - return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented") + return nil } func (m *MockManager) TriggerSelection(networks route.HAMap) { @@ -52,6 +53,22 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector { return nil } +// GetClientRoutes mock implementation of GetClientRoutes from Manager interface +func (m *MockManager) GetClientRoutes() route.HAMap { + if m.GetClientRoutesFunc != nil { + return m.GetClientRoutesFunc() + } + return nil +} + +// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface +func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { + if m.GetClientRoutesWithNetIDFunc != nil { + return m.GetClientRoutesWithNetIDFunc() + } + return nil +} + // Start mock implementation of Start from Manager interface func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) { } diff --git a/client/server/route.go b/client/server/route.go index d70e0dca3..f1312a5f3 100644 --- a/client/server/route.go +++ b/client/server/route.go @@ -34,7 +34,7 @@ func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.L return nil, fmt.Errorf("not connected") } - routesMap := engine.GetClientRoutesWithNetID() + routesMap := engine.GetRouteManager().GetClientRoutesWithNetID() routeSelector := engine.GetRouteManager().GetRouteSelector() var routes []*selectRoute @@ -116,11 +116,12 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) routeSelector.SelectAllRoutes() } else { routes := toNetIDs(req.GetRouteIDs()) - if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + netIdRoutes := maps.Keys(engine.GetRouteManager().GetClientRoutesWithNetID()) + if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil { return nil, fmt.Errorf("select routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(engine.GetRouteManager().GetClientRoutes()) return &proto.SelectRoutesResponse{}, nil } @@ -145,11 +146,12 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques routeSelector.DeselectAllRoutes() } else { routes := toNetIDs(req.GetRouteIDs()) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + netIdRoutes := maps.Keys(engine.GetRouteManager().GetClientRoutesWithNetID()) + if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil { return nil, fmt.Errorf("deselect routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(engine.GetRouteManager().GetClientRoutes()) return &proto.SelectRoutesResponse{}, nil }