diff --git a/client/internal/dns/file_unix.go b/client/internal/dns/file_unix.go index 1f4ddb67c..3e338267f 100644 --- a/client/internal/dns/file_unix.go +++ b/client/internal/dns/file_unix.go @@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string { continue } - listOfDomains = append(listOfDomains, dConf.Domain) + listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, ".")) } return listOfDomains } diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 3286daabf..6baf9ed95 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -75,12 +75,7 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority } // First remove any existing handler with same pattern (case-insensitive) and priority - for i := len(c.handlers) - 1; i >= 0; i-- { - if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority { - c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) - break - } - } + c.removeEntry(origPattern, priority) // Check if handler implements SubdomainMatcher interface matchSubdomains := false @@ -133,30 +128,20 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) { pattern = dns.Fqdn(pattern) + c.removeEntry(pattern, priority) +} + +func (c *HandlerChain) removeEntry(pattern string, priority int) { // Find and remove handlers matching both original pattern (case-insensitive) and priority for i := len(c.handlers) - 1; i >= 0; i-- { entry := c.handlers[i] if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) - return + break } } } -// HasHandlers returns true if there are any handlers remaining for the given pattern -func (c *HandlerChain) HasHandlers(pattern string) bool { - c.mu.RLock() - defer c.mu.RUnlock() - - pattern = strings.ToLower(dns.Fqdn(pattern)) - for _, entry := range c.handlers { - if strings.EqualFold(entry.Pattern, pattern) { - return true - } - } - return false -} - func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { return diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 94aa987af..4c910a95f 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -443,14 +443,6 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { for _, handler := range handlers { handler.AssertExpectations(t) } - - // Verify handler exists check - for priority, shouldExist := range tt.expectedCalls { - if shouldExist { - assert.True(t, chain.HasHandlers(tt.ops[0].pattern), - "Handler chain should have handlers for pattern after removing priority %d", priority) - } - } }) } } @@ -470,45 +462,69 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { r := new(dns.Msg) r.SetQuestion(testQuery, dns.TypeA) + // Keep track of mocks for the final assertion in Step 4 + mocks := []*nbdns.MockSubdomainHandler{routeHandler, matchHandler, defaultHandler} + // Add handlers in mixed order chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute) chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) - // Test 1: Initial state with all three handlers - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + // Test 1: Initial state + w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} // Highest priority handler (routeHandler) should be called routeHandler.On("ServeDNS", mock.Anything, r).Return().Once() + matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet + defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet - chain.ServeDNS(w, r) + chain.ServeDNS(w1, r) routeHandler.AssertExpectations(t) + routeHandler.ExpectedCalls = nil + routeHandler.Calls = nil + matchHandler.ExpectedCalls = nil + matchHandler.Calls = nil + defaultHandler.ExpectedCalls = nil + defaultHandler.Calls = nil + // Test 2: Remove highest priority handler chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute) - assert.True(t, chain.HasHandlers(testDomain)) - w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} // Now middle priority handler (matchHandler) should be called matchHandler.On("ServeDNS", mock.Anything, r).Return().Once() + defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet - chain.ServeDNS(w, r) + chain.ServeDNS(w2, r) matchHandler.AssertExpectations(t) + matchHandler.ExpectedCalls = nil + matchHandler.Calls = nil + defaultHandler.ExpectedCalls = nil + defaultHandler.Calls = nil + // Test 3: Remove middle priority handler chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) - assert.True(t, chain.HasHandlers(testDomain)) - w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} // Now lowest priority handler (defaultHandler) should be called defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once() - chain.ServeDNS(w, r) + chain.ServeDNS(w3, r) defaultHandler.AssertExpectations(t) + defaultHandler.ExpectedCalls = nil + defaultHandler.Calls = nil + // Test 4: Remove last handler chain.RemoveHandler(testDomain, nbdns.PriorityDefault) - assert.False(t, chain.HasHandlers(testDomain)) + w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain + + for _, m := range mocks { + m.AssertNumberOfCalls(t, "ServeDNS", 0) + } } func TestHandlerChain_CaseSensitivity(t *testing.T) { @@ -830,3 +846,165 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { }) } } + +func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) { + tests := []struct { + name string + addPattern string + removePattern string + queryPattern string + shouldBeRemoved bool + description string + }{ + { + name: "exact same pattern", + addPattern: "example.com.", + removePattern: "example.com.", + queryPattern: "example.com.", + shouldBeRemoved: true, + description: "Adding and removing with identical patterns", + }, + { + name: "case difference", + addPattern: "Example.Com.", + removePattern: "EXAMPLE.COM.", + queryPattern: "example.com.", + shouldBeRemoved: true, + description: "Adding with mixed case, removing with uppercase", + }, + { + name: "reversed case difference", + addPattern: "EXAMPLE.ORG.", + removePattern: "example.org.", + queryPattern: "example.org.", + shouldBeRemoved: true, + description: "Adding with uppercase, removing with lowercase", + }, + { + name: "add wildcard, remove wildcard", + addPattern: "*.example.com.", + removePattern: "*.example.com.", + queryPattern: "sub.example.com.", + shouldBeRemoved: true, + description: "Adding and removing with identical wildcard patterns", + }, + { + name: "add wildcard, remove transformed pattern", + addPattern: "*.example.net.", + removePattern: "example.net.", + queryPattern: "sub.example.net.", + shouldBeRemoved: false, + description: "Adding with wildcard, removing with non-wildcard pattern", + }, + { + name: "add transformed pattern, remove wildcard", + addPattern: "example.io.", + removePattern: "*.example.io.", + queryPattern: "example.io.", + shouldBeRemoved: false, + description: "Adding with non-wildcard pattern, removing with wildcard pattern", + }, + { + name: "trailing dot difference", + addPattern: "example.dev", + removePattern: "example.dev.", + queryPattern: "example.dev.", + shouldBeRemoved: true, + description: "Adding without trailing dot, removing with trailing dot", + }, + { + name: "reversed trailing dot difference", + addPattern: "example.app.", + removePattern: "example.app", + queryPattern: "example.app.", + shouldBeRemoved: true, + description: "Adding with trailing dot, removing without trailing dot", + }, + { + name: "mixed case and wildcard", + addPattern: "*.Example.Site.", + removePattern: "*.EXAMPLE.SITE.", + queryPattern: "sub.example.site.", + shouldBeRemoved: true, + description: "Adding mixed case wildcard, removing uppercase wildcard", + }, + { + name: "root zone", + addPattern: ".", + removePattern: ".", + queryPattern: "random.domain.", + shouldBeRemoved: true, + description: "Adding and removing root zone", + }, + { + name: "wrong domain", + addPattern: "example.com.", + removePattern: "different.com.", + queryPattern: "example.com.", + shouldBeRemoved: false, + description: "Adding one domain, trying to remove a different domain", + }, + { + name: "subdomain mismatch", + addPattern: "sub.example.com.", + removePattern: "example.com.", + queryPattern: "sub.example.com.", + shouldBeRemoved: false, + description: "Adding subdomain, trying to remove parent domain", + }, + { + name: "parent domain mismatch", + addPattern: "example.com.", + removePattern: "sub.example.com.", + queryPattern: "example.com.", + shouldBeRemoved: false, + description: "Adding parent domain, trying to remove subdomain", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + + handler := &nbdns.MockHandler{} + r := new(dns.Msg) + r.SetQuestion(tt.queryPattern, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + // First verify no handler is called before adding any + chain.ServeDNS(w, r) + handler.AssertNotCalled(t, "ServeDNS") + + // Add handler + chain.AddHandler(tt.addPattern, handler, nbdns.PriorityDefault) + + // Verify handler is called after adding + handler.On("ServeDNS", mock.Anything, r).Once() + chain.ServeDNS(w, r) + handler.AssertExpectations(t) + + // Reset mock for the next test + handler.ExpectedCalls = nil + + // Remove handler + chain.RemoveHandler(tt.removePattern, nbdns.PriorityDefault) + + // Set up expectations based on whether removal should succeed + if !tt.shouldBeRemoved { + handler.On("ServeDNS", mock.Anything, r).Once() + } + + // Test if handler is still called after removal attempt + chain.ServeDNS(w, r) + + if tt.shouldBeRemoved { + handler.AssertNotCalled(t, "ServeDNS", + "Handler should not be called after successful removal with pattern %q", + tt.removePattern) + } else { + handler.AssertExpectations(t) + handler.ExpectedCalls = nil + } + }) + } +} diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index 25e9ff7e5..dbf0f2cfc 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -5,6 +5,8 @@ import ( "net/netip" "strings" + "github.com/miekg/dns" + "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" ) @@ -12,8 +14,8 @@ import ( var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") const ( - ipv4ReverseZone = ".in-addr.arpa" - ipv6ReverseZone = ".ip6.arpa" + ipv4ReverseZone = ".in-addr.arpa." + ipv6ReverseZone = ".ip6.arpa." ) type hostManager interface { @@ -103,7 +105,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD for _, domain := range nsConfig.Domains { config.Domains = append(config.Domains, DomainConfig{ - Domain: strings.TrimSuffix(domain, "."), + Domain: strings.ToLower(dns.Fqdn(domain)), MatchOnly: !nsConfig.SearchDomainsEnabled, }) } @@ -112,7 +114,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD for _, customZone := range dnsConfig.CustomZones { matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone) config.Domains = append(config.Domains, DomainConfig{ - Domain: strings.TrimSuffix(customZone.Domain, "."), + Domain: strings.ToLower(dns.Fqdn(customZone.Domain)), MatchOnly: matchOnly, }) } diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index f727f68b5..a445bc6c4 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * continue } if dConf.MatchOnly { - matchDomains = append(matchDomains, dConf.Domain) + matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, ".")) continue } - searchDomains = append(searchDomains, dConf.Domain) + searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, ".")) } matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 285904f71..cfba29501 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -17,9 +17,12 @@ import ( var ( userenv = syscall.NewLazyDLL("userenv.dll") + dnsapi = syscall.NewLazyDLL("dnsapi.dll") // https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx") + + dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache") ) const ( @@ -97,9 +100,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager continue } if !dConf.MatchOnly { - searchDomains = append(searchDomains, dConf.Domain) + searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, ".")) } - matchDomains = append(matchDomains, "."+dConf.Domain) + matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) } if len(matchDomains) != 0 { @@ -116,6 +119,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager return fmt.Errorf("update search domains: %w", err) } + if err := r.flushDNSCache(); err != nil { + log.Errorf("failed to flush DNS cache: %v", err) + } + return nil } @@ -184,6 +191,26 @@ func (r *registryConfigurator) string() string { return "registry" } +func (r *registryConfigurator) flushDNSCache() error { + // dnsFlushResolverCacheFn.Call() may panic if the func is not found + defer func() { + if rec := recover(); rec != nil { + log.Errorf("Recovered from panic in flushDNSCache: %v", rec) + } + }() + + ret, _, err := dnsFlushResolverCacheFn.Call() + if ret == 0 { + if err != nil && !errors.Is(err, syscall.Errno(0)) { + return fmt.Errorf("DnsFlushResolverCache failed: %w", err) + } + return fmt.Errorf("DnsFlushResolverCache failed") + } + + log.Info("flushed DNS cache") + return nil +} + func (r *registryConfigurator) updateSearchDomains(domains []string) error { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil { return fmt.Errorf("update search domains: %w", err) @@ -236,6 +263,10 @@ func (r *registryConfigurator) restoreHostDNS() error { return fmt.Errorf("remove interface registry key: %w", err) } + if err := r.flushDNSCache(); err != nil { + log.Errorf("failed to flush DNS cache: %v", err) + } + return nil } diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 7e36ea5df..c5dd6e23f 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -6,6 +6,7 @@ import ( "github.com/miekg/dns" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" ) // MockServer is the mock instance of a dns server @@ -13,17 +14,17 @@ type MockServer struct { InitializeFunc func() error StopFunc func() UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error - RegisterHandlerFunc func([]string, dns.Handler, int) - DeregisterHandlerFunc func([]string, int) + RegisterHandlerFunc func(domain.List, dns.Handler, int) + DeregisterHandlerFunc func(domain.List, int) } -func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { +func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) { if m.RegisterHandlerFunc != nil { m.RegisterHandlerFunc(domains, handler, priority) } } -func (m *MockServer) DeregisterHandler(domains []string, priority int) { +func (m *MockServer) DeregisterHandler(domains domain.List, priority int) { if m.DeregisterHandlerFunc != nil { m.DeregisterHandlerFunc(domains, priority) } diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index 10b4e6a6e..caae63a24 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -13,7 +13,6 @@ import ( "github.com/godbus/dbus/v5" "github.com/hashicorp/go-version" - "github.com/miekg/dns" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -126,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st continue } if dConf.MatchOnly { - matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain)) + matchDomains = append(matchDomains, "~."+dConf.Domain) continue } - searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain)) + searchDomains = append(searchDomains, dConf.Domain) } newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index bc87012f2..74ab6717f 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -6,11 +6,13 @@ import ( "fmt" "net/netip" "runtime" + "strings" "sync" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/listener" @@ -18,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" cProto "github.com/netbirdio/netbird/client/proto" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" ) // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes @@ -32,8 +35,8 @@ type IosDnsManager interface { // Server is a dns server interface type Server interface { - RegisterHandler(domains []string, handler dns.Handler, priority int) - DeregisterHandler(domains []string, priority int) + RegisterHandler(domains domain.List, handler dns.Handler, priority int) + DeregisterHandler(domains domain.List, priority int) Initialize() error Stop() DnsIP() string @@ -65,6 +68,7 @@ type DefaultServer struct { previousConfigHash uint64 currentConfig HostDNSConfig handlerChain *HandlerChain + extraDomains map[domain.Domain]int // permanent related properties permanent bool @@ -164,13 +168,15 @@ func newDefaultServer( stateManager *statemanager.Manager, disableSys bool, ) *DefaultServer { + handlerChain := NewHandlerChain() ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ ctx: ctx, ctxCancel: stop, disableSys: disableSys, service: dnsService, - handlerChain: NewHandlerChain(), + handlerChain: handlerChain, + extraDomains: make(map[domain.Domain]int), dnsMuxMap: make(registeredHandlerMap), localResolver: &localResolver{ registeredMap: make(registrationMap), @@ -181,14 +187,26 @@ func newDefaultServer( hostsDNSHolder: newHostsDNSHolder(), } + // register with root zone, handler chain takes care of the routing + dnsService.RegisterMux(".", handlerChain) + return defaultServer } -func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { +// RegisterHandler registers a handler for the given domains with the given priority. +// Any previously registered handler for the same domain and priority will be replaced. +func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) { s.mux.Lock() defer s.mux.Unlock() - s.registerHandler(domains, handler, priority) + s.registerHandler(domains.ToPunycodeList(), handler, priority) + + // TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain + for _, domain := range domains { + // convert to zone with simple ref counter + s.extraDomains[toZone(domain)]++ + } + s.applyHostConfig() } func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { @@ -200,15 +218,23 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p continue } s.handlerChain.AddHandler(domain, handler, priority) - s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain) } } -func (s *DefaultServer) DeregisterHandler(domains []string, priority int) { +// DeregisterHandler deregisters the handler for the given domains with the given priority. +func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) { s.mux.Lock() defer s.mux.Unlock() - s.deregisterHandler(domains, priority) + s.deregisterHandler(domains.ToPunycodeList(), priority) + for _, domain := range domains { + zone := toZone(domain) + s.extraDomains[zone]-- + if s.extraDomains[zone] <= 0 { + delete(s.extraDomains, zone) + } + } + s.applyHostConfig() } func (s *DefaultServer) deregisterHandler(domains []string, priority int) { @@ -221,11 +247,6 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) { } s.handlerChain.RemoveHandler(domain, priority) - - // Only deregister from service if no handlers remain - if !s.handlerChain.HasHandlers(domain) { - s.service.DeregisterMux(nbdns.NormalizeZone(domain)) - } } } @@ -286,6 +307,8 @@ func (s *DefaultServer) Stop() { } s.service.Stop() + + maps.Clear(s.extraDomains) } // OnUpdatedHostDNSServer update the DNS servers addresses for root zones @@ -390,7 +413,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // is the service should be Disabled, we stop the listener or fake resolver // and proceed with a regular update to clean up the handlers and records if update.ServiceEnable { - _ = s.service.Listen() + if err := s.service.Listen(); err != nil { + log.Errorf("failed to start DNS service: %v", err) + } } else if !s.permanent { s.service.Stop() } @@ -413,17 +438,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) - hostUpdate := s.currentConfig if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() { log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + "Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver") - hostUpdate.RouteAll = false + s.currentConfig.RouteAll = false } - if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil { - log.Error(err) - s.handleErrNoGroupaAll(err) - } + s.applyHostConfig() go func() { // persist dns state right away @@ -441,6 +462,38 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { return nil } +func (s *DefaultServer) applyHostConfig() { + if s.hostManager == nil { + return + } + + config := s.currentConfig + + existingDomains := make(map[string]struct{}) + for _, d := range config.Domains { + existingDomains[d.Domain] = struct{}{} + } + + // add extra domains only if they're not already in the config + for domain := range s.extraDomains { + domainStr := domain.PunycodeString() + + if _, exists := existingDomains[domainStr]; !exists { + config.Domains = append(config.Domains, DomainConfig{ + Domain: domainStr, + MatchOnly: true, + }) + } + } + + log.Debugf("extra match domains: %v", s.extraDomains) + + if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil { + log.Errorf("failed to apply DNS host manager update: %v", err) + s.handleErrNoGroupaAll(err) + } +} + func (s *DefaultServer) handleErrNoGroupaAll(err error) { if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) { return @@ -690,10 +743,7 @@ func (s *DefaultServer) upstreamCallbacks( } } - if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { - s.handleErrNoGroupaAll(err) - l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) - } + s.applyHostConfig() go func() { if err := s.stateManager.PersistState(s.ctx); err != nil { @@ -728,12 +778,7 @@ func (s *DefaultServer) upstreamCallbacks( s.registerHandler([]string{nbdns.RootZone}, handler, priority) } - if s.hostManager != nil { - if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { - s.handleErrNoGroupaAll(err) - l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") - } - } + s.applyHostConfig() s.updateNSState(nsGroup, nil, true) } @@ -836,3 +881,13 @@ func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain return result } + +func toZone(d domain.Domain) domain.Domain { + return domain.Domain( + nbdns.NormalizeZone( + dns.Fqdn( + strings.ToLower(d.PunycodeString()), + ), + ), + ) +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 8a15c430b..ed69b0e93 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -29,6 +29,7 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter" + "github.com/netbirdio/netbird/management/domain" ) var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() @@ -38,7 +39,7 @@ type mocWGIface struct { } func (w *mocWGIface) Name() string { - panic("implement me") + return "utun2301" } func (w *mocWGIface) Address() wgaddr.Address { @@ -1448,3 +1449,497 @@ func TestDefaultServer_UpdateMux(t *testing.T) { }) } } + +func TestExtraDomains(t *testing.T) { + tests := []struct { + name string + initialConfig nbdns.Config + registerDomains []domain.List + deregisterDomains []domain.List + finalConfig nbdns.Config + expectedDomains []string + expectedMatchOnly []string + applyHostConfigCall int + }{ + { + name: "Register domains before config update", + registerDomains: []domain.List{ + {"extra1.example.com", "extra2.example.com"}, + }, + initialConfig: nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + {Domain: "config.example.com"}, + }, + }, + expectedDomains: []string{ + "config.example.com.", + "extra1.example.com.", + "extra2.example.com.", + }, + expectedMatchOnly: []string{ + "extra1.example.com.", + "extra2.example.com.", + }, + applyHostConfigCall: 2, + }, + { + name: "Register domains after config update", + initialConfig: nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + {Domain: "config.example.com"}, + }, + }, + registerDomains: []domain.List{ + {"extra1.example.com", "extra2.example.com"}, + }, + expectedDomains: []string{ + "config.example.com.", + "extra1.example.com.", + "extra2.example.com.", + }, + expectedMatchOnly: []string{ + "extra1.example.com.", + "extra2.example.com.", + }, + applyHostConfigCall: 2, + }, + { + name: "Register overlapping domains", + initialConfig: nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + {Domain: "config.example.com"}, + {Domain: "overlap.example.com"}, + }, + }, + registerDomains: []domain.List{ + {"extra.example.com", "overlap.example.com"}, + }, + expectedDomains: []string{ + "config.example.com.", + "overlap.example.com.", + "extra.example.com.", + }, + expectedMatchOnly: []string{ + "extra.example.com.", + }, + applyHostConfigCall: 2, + }, + { + name: "Register and deregister domains", + initialConfig: nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + {Domain: "config.example.com"}, + }, + }, + registerDomains: []domain.List{ + {"extra1.example.com", "extra2.example.com"}, + {"extra3.example.com", "extra4.example.com"}, + }, + deregisterDomains: []domain.List{ + {"extra1.example.com", "extra3.example.com"}, + }, + expectedDomains: []string{ + "config.example.com.", + "extra2.example.com.", + "extra4.example.com.", + }, + expectedMatchOnly: []string{ + "extra2.example.com.", + "extra4.example.com.", + }, + applyHostConfigCall: 4, + }, + { + name: "Register domains with ref counter", + initialConfig: nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + {Domain: "config.example.com"}, + }, + }, + registerDomains: []domain.List{ + {"extra.example.com", "duplicate.example.com"}, + {"other.example.com", "duplicate.example.com"}, + }, + deregisterDomains: []domain.List{ + {"duplicate.example.com"}, + }, + expectedDomains: []string{ + "config.example.com.", + "extra.example.com.", + "other.example.com.", + "duplicate.example.com.", + }, + expectedMatchOnly: []string{ + "extra.example.com.", + "other.example.com.", + "duplicate.example.com.", + }, + applyHostConfigCall: 4, + }, + { + name: "Config update with new domains after registration", + initialConfig: nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + {Domain: "config.example.com"}, + }, + }, + registerDomains: []domain.List{ + {"extra.example.com", "duplicate.example.com"}, + }, + finalConfig: nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + {Domain: "config.example.com"}, + {Domain: "newconfig.example.com"}, + }, + }, + expectedDomains: []string{ + "config.example.com.", + "newconfig.example.com.", + "extra.example.com.", + "duplicate.example.com.", + }, + expectedMatchOnly: []string{ + "extra.example.com.", + "duplicate.example.com.", + }, + applyHostConfigCall: 3, + }, + { + name: "Deregister domain that is part of customZones", + initialConfig: nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + {Domain: "config.example.com"}, + {Domain: "protected.example.com"}, + }, + }, + registerDomains: []domain.List{ + {"extra.example.com", "protected.example.com"}, + }, + deregisterDomains: []domain.List{ + {"protected.example.com"}, + }, + expectedDomains: []string{ + "extra.example.com.", + "config.example.com.", + "protected.example.com.", + }, + expectedMatchOnly: []string{ + "extra.example.com.", + }, + applyHostConfigCall: 3, + }, + { + name: "Register domain that is part of nameserver group", + initialConfig: nbdns.Config{ + ServiceEnable: true, + NameServerGroups: []*nbdns.NameServerGroup{ + { + Domains: []string{"ns.example.com", "overlap.ns.example.com"}, + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + }, + }, + }, + registerDomains: []domain.List{ + {"extra.example.com", "overlap.ns.example.com"}, + }, + expectedDomains: []string{ + "ns.example.com.", + "overlap.ns.example.com.", + "extra.example.com.", + }, + expectedMatchOnly: []string{ + "ns.example.com.", + "overlap.ns.example.com.", + "extra.example.com.", + }, + applyHostConfigCall: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedConfigs []HostDNSConfig + mockHostConfig := &mockHostConfigurator{ + applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error { + capturedConfigs = append(capturedConfigs, config) + return nil + }, + restoreHostDNSFunc: func() error { + return nil + }, + supportCustomPortFunc: func() bool { + return true + }, + stringFunc: func() string { + return "mock" + }, + } + + mockSvc := &mockService{} + + server := &DefaultServer{ + ctx: context.Background(), + handlerChain: NewHandlerChain(), + wgInterface: &mocWGIface{}, + hostManager: mockHostConfig, + localResolver: &localResolver{}, + service: mockSvc, + statusRecorder: peer.NewRecorder("test"), + extraDomains: make(map[domain.Domain]int), + } + + // Apply initial configuration + if tt.initialConfig.ServiceEnable { + err := server.applyConfiguration(tt.initialConfig) + assert.NoError(t, err) + } + + // Register domains + for _, domains := range tt.registerDomains { + server.RegisterHandler(domains, &MockHandler{}, PriorityDefault) + } + + // Deregister domains if specified + for _, domains := range tt.deregisterDomains { + server.DeregisterHandler(domains, PriorityDefault) + } + + // Apply final configuration if specified + if tt.finalConfig.ServiceEnable { + err := server.applyConfiguration(tt.finalConfig) + assert.NoError(t, err) + } + + // Verify number of calls + assert.Equal(t, tt.applyHostConfigCall, len(capturedConfigs), + "Expected %d calls to applyDNSConfig, got %d", tt.applyHostConfigCall, len(capturedConfigs)) + + // Get the last applied config + lastConfig := capturedConfigs[len(capturedConfigs)-1] + + // Check all expected domains are present + domainMap := make(map[string]bool) + matchOnlyMap := make(map[string]bool) + + for _, d := range lastConfig.Domains { + domainMap[d.Domain] = true + if d.MatchOnly { + matchOnlyMap[d.Domain] = true + } + } + + // Verify expected domains + for _, d := range tt.expectedDomains { + assert.True(t, domainMap[d], "Expected domain %s not found in final config", d) + } + + // Verify match-only domains + for _, d := range tt.expectedMatchOnly { + assert.True(t, matchOnlyMap[d], "Expected match-only domain %s not found in final config", d) + } + + // Verify no unexpected domains + assert.Equal(t, len(tt.expectedDomains), len(domainMap), "Unexpected number of domains in final config") + assert.Equal(t, len(tt.expectedMatchOnly), len(matchOnlyMap), "Unexpected number of match-only domains in final config") + }) + } +} + +func TestExtraDomainsRefCounting(t *testing.T) { + mockHostConfig := &mockHostConfigurator{ + applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error { + return nil + }, + restoreHostDNSFunc: func() error { + return nil + }, + supportCustomPortFunc: func() bool { + return true + }, + stringFunc: func() string { + return "mock" + }, + } + + mockSvc := &mockService{} + + server := &DefaultServer{ + ctx: context.Background(), + handlerChain: NewHandlerChain(), + hostManager: mockHostConfig, + localResolver: &localResolver{}, + service: mockSvc, + statusRecorder: peer.NewRecorder("test"), + extraDomains: make(map[domain.Domain]int), + } + + // Register domains from different handlers with same domain + server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute) + server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain) + + // Verify refcount is 2 + zoneKey := toZone("shared.example.com") + assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice") + + // Deregister one handler + server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain) + + // Verify refcount is 1 + assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler") + + // Deregister the other handler + server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityDNSRoute) + + // Verify domain is removed + _, exists := server.extraDomains[zoneKey] + assert.False(t, exists, "Domain should be removed after deregistering all handlers") +} + +func TestUpdateConfigWithExistingExtraDomains(t *testing.T) { + var capturedConfig HostDNSConfig + mockHostConfig := &mockHostConfigurator{ + applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error { + capturedConfig = config + return nil + }, + restoreHostDNSFunc: func() error { + return nil + }, + supportCustomPortFunc: func() bool { + return true + }, + stringFunc: func() string { + return "mock" + }, + } + + mockSvc := &mockService{} + + server := &DefaultServer{ + ctx: context.Background(), + handlerChain: NewHandlerChain(), + hostManager: mockHostConfig, + localResolver: &localResolver{}, + service: mockSvc, + statusRecorder: peer.NewRecorder("test"), + extraDomains: make(map[domain.Domain]int), + } + + server.RegisterHandler(domain.List{"extra.example.com"}, &MockHandler{}, PriorityDefault) + + initialConfig := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + {Domain: "config.example.com"}, + }, + } + err := server.applyConfiguration(initialConfig) + assert.NoError(t, err) + + var domains []string + for _, d := range capturedConfig.Domains { + domains = append(domains, d.Domain) + } + assert.Contains(t, domains, "config.example.com.") + assert.Contains(t, domains, "extra.example.com.") + + // Now apply a new configuration with overlapping domain + updatedConfig := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + {Domain: "config.example.com"}, + {Domain: "extra.example.com"}, + }, + } + err = server.applyConfiguration(updatedConfig) + assert.NoError(t, err) + + // Verify both domains are in config, but no duplicates + domains = []string{} + matchOnlyCount := 0 + for _, d := range capturedConfig.Domains { + domains = append(domains, d.Domain) + if d.MatchOnly { + matchOnlyCount++ + } + } + + assert.Contains(t, domains, "config.example.com.") + assert.Contains(t, domains, "extra.example.com.") + assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates") + + // Extra domain should no longer be marked as match-only when in config + matchOnlyDomain := "" + for _, d := range capturedConfig.Domains { + if d.Domain == "extra.example.com." && d.MatchOnly { + matchOnlyDomain = d.Domain + break + } + } + assert.Empty(t, matchOnlyDomain, "Domain should not be match-only when included in config") +} + +func TestDomainCaseHandling(t *testing.T) { + var capturedConfig HostDNSConfig + mockHostConfig := &mockHostConfigurator{ + applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error { + capturedConfig = config + return nil + }, + restoreHostDNSFunc: func() error { + return nil + }, + supportCustomPortFunc: func() bool { + return true + }, + stringFunc: func() string { + return "mock" + }, + } + + mockSvc := &mockService{} + server := &DefaultServer{ + ctx: context.Background(), + handlerChain: NewHandlerChain(), + hostManager: mockHostConfig, + localResolver: &localResolver{}, + service: mockSvc, + statusRecorder: peer.NewRecorder("test"), + extraDomains: make(map[domain.Domain]int), + } + + server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault) + server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain) + + assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized") + + config := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + {Domain: "config.example.com"}, + }, + } + err := server.applyConfiguration(config) + assert.NoError(t, err) + + var domains []string + for _, d := range capturedConfig.Domains { + domains = append(domains, d.Domain) + } + assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent") + assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present") +} diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index db1418ef1..53c5c58a0 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -11,7 +11,6 @@ import ( "time" "github.com/godbus/dbus/v5" - "github.com/miekg/dns" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" @@ -111,7 +110,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana continue } domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ - Domain: dns.Fqdn(dConf.Domain), + Domain: dConf.Domain, MatchOnly: dConf.MatchOnly, }) @@ -151,6 +150,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana if err != nil { log.Error(err) } + + if err := s.flushDNSCache(); err != nil { + log.Errorf("failed to flush DNS cache: %v", err) + } + return nil } @@ -163,7 +167,8 @@ func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdD if err != nil { return fmt.Errorf("setting domains configuration failed with error: %w", err) } - return s.flushCaches() + + return nil } func (s *systemdDbusConfigurator) restoreHostDNS() error { @@ -183,10 +188,14 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error { return fmt.Errorf("unable to revert link configuration, got error: %w", err) } - return s.flushCaches() + if err := s.flushDNSCache(); err != nil { + log.Errorf("failed to flush DNS cache: %v", err) + } + + return nil } -func (s *systemdDbusConfigurator) flushCaches() error { +func (s *systemdDbusConfigurator) flushDNSCache() error { obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode) if err != nil { return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err) diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index a22689cf9..53fa20f62 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -23,9 +23,10 @@ import ( ) const ( + UpstreamTimeout = 15 * time.Second + failsTillDeact = int32(5) reactivatePeriod = 30 * time.Second - upstreamTimeout = 15 * time.Second probeTimeout = 2 * time.Second ) @@ -66,7 +67,7 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d ctx: ctx, cancel: cancel, domain: domain, - upstreamTimeout: upstreamTimeout, + upstreamTimeout: UpstreamTimeout, reactivatePeriod: reactivatePeriod, failsTillDeact: failsTillDeact, statusRecorder: statusRecorder, diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index a9e46ca02..06ffcba11 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -55,7 +55,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin // exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - timeout := upstreamTimeout + timeout := UpstreamTimeout if deadline, ok := ctx.Deadline(); ok { timeout = time.Until(deadline) } diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 7d3301e14..c73079b92 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -52,7 +52,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err) } - timeout := upstreamTimeout + timeout := UpstreamTimeout if deadline, ok := ctx.Deadline(); ok { timeout = time.Until(deadline) } diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index c5adc0858..5dbcc9f79 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -26,7 +26,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { name: "Should Resolve A Record", inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"}, - timeout: upstreamTimeout, + timeout: UpstreamTimeout, expectedAnswer: "1.1.1.1", }, { @@ -48,7 +48,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"}, cancelCTX: true, - timeout: upstreamTimeout, + timeout: UpstreamTimeout, responseShouldBeNil: true, }, } @@ -122,7 +122,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { r: new(dns.Msg), rtt: time.Millisecond, }, - upstreamTimeout: upstreamTimeout, + upstreamTimeout: UpstreamTimeout, reactivatePeriod: reactivatePeriod, failsTillDeact: failsTillDeact, } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 2e6e4fede..42d740d90 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -6,7 +6,6 @@ import ( "net/netip" "strings" "sync" - "time" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" @@ -60,7 +59,7 @@ func (d *DnsInterceptor) String() string { } func (d *DnsInterceptor) AddRoute(context.Context) error { - d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute) + d.dnsServer.RegisterHandler(d.route.Domains, d, nbdns.PriorityDNSRoute) return nil } @@ -89,7 +88,7 @@ func (d *DnsInterceptor) RemoveRoute() error { clear(d.interceptedDomains) d.mu.Unlock() - d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute) + d.dnsServer.DeregisterHandler(d.route.Domains, nbdns.PriorityDNSRoute) return nberrors.FormatErrorOrNil(merr) } @@ -142,21 +141,24 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { log.Tracef("received DNS request for domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + // pass if non A/AAAA query + if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA { + d.continueToNextHandler(w, r, "non A/AAAA query") + return + } + d.mu.RLock() peerKey := d.currentPeerKey d.mu.RUnlock() if peerKey == "" { - log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name) - - d.continueToNextHandler(w, r, "no current peer key") + d.writeDNSError(w, r, "no current peer key") return } upstreamIP, err := d.getUpstreamIP(peerKey) if err != nil { - log.Errorf("failed to get upstream IP: %v", err) - d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err)) + d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err)) return } @@ -165,34 +167,43 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.SetEdns0(4096, false) r.MsgHdr.AuthenticatedData = true } - client := &dns.Client{ - Timeout: 5 * time.Second, + Timeout: nbdns.UpstreamTimeout, Net: "udp", } upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) reply, _, err := client.ExchangeContext(context.Background(), r, upstream) - - var answer []dns.RR - if reply != nil { - answer = reply.Answer - } - log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer) - if err != nil { - log.Errorf("failed to exchange DNS request with %s: %v", upstream, err) + log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { log.Errorf("failed writing DNS response: %v", err) } return } + var answer []dns.RR + if reply != nil { + answer = reply.Answer + } + + log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer) + reply.Id = r.Id if err := d.writeMsg(w, reply); err != nil { log.Errorf("failed writing DNS response: %v", err) } } +func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) { + log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason) + + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeServerFailure) + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS error response: %v", err) + } +} + // continueToNextHandler signals the handler chain to try the next handler func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) { log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 079134701..47511d4af 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -235,7 +235,7 @@ func (r *Route) resolve(results chan resolveResult) { ips, err := r.getIPsFromResolver(domain) if err != nil { log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err) - ips, err = net.LookupIP(string(domain)) + ips, err = net.LookupIP(domain.PunycodeString()) if err != nil { results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)} return diff --git a/client/internal/routemanager/dynamic/route_generic.go b/client/internal/routemanager/dynamic/route_generic.go index cf3d913a4..a618a2392 100644 --- a/client/internal/routemanager/dynamic/route_generic.go +++ b/client/internal/routemanager/dynamic/route_generic.go @@ -9,5 +9,5 @@ import ( ) func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { - return net.LookupIP(string(domain)) + return net.LookupIP(domain.PunycodeString()) } diff --git a/client/internal/routemanager/dynamic/route_ios.go b/client/internal/routemanager/dynamic/route_ios.go index 67138222f..145d395e6 100644 --- a/client/internal/routemanager/dynamic/route_ios.go +++ b/client/internal/routemanager/dynamic/route_ios.go @@ -23,7 +23,7 @@ func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { } msg := new(dns.Msg) - msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA) + msg.SetQuestion(dns.Fqdn(domain.PunycodeString()), dns.TypeA) startTime := time.Now() diff --git a/client/server/network.go b/client/server/network.go index 1b7962b78..e0b01f763 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -100,7 +100,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro // Convert to proto format for domain, ips := range domainMap { - pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{ + pbRoute.ResolvedIPs[domain.PunycodeString()] = &proto.IPList{ Ips: ips, } } diff --git a/management/domain/domain.go b/management/domain/domain.go index e7e6b050a..2e089b01f 100644 --- a/management/domain/domain.go +++ b/management/domain/domain.go @@ -24,6 +24,11 @@ func (d Domain) SafeString() string { return str } +// PunycodeString returns the punycode representation of the Domain. +func (d Domain) PunycodeString() string { + return string(d) +} + // FromString creates a Domain from a string, converting it to punycode. func FromString(s string) (Domain, error) { ascii, err := idna.ToASCII(s)