diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index f0277319c..7fc50505f 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -33,7 +33,8 @@ type Server interface { Initialize() error Stop() DnsIP() string - UpdateDNSServer(serial uint64, update nbdns.Config) error + DnsPort() int + UpdateDNSServer(update nbdns.Config, hasDNSRoute bool) error OnUpdatedHostDNSServer(strings []string) SearchDomains() []string ProbeAvailability() @@ -51,7 +52,6 @@ type DefaultServer struct { localResolver *localResolver wgInterface WGIface hostManager hostManager - updateSerial uint64 previousConfigHash uint64 currentConfig HostDNSConfig @@ -183,6 +183,11 @@ func (s *DefaultServer) DnsIP() string { return s.service.RuntimeIP() } +func (s *DefaultServer) DnsPort() int { + // Todo: review what will be if the service is not running yet + return s.service.RuntimePort() +} + // Stop stops the server func (s *DefaultServer) Stop() { s.mux.Lock() @@ -215,16 +220,12 @@ func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { } // UpdateDNSServer processes an update received from the management service -func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { +func (s *DefaultServer) UpdateDNSServer(update nbdns.Config, hasDNSRoute bool) error { select { case <-s.ctx.Done(): log.Infof("not updating DNS server as context is closed") return s.ctx.Err() default: - if serial < s.updateSerial { - return fmt.Errorf("not applying dns update, error: "+ - "network update is %d behind the last applied update", s.updateSerial-serial) - } s.mux.Lock() defer s.mux.Unlock() @@ -244,17 +245,14 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro if s.previousConfigHash == hash { log.Debugf("not applying the dns configuration update as there is nothing new") - s.updateSerial = serial return nil } - if err := s.applyConfiguration(update); err != nil { + if err := s.applyConfiguration(update, hasDNSRoute); err != nil { return fmt.Errorf("apply configuration: %w", err) } - s.updateSerial = serial s.previousConfigHash = hash - return nil } } @@ -288,15 +286,18 @@ func (s *DefaultServer) ProbeAvailability() { wg.Wait() } -func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { +func (s *DefaultServer) applyConfiguration(update nbdns.Config, hasDNSRoute bool) error { // is the service should be Disabled, we stop the listener or fake resolver // and proceed with a regular update to clean up the handlers and records - if update.ServiceEnable { + if update.ServiceEnable || hasDNSRoute { _ = s.service.Listen() } else if !s.permanent { s.service.Stop() } + // trace the dns configuration update + log.Infof("---- dns server listen address: %s:%d", s.service.RuntimeIP(), s.service.RuntimePort()) + localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) if err != nil { return fmt.Errorf("not applying dns update, error: %v", err) diff --git a/client/internal/engine.go b/client/internal/engine.go index 63caec02a..d0123d7a2 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -44,7 +44,6 @@ import ( "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" mgm "github.com/netbirdio/netbird/management/client" - "github.com/netbirdio/netbird/management/domain" mgmProto "github.com/netbirdio/netbird/management/proto" auth "github.com/netbirdio/netbird/relay/auth/hmac" relayClient "github.com/netbirdio/netbird/relay/client" @@ -802,14 +801,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.acl.ApplyFiltering(networkMap) } - protoRoutes := networkMap.GetRoutes() - if protoRoutes == nil { - protoRoutes = []*mgmProto.Route{} + // todo keep the state because of the serial or eliminate the serial usage from dns and route mgr + networkMapMgr := networkMapHandler{ + DNSServer: e.dnsServer, + RouteManager: e.routeManager, } - - _, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) - if err != nil { - log.Errorf("failed to update clientRoutes, err: %v", err) + if err := networkMapMgr.update(serial, networkMap); err != nil { + log.Warnf("failed to update apply network map: %v", err) + // todo: consider to return here with error } e.clientRoutesMu.Lock() @@ -858,16 +857,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } } - protoDNSConfig := networkMap.GetDNSConfig() - if protoDNSConfig == nil { - protoDNSConfig = &mgmProto.DNSConfig{} - } - - err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)) - if err != nil { - log.Errorf("failed to update dns server, err: %v", err) - } - e.networkSerial = serial // Test received (upstream) servers for availability right away instead of upon usage. @@ -877,76 +866,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return nil } -func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { - routes := make([]*route.Route, 0) - for _, protoRoute := range protoRoutes { - var prefix netip.Prefix - if len(protoRoute.Domains) == 0 { - var err error - if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil { - log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err) - continue - } - } - convertedRoute := &route.Route{ - ID: route.ID(protoRoute.ID), - Network: prefix, - Domains: domain.FromPunycodeList(protoRoute.Domains), - NetID: route.NetID(protoRoute.NetID), - NetworkType: route.NetworkType(protoRoute.NetworkType), - Peer: protoRoute.Peer, - Metric: int(protoRoute.Metric), - Masquerade: protoRoute.Masquerade, - KeepRoute: protoRoute.KeepRoute, - } - routes = append(routes, convertedRoute) - } - return routes -} - -func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { - dnsUpdate := nbdns.Config{ - ServiceEnable: protoDNSConfig.GetServiceEnable(), - CustomZones: make([]nbdns.CustomZone, 0), - NameServerGroups: make([]*nbdns.NameServerGroup, 0), - } - - for _, zone := range protoDNSConfig.GetCustomZones() { - dnsZone := nbdns.CustomZone{ - Domain: zone.GetDomain(), - } - for _, record := range zone.Records { - dnsRecord := nbdns.SimpleRecord{ - Name: record.GetName(), - Type: int(record.GetType()), - Class: record.GetClass(), - TTL: int(record.GetTTL()), - RData: record.GetRData(), - } - dnsZone.Records = append(dnsZone.Records, dnsRecord) - } - dnsUpdate.CustomZones = append(dnsUpdate.CustomZones, dnsZone) - } - - for _, nsGroup := range protoDNSConfig.GetNameServerGroups() { - dnsNSGroup := &nbdns.NameServerGroup{ - Primary: nsGroup.GetPrimary(), - Domains: nsGroup.GetDomains(), - SearchDomainsEnabled: nsGroup.GetSearchDomainsEnabled(), - } - for _, ns := range nsGroup.GetNameServers() { - dnsNS := nbdns.NameServer{ - IP: netip.MustParseAddr(ns.GetIP()), - NSType: nbdns.NameServerType(ns.GetNSType()), - Port: int(ns.GetPort()), - } - dnsNSGroup.NameServers = append(dnsNSGroup.NameServers, dnsNS) - } - dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup) - } - return dnsUpdate -} - func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) { replacement := make([]peer.State, len(offlinePeers)) for i, offlinePeer := range offlinePeers { @@ -1235,7 +1154,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { if err != nil { return nil, nil, err } - routes := toRoutes(netMap.GetRoutes()) + _, routes := toRoutes(netMap.GetRoutes()) dnsCfg := toDNSConfig(netMap.GetDNSConfig()) return routes, &dnsCfg, nil } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index b58c1f7e9..1f9ab1dbd 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -509,7 +509,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { expectedSerial uint64 }{ { - name: "Routes Config Should Be Passed To Manager", + name: "Routes Config Should Be Passed To networkMapHandler", networkMap: &mgmtProto.NetworkMap{ Serial: 1, PeerConfig: nil, diff --git a/client/internal/netmap_handler.go b/client/internal/netmap_handler.go new file mode 100644 index 000000000..af61b9081 --- /dev/null +++ b/client/internal/netmap_handler.go @@ -0,0 +1,171 @@ +package internal + +import ( + "fmt" + "net" + "net/netip" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/routemanager" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" + mgmProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/route" + + log "github.com/sirupsen/logrus" +) + +type networkMapHandler struct { + DNSServer dns.Server + RouteManager routemanager.Manager + Firewall firewall.Manager + + updateSerial uint64 + dnsRules []firewall.Rule +} + +func (h *networkMapHandler) update(serial uint64, networkMap *mgmProto.NetworkMap) error { + if serial < h.updateSerial { + return fmt.Errorf("not applying dns update, error: "+ + "network update is %d behind the last applied update", h.updateSerial-serial) + } + + hasDNSRoute, routes := toRoutes(networkMap.GetRoutes()) + DNSConfig := toDNSConfig(networkMap.GetDNSConfig()) + + if err := h.DNSServer.UpdateDNSServer(DNSConfig, hasDNSRoute); err != nil { + log.Errorf("failed to update dns server, err: %v", err) + return err + } + h.updateSerial = serial + + // todo: consider to eliminate the serial management from the client.go + _, err := h.RouteManager.UpdateRoutes(serial, routes) + if err != nil { + log.Errorf("failed to update routes, err: %v", err) + return err + } + + if hasDNSRoute { + if err := h.allowDNSFirewall(); err != nil { + return err + } + } else { + if err := h.dropDNSFirewall(); err != nil { + return err + } + } + return nil +} + +func (h *networkMapHandler) allowDNSFirewall() error { + dport := &firewall.Port{ + IsRange: false, + Values: []int{h.DNSServer.DnsPort()}, + } + dnsRules, err := h.Firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") + if err != nil { + log.Errorf("failed to add allow DNS router rules, err: %v", err) + return err + } + h.dnsRules = dnsRules + return nil +} + +func (h *networkMapHandler) dropDNSFirewall() error { + if len(h.dnsRules) == 0 { + return nil + } + + for _, rule := range h.dnsRules { + if err := h.Firewall.DeletePeerRule(rule); err != nil { + log.Errorf("failed to delete DNS router rules, err: %v", err) + return err + } + } + + h.dnsRules = nil + return nil +} + +func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { + if protoDNSConfig == nil { + protoDNSConfig = &mgmProto.DNSConfig{} + } + + dnsUpdate := nbdns.Config{ + ServiceEnable: protoDNSConfig.GetServiceEnable(), + CustomZones: make([]nbdns.CustomZone, 0), + NameServerGroups: make([]*nbdns.NameServerGroup, 0), + } + + for _, zone := range protoDNSConfig.GetCustomZones() { + dnsZone := nbdns.CustomZone{ + Domain: zone.GetDomain(), + } + for _, record := range zone.Records { + dnsRecord := nbdns.SimpleRecord{ + Name: record.GetName(), + Type: int(record.GetType()), + Class: record.GetClass(), + TTL: int(record.GetTTL()), + RData: record.GetRData(), + } + dnsZone.Records = append(dnsZone.Records, dnsRecord) + } + dnsUpdate.CustomZones = append(dnsUpdate.CustomZones, dnsZone) + } + + for _, nsGroup := range protoDNSConfig.GetNameServerGroups() { + dnsNSGroup := &nbdns.NameServerGroup{ + Primary: nsGroup.GetPrimary(), + Domains: nsGroup.GetDomains(), + SearchDomainsEnabled: nsGroup.GetSearchDomainsEnabled(), + } + for _, ns := range nsGroup.GetNameServers() { + dnsNS := nbdns.NameServer{ + IP: netip.MustParseAddr(ns.GetIP()), + NSType: nbdns.NameServerType(ns.GetNSType()), + Port: int(ns.GetPort()), + } + dnsNSGroup.NameServers = append(dnsNSGroup.NameServers, dnsNS) + } + dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup) + } + return dnsUpdate +} + +func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) { + if protoRoutes == nil { + protoRoutes = []*mgmProto.Route{} + } + var hasDNSRoute bool + routes := make([]*route.Route, 0) + for _, protoRoute := range protoRoutes { + var prefix netip.Prefix + if len(protoRoute.Domains) == 0 { + var err error + if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil { + log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err) + continue + } + } + + hasDNSRoute = true + + convertedRoute := &route.Route{ + ID: route.ID(protoRoute.ID), + Network: prefix, + Domains: domain.FromPunycodeList(protoRoute.Domains), + NetID: route.NetID(protoRoute.NetID), + NetworkType: route.NetworkType(protoRoute.NetworkType), + Peer: protoRoute.Peer, + Metric: int(protoRoute.Metric), + Masquerade: protoRoute.Masquerade, + KeepRoute: protoRoute.KeepRoute, + } + routes = append(routes, convertedRoute) + } + return hasDNSRoute, routes +} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 8bf3a91b0..edb352f63 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -33,7 +33,7 @@ import ( // Manager is a route manager interface type Manager interface { Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) - UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) + UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (route.HAMap, error) TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector SetRouteChangeListener(listener listener.NetworkChangeListener) @@ -60,6 +60,7 @@ type DefaultManager struct { allowedIPsRefCounter *refcounter.AllowedIPsRefCounter dnsRouteInterval time.Duration stateManager *statemanager.Manager + dnsRule []firewall.Rule // todo: remove rule in stop action } func NewManager( @@ -210,11 +211,11 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps -func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { +func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (route.HAMap, error) { select { case <-m.ctx.Done(): log.Infof("not updating routes as context is closed") - return nil, nil, m.ctx.Err() + return nil, m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() @@ -226,13 +227,12 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro m.notifier.OnNewRoutes(filteredClientRoutes) if m.serverRouter != nil { - err := m.serverRouter.updateRoutes(newServerRoutesMap) - if err != nil { - return nil, nil, fmt.Errorf("update routes: %w", err) + if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil { + return nil, fmt.Errorf("update routes: %w", err) } } - return newServerRoutesMap, newClientRoutesIDMap, nil + return newClientRoutesIDMap, nil } } diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 07dac21b8..063885740 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -436,11 +436,11 @@ func TestManagerUpdateRoutes(t *testing.T) { } if len(testCase.inputInitRoutes) > 0 { - _, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) + _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) require.NoError(t, err, "should update routes with init routes") } - _, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) + _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected