diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 673f410e2..3286daabf 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -12,7 +12,7 @@ import ( const ( PriorityDNSRoute = 100 PriorityMatchDomain = 50 - PriorityDefault = 0 + PriorityDefault = 1 ) type SubdomainMatcher interface { @@ -26,7 +26,6 @@ type HandlerEntry struct { Pattern string OrigPattern string IsWildcard bool - StopHandler handlerWithStop MatchSubdomains bool } @@ -64,7 +63,7 @@ func (w *ResponseWriterChain) GetOrigPattern() string { } // AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority -func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) { +func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) { c.mu.Lock() defer c.mu.Unlock() @@ -78,9 +77,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority // First remove any existing handler with same pattern (case-insensitive) and priority for i := len(c.handlers) - 1; i >= 0; i-- { if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority { - if c.handlers[i].StopHandler != nil { - c.handlers[i].StopHandler.stop() - } c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) break } @@ -101,7 +97,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority Pattern: pattern, OrigPattern: origPattern, IsWildcard: isWildcard, - StopHandler: stopHandler, MatchSubdomains: matchSubdomains, } @@ -142,9 +137,6 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) { for i := len(c.handlers) - 1; i >= 0; i-- { entry := c.handlers[i] if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { - if entry.StopHandler != nil { - entry.StopHandler.stop() - } c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) return } @@ -180,8 +172,8 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if log.IsLevelEnabled(log.TraceLevel) { log.Tracef("current handlers (%d):", len(handlers)) for _, h := range handlers { - log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d", - h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority) + log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d", + h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority) } } @@ -206,13 +198,13 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } if !matched { - log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false", - qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard) + log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false", + qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority) continue } - log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v", - qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains) + log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d", + qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority) chainWriter := &ResponseWriterChain{ ResponseWriter: w, diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index d04bfbbb3..8c66446ee 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -21,9 +21,9 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { dnsRouteHandler := &nbdns.MockHandler{} // Setup handlers with different priorities - chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil) - chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil) - chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil) + chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault) + chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain) + chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute) // Create test request r := new(dns.Msg) @@ -138,7 +138,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { pattern = "*." + tt.handlerDomain[2:] } - chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil) + chain.AddHandler(pattern, handler, nbdns.PriorityDefault) r := new(dns.Msg) r.SetQuestion(tt.queryDomain, dns.TypeA) @@ -253,7 +253,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe() } - chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil) + chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority) } // Create and execute request @@ -280,9 +280,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { handler3 := &nbdns.MockHandler{} // Add handlers in priority order - chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil) - chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil) - chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil) + chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute) + chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain) + chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault) // Create test request r := new(dns.Msg) @@ -416,7 +416,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { if op.action == "add" { handler := &nbdns.MockHandler{} handlers[op.priority] = handler - chain.AddHandler(op.pattern, handler, op.priority, nil) + chain.AddHandler(op.pattern, handler, op.priority) } else { chain.RemoveHandler(op.pattern, op.priority) } @@ -471,9 +471,9 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { r.SetQuestion(testQuery, dns.TypeA) // Add handlers in mixed order - chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil) - chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil) - chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil) + chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault) + chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute) + chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) // Test 1: Initial state with all three handlers w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} @@ -653,7 +653,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { handler = mockHandler } - chain.AddHandler(pattern, handler, h.priority, nil) + chain.AddHandler(pattern, handler, h.priority) } // Execute request diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go index 9a78d4d50..1fe88f750 100644 --- a/client/internal/dns/local.go +++ b/client/internal/dns/local.go @@ -29,10 +29,15 @@ func (d *localResolver) String() string { return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap)) } +// ID returns the unique handler ID +func (d *localResolver) id() handlerID { + return "local-resolver" +} + // ServeDNS handles a DNS request func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) > 0 { - log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) } replyMessage := &dns.Msg{} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 1fe913fd9..f714f9857 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -5,7 +5,6 @@ import ( "fmt" "net/netip" "runtime" - "strings" "sync" "github.com/miekg/dns" @@ -42,7 +41,12 @@ type Server interface { ProbeAvailability() } -type registeredHandlerMap map[string]handlerWithStop +type handlerID string + +type nsGroupsByDomain struct { + domain string + groups []*nbdns.NameServerGroup +} // DefaultServer dns server object type DefaultServer struct { @@ -52,7 +56,6 @@ type DefaultServer struct { mux sync.Mutex service service dnsMuxMap registeredHandlerMap - handlerPriorities map[string]int localResolver *localResolver wgInterface WGIface hostManager hostManager @@ -77,14 +80,17 @@ type handlerWithStop interface { dns.Handler stop() probeAvailability() + id() handlerID } -type muxUpdate struct { +type handlerWrapper struct { domain string handler handlerWithStop priority int } +type registeredHandlerMap map[handlerID]handlerWrapper + // NewDefaultServer returns a new dns server func NewDefaultServer( ctx context.Context, @@ -158,13 +164,12 @@ func newDefaultServer( ) *DefaultServer { ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - disableSys: disableSys, - service: dnsService, - handlerChain: NewHandlerChain(), - dnsMuxMap: make(registeredHandlerMap), - handlerPriorities: make(map[string]int), + ctx: ctx, + ctxCancel: stop, + disableSys: disableSys, + service: dnsService, + handlerChain: NewHandlerChain(), + dnsMuxMap: make(registeredHandlerMap), localResolver: &localResolver{ registeredMap: make(registrationMap), }, @@ -192,8 +197,7 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p log.Warn("skipping empty domain") continue } - s.handlerChain.AddHandler(domain, handler, priority, nil) - s.handlerPriorities[domain] = priority + s.handlerChain.AddHandler(domain, handler, priority) s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain) } } @@ -209,14 +213,15 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) { log.Debugf("deregistering handler %v with priority %d", domains, priority) for _, domain := range domains { + if domain == "" { + log.Warn("skipping empty domain") + continue + } + s.handlerChain.RemoveHandler(domain, priority) // Only deregister from service if no handlers remain if !s.handlerChain.HasHandlers(domain) { - if domain == "" { - log.Warn("skipping empty domain") - continue - } s.service.DeregisterMux(nbdns.NormalizeZone(domain)) } } @@ -283,14 +288,24 @@ func (s *DefaultServer) Stop() { // OnUpdatedHostDNSServer update the DNS servers addresses for root zones // It will be applied if the mgm server do not enforce DNS settings for root zone + func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { s.hostsDNSHolder.set(hostsDnsList) - _, ok := s.dnsMuxMap[nbdns.RootZone] - if ok { + // Check if there's any root handler + var hasRootHandler bool + for _, handler := range s.dnsMuxMap { + if handler.domain == nbdns.RootZone { + hasRootHandler = true + break + } + } + + if hasRootHandler { log.Debugf("on new host DNS config but skip to apply it") return } + log.Debugf("update host DNS settings: %+v", hostsDnsList) s.addHostRootZone() } @@ -364,7 +379,7 @@ func (s *DefaultServer) ProbeAvailability() { go func(mux handlerWithStop) { defer wg.Done() mux.probeAvailability() - }(mux) + }(mux.handler) } wg.Wait() } @@ -419,8 +434,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { return nil } -func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) { - var muxUpdates []muxUpdate +func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, map[string]nbdns.SimpleRecord, error) { + var muxUpdates []handlerWrapper localRecords := make(map[string]nbdns.SimpleRecord, 0) for _, customZone := range customZones { @@ -428,7 +443,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) return nil, nil, fmt.Errorf("received an empty list of records") } - muxUpdates = append(muxUpdates, muxUpdate{ + muxUpdates = append(muxUpdates, handlerWrapper{ domain: customZone.Domain, handler: s.localResolver, priority: PriorityMatchDomain, @@ -446,15 +461,59 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) return muxUpdates, localRecords, nil } -func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) { +func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) { + var muxUpdates []handlerWrapper - var muxUpdates []muxUpdate for _, nsGroup := range nameServerGroups { if len(nsGroup.NameServers) == 0 { log.Warn("received a nameserver group with empty nameserver list") continue } + if !nsGroup.Primary && len(nsGroup.Domains) == 0 { + return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list") + } + + for _, domain := range nsGroup.Domains { + if domain == "" { + return nil, fmt.Errorf("received a nameserver group with an empty domain element") + } + } + } + + groupedNS := groupNSGroupsByDomain(nameServerGroups) + + for _, domainGroup := range groupedNS { + basePriority := PriorityMatchDomain + if domainGroup.domain == nbdns.RootZone { + basePriority = PriorityDefault + } + + updates, err := s.createHandlersForDomainGroup(domainGroup, basePriority) + if err != nil { + return nil, err + } + muxUpdates = append(muxUpdates, updates...) + } + + return muxUpdates, nil +} + +func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomain, basePriority int) ([]handlerWrapper, error) { + var muxUpdates []handlerWrapper + + for i, nsGroup := range domainGroup.groups { + // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts + priority := basePriority - i + + // Check if we're about to overlap with the next priority tier + if basePriority == PriorityMatchDomain && priority <= PriorityDefault { + log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", + domainGroup.domain, PriorityMatchDomain-PriorityDefault) + break + } + + log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority) handler, err := newUpstreamResolver( s.ctx, s.wgInterface.Name(), @@ -462,10 +521,12 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam s.wgInterface.Address().Network, s.statusRecorder, s.hostsDNSHolder, + domainGroup.domain, ) if err != nil { - return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err) + return nil, fmt.Errorf("create upstream resolver: %v", err) } + for _, ns := range nsGroup.NameServers { if ns.NSType != nbdns.UDPNameServerType { log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", @@ -489,78 +550,47 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam // after some period defined by upstream it tries to reactivate self by calling this hook // everything we need here is just to re-apply current configuration because it already // contains this upstream settings (temporal deactivation not removed it) - handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler) + handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler, priority) - if nsGroup.Primary { - muxUpdates = append(muxUpdates, muxUpdate{ - domain: nbdns.RootZone, - handler: handler, - priority: PriorityDefault, - }) - continue - } - - if len(nsGroup.Domains) == 0 { - handler.stop() - return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list") - } - - for _, domain := range nsGroup.Domains { - if domain == "" { - handler.stop() - return nil, fmt.Errorf("received a nameserver group with an empty domain element") - } - muxUpdates = append(muxUpdates, muxUpdate{ - domain: domain, - handler: handler, - priority: PriorityMatchDomain, - }) - } + muxUpdates = append(muxUpdates, handlerWrapper{ + domain: domainGroup.domain, + handler: handler, + priority: priority, + }) } return muxUpdates, nil } -func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { - muxUpdateMap := make(registeredHandlerMap) - handlersByPriority := make(map[string]int) - - var isContainRootUpdate bool - - // First register new handlers - for _, update := range muxUpdates { - s.registerHandler([]string{update.domain}, update.handler, update.priority) - muxUpdateMap[update.domain] = update.handler - handlersByPriority[update.domain] = update.priority - - if existingHandler, ok := s.dnsMuxMap[update.domain]; ok { - existingHandler.stop() - } - - if update.domain == nbdns.RootZone { - isContainRootUpdate = true - } +func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { + // this will introduce a short period of time when the server is not able to handle DNS requests + for _, existing := range s.dnsMuxMap { + s.deregisterHandler([]string{existing.domain}, existing.priority) + existing.handler.stop() } - // Then deregister old handlers not in the update - for key, existingHandler := range s.dnsMuxMap { - _, found := muxUpdateMap[key] - if !found { - if !isContainRootUpdate && key == nbdns.RootZone { + muxUpdateMap := make(registeredHandlerMap) + var containsRootUpdate bool + + for _, update := range muxUpdates { + if update.domain == nbdns.RootZone { + containsRootUpdate = true + } + s.registerHandler([]string{update.domain}, update.handler, update.priority) + muxUpdateMap[update.handler.id()] = update + } + + // If there's no root update and we had a root handler, restore it + if !containsRootUpdate { + for _, existing := range s.dnsMuxMap { + if existing.domain == nbdns.RootZone { s.addHostRootZone() - existingHandler.stop() - } else { - existingHandler.stop() - // Deregister with the priority that was used to register - if oldPriority, ok := s.handlerPriorities[key]; ok { - s.deregisterHandler([]string{key}, oldPriority) - } + break } } } s.dnsMuxMap = muxUpdateMap - s.handlerPriorities = handlersByPriority } func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) { @@ -593,6 +623,7 @@ func getNSHostPort(ns nbdns.NameServer) string { func (s *DefaultServer) upstreamCallbacks( nsGroup *nbdns.NameServerGroup, handler dns.Handler, + priority int, ) (deactivate func(error), reactivate func()) { var removeIndex map[string]int deactivate = func(err error) { @@ -609,13 +640,13 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { removeIndex[nbdns.RootZone] = -1 s.currentConfig.RouteAll = false - s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault) + s.deregisterHandler([]string{nbdns.RootZone}, priority) } for i, item := range s.currentConfig.Domains { if _, found := removeIndex[item.Domain]; found { s.currentConfig.Domains[i].Disabled = true - s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain) + s.deregisterHandler([]string{item.Domain}, priority) removeIndex[item.Domain] = i } } @@ -635,8 +666,8 @@ func (s *DefaultServer) upstreamCallbacks( } s.updateNSState(nsGroup, err, false) - } + reactivate = func() { s.mux.Lock() defer s.mux.Unlock() @@ -646,7 +677,7 @@ func (s *DefaultServer) upstreamCallbacks( continue } s.currentConfig.Domains[i].Disabled = false - s.registerHandler([]string{domain}, handler, PriorityMatchDomain) + s.registerHandler([]string{domain}, handler, priority) } l := log.WithField("nameservers", nsGroup.NameServers) @@ -654,7 +685,7 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { s.currentConfig.RouteAll = true - s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) + s.registerHandler([]string{nbdns.RootZone}, handler, priority) } if s.hostManager != nil { @@ -676,6 +707,7 @@ func (s *DefaultServer) addHostRootZone() { s.wgInterface.Address().Network, s.statusRecorder, s.hostsDNSHolder, + nbdns.RootZone, ) if err != nil { log.Errorf("unable to create a new upstream resolver, error: %v", err) @@ -732,5 +764,34 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string { for _, ns := range nsGroup.NameServers { servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) } - return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ",")) + return fmt.Sprintf("%v_%v", servers, nsGroup.Domains) +} + +// groupNSGroupsByDomain groups nameserver groups by their match domains +func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain { + domainMap := make(map[string][]*nbdns.NameServerGroup) + + for _, group := range nsGroups { + if group.Primary { + domainMap[nbdns.RootZone] = append(domainMap[nbdns.RootZone], group) + continue + } + + for _, domain := range group.Domains { + if domain == "" { + continue + } + domainMap[domain] = append(domainMap[domain], group) + } + } + + var result []nsGroupsByDomain + for domain, groups := range domainMap { + result = append(result, nsGroupsByDomain{ + domain: domain, + groups: groups, + }) + } + + return result } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 14ff1bb71..db49f96a2 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -13,6 +13,7 @@ import ( "github.com/golang/mock/gomock" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -88,6 +89,18 @@ func init() { formatter.SetTextFormatter(log.StandardLogger()) } +func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase { + var srvs []string + for _, srv := range servers { + srvs = append(srvs, getNSHostPort(srv)) + } + return &upstreamResolverBase{ + domain: domain, + upstreamServers: srvs, + cancel: func() {}, + } +} + func TestUpdateDNSServer(t *testing.T) { nameServers := []nbdns.NameServer{ { @@ -140,15 +153,37 @@ func TestUpdateDNSServer(t *testing.T) { }, }, }, - expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler, nbdns.RootZone: dummyHandler}, - expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, + expectedUpstreamMap: registeredHandlerMap{ + generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{ + domain: "netbird.io", + handler: dummyHandler, + priority: PriorityMatchDomain, + }, + dummyHandler.id(): handlerWrapper{ + domain: "netbird.cloud", + handler: dummyHandler, + priority: PriorityMatchDomain, + }, + generateDummyHandler(".", nameServers).id(): handlerWrapper{ + domain: nbdns.RootZone, + handler: dummyHandler, + priority: PriorityDefault, + }, + }, + expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, }, { - name: "New Config Should Succeed", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, - initUpstreamMap: registeredHandlerMap{buildRecordKey(zoneRecords[0].Name, 1, 1): dummyHandler}, - initSerial: 0, - inputSerial: 1, + name: "New Config Should Succeed", + initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + initUpstreamMap: registeredHandlerMap{ + generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ + domain: buildRecordKey(zoneRecords[0].Name, 1, 1), + handler: dummyHandler, + priority: PriorityMatchDomain, + }, + }, + initSerial: 0, + inputSerial: 1, inputUpdate: nbdns.Config{ ServiceEnable: true, CustomZones: []nbdns.CustomZone{ @@ -164,8 +199,19 @@ func TestUpdateDNSServer(t *testing.T) { }, }, }, - expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler}, - expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, + expectedUpstreamMap: registeredHandlerMap{ + generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{ + domain: "netbird.io", + handler: dummyHandler, + priority: PriorityMatchDomain, + }, + "local-resolver": handlerWrapper{ + domain: "netbird.cloud", + handler: dummyHandler, + priority: PriorityMatchDomain, + }, + }, + expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, }, { name: "Smaller Config Serial Should Be Skipped", @@ -242,9 +288,15 @@ func TestUpdateDNSServer(t *testing.T) { shouldFail: true, }, { - name: "Empty Config Should Succeed and Clean Maps", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, - initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler}, + name: "Empty Config Should Succeed and Clean Maps", + initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + initUpstreamMap: registeredHandlerMap{ + generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ + domain: zoneRecords[0].Name, + handler: dummyHandler, + priority: PriorityMatchDomain, + }, + }, initSerial: 0, inputSerial: 1, inputUpdate: nbdns.Config{ServiceEnable: true}, @@ -252,9 +304,15 @@ func TestUpdateDNSServer(t *testing.T) { expectedLocalMap: make(registrationMap), }, { - name: "Disabled Service Should clean map", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, - initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler}, + name: "Disabled Service Should clean map", + initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + initUpstreamMap: registeredHandlerMap{ + generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ + domain: zoneRecords[0].Name, + handler: dummyHandler, + priority: PriorityMatchDomain, + }, + }, initSerial: 0, inputSerial: 1, inputUpdate: nbdns.Config{ServiceEnable: false}, @@ -421,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { } }() - dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}} + dnsServer.dnsMuxMap = registeredHandlerMap{ + "id1": handlerWrapper{ + domain: zoneRecords[0].Name, + handler: &localResolver{}, + priority: PriorityMatchDomain, + }, + } dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}} dnsServer.updateSerial = 0 @@ -562,9 +626,8 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { localResolver: &localResolver{ registeredMap: make(registrationMap), }, - handlerChain: NewHandlerChain(), - handlerPriorities: make(map[string]int), - hostManager: hostManager, + handlerChain: NewHandlerChain(), + hostManager: hostManager, currentConfig: HostDNSConfig{ Domains: []DomainConfig{ {false, "domain0", false}, @@ -593,7 +656,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { NameServers: []nbdns.NameServer{ {IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53}, }, - }, nil) + }, nil, 0) deactivate(nil) expected := "domain0,domain2" @@ -903,8 +966,8 @@ func TestHandlerChain_DomainPriorities(t *testing.T) { Subdomains: true, } - chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil) - chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil) + chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute) + chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain) testCases := []struct { name string @@ -959,3 +1022,421 @@ func TestHandlerChain_DomainPriorities(t *testing.T) { }) } } + +type mockHandler struct { + Id string +} + +func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} +func (m *mockHandler) stop() {} +func (m *mockHandler) probeAvailability() {} +func (m *mockHandler) id() handlerID { return handlerID(m.Id) } + +type mockService struct{} + +func (m *mockService) Listen() error { return nil } +func (m *mockService) Stop() {} +func (m *mockService) RuntimeIP() string { return "127.0.0.1" } +func (m *mockService) RuntimePort() int { return 53 } +func (m *mockService) RegisterMux(string, dns.Handler) {} +func (m *mockService) DeregisterMux(string) {} + +func TestDefaultServer_UpdateMux(t *testing.T) { + baseMatchHandlers := registeredHandlerMap{ + "upstream-group1": { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group1", + }, + priority: PriorityMatchDomain, + }, + "upstream-group2": { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group2", + }, + priority: PriorityMatchDomain - 1, + }, + } + + baseRootHandlers := registeredHandlerMap{ + "upstream-root1": { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root1", + }, + priority: PriorityDefault, + }, + "upstream-root2": { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root2", + }, + priority: PriorityDefault - 1, + }, + } + + baseMixedHandlers := registeredHandlerMap{ + "upstream-group1": { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group1", + }, + priority: PriorityMatchDomain, + }, + "upstream-group2": { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group2", + }, + priority: PriorityMatchDomain - 1, + }, + "upstream-other": { + domain: "other.com", + handler: &mockHandler{ + Id: "upstream-other", + }, + priority: PriorityMatchDomain, + }, + } + + tests := []struct { + name string + initialHandlers registeredHandlerMap + updates []handlerWrapper + expectedHandlers map[string]string // map[handlerID]domain + description string + }{ + { + name: "Remove group1 from update", + initialHandlers: baseMatchHandlers, + updates: []handlerWrapper{ + // Only group2 remains + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group2", + }, + priority: PriorityMatchDomain - 1, + }, + }, + expectedHandlers: map[string]string{ + "upstream-group2": "example.com", + }, + description: "When group1 is not included in the update, it should be removed while group2 remains", + }, + { + name: "Remove group2 from update", + initialHandlers: baseMatchHandlers, + updates: []handlerWrapper{ + // Only group1 remains + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group1", + }, + priority: PriorityMatchDomain, + }, + }, + expectedHandlers: map[string]string{ + "upstream-group1": "example.com", + }, + description: "When group2 is not included in the update, it should be removed while group1 remains", + }, + { + name: "Add group3 in first position", + initialHandlers: baseMatchHandlers, + updates: []handlerWrapper{ + // Add group3 with highest priority + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group3", + }, + priority: PriorityMatchDomain + 1, + }, + // Keep existing groups with their original priorities + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group1", + }, + priority: PriorityMatchDomain, + }, + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group2", + }, + priority: PriorityMatchDomain - 1, + }, + }, + expectedHandlers: map[string]string{ + "upstream-group1": "example.com", + "upstream-group2": "example.com", + "upstream-group3": "example.com", + }, + description: "When adding group3 with highest priority, it should be first in chain while maintaining existing groups", + }, + { + name: "Add group3 in last position", + initialHandlers: baseMatchHandlers, + updates: []handlerWrapper{ + // Keep existing groups with their original priorities + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group1", + }, + priority: PriorityMatchDomain, + }, + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group2", + }, + priority: PriorityMatchDomain - 1, + }, + // Add group3 with lowest priority + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group3", + }, + priority: PriorityMatchDomain - 2, + }, + }, + expectedHandlers: map[string]string{ + "upstream-group1": "example.com", + "upstream-group2": "example.com", + "upstream-group3": "example.com", + }, + description: "When adding group3 with lowest priority, it should be last in chain while maintaining existing groups", + }, + // Root zone tests + { + name: "Remove root1 from update", + initialHandlers: baseRootHandlers, + updates: []handlerWrapper{ + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root2", + }, + priority: PriorityDefault - 1, + }, + }, + expectedHandlers: map[string]string{ + "upstream-root2": ".", + }, + description: "When root1 is not included in the update, it should be removed while root2 remains", + }, + { + name: "Remove root2 from update", + initialHandlers: baseRootHandlers, + updates: []handlerWrapper{ + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root1", + }, + priority: PriorityDefault, + }, + }, + expectedHandlers: map[string]string{ + "upstream-root1": ".", + }, + description: "When root2 is not included in the update, it should be removed while root1 remains", + }, + { + name: "Add root3 in first position", + initialHandlers: baseRootHandlers, + updates: []handlerWrapper{ + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root3", + }, + priority: PriorityDefault + 1, + }, + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root1", + }, + priority: PriorityDefault, + }, + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root2", + }, + priority: PriorityDefault - 1, + }, + }, + expectedHandlers: map[string]string{ + "upstream-root1": ".", + "upstream-root2": ".", + "upstream-root3": ".", + }, + description: "When adding root3 with highest priority, it should be first in chain while maintaining existing root handlers", + }, + { + name: "Add root3 in last position", + initialHandlers: baseRootHandlers, + updates: []handlerWrapper{ + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root1", + }, + priority: PriorityDefault, + }, + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root2", + }, + priority: PriorityDefault - 1, + }, + { + domain: ".", + handler: &mockHandler{ + Id: "upstream-root3", + }, + priority: PriorityDefault - 2, + }, + }, + expectedHandlers: map[string]string{ + "upstream-root1": ".", + "upstream-root2": ".", + "upstream-root3": ".", + }, + description: "When adding root3 with lowest priority, it should be last in chain while maintaining existing root handlers", + }, + // Mixed domain tests + { + name: "Update with mixed domains - remove one of duplicate domain", + initialHandlers: baseMixedHandlers, + updates: []handlerWrapper{ + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group1", + }, + priority: PriorityMatchDomain, + }, + { + domain: "other.com", + handler: &mockHandler{ + Id: "upstream-other", + }, + priority: PriorityMatchDomain, + }, + }, + expectedHandlers: map[string]string{ + "upstream-group1": "example.com", + "upstream-other": "other.com", + }, + description: "When updating mixed domains, should correctly handle removal of one duplicate while maintaining other domains", + }, + { + name: "Update with mixed domains - add new domain", + initialHandlers: baseMixedHandlers, + updates: []handlerWrapper{ + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group1", + }, + priority: PriorityMatchDomain, + }, + { + domain: "example.com", + handler: &mockHandler{ + Id: "upstream-group2", + }, + priority: PriorityMatchDomain - 1, + }, + { + domain: "other.com", + handler: &mockHandler{ + Id: "upstream-other", + }, + priority: PriorityMatchDomain, + }, + { + domain: "new.com", + handler: &mockHandler{ + Id: "upstream-new", + }, + priority: PriorityMatchDomain, + }, + }, + expectedHandlers: map[string]string{ + "upstream-group1": "example.com", + "upstream-group2": "example.com", + "upstream-other": "other.com", + "upstream-new": "new.com", + }, + description: "When updating mixed domains, should maintain existing duplicates and add new domain", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := &DefaultServer{ + dnsMuxMap: tt.initialHandlers, + handlerChain: NewHandlerChain(), + service: &mockService{}, + } + + // Perform the update + server.updateMux(tt.updates) + + // Verify the results + assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap), + "Number of handlers after update doesn't match expected") + + // Check each expected handler + for id, expectedDomain := range tt.expectedHandlers { + handler, exists := server.dnsMuxMap[handlerID(id)] + assert.True(t, exists, "Expected handler %s not found", id) + if exists { + assert.Equal(t, expectedDomain, handler.domain, + "Domain mismatch for handler %s", id) + } + } + + // Verify no unexpected handlers exist + for handlerID := range server.dnsMuxMap { + _, expected := tt.expectedHandlers[string(handlerID)] + assert.True(t, expected, "Unexpected handler found: %s", handlerID) + } + + // Verify the handlerChain state and order + previousPriority := 0 + for _, chainEntry := range server.handlerChain.handlers { + // Verify priority order + if previousPriority > 0 { + assert.True(t, chainEntry.Priority <= previousPriority, + "Handlers in chain not properly ordered by priority") + } + previousPriority = chainEntry.Priority + + // Verify handler exists in mux + foundInMux := false + for _, muxEntry := range server.dnsMuxMap { + if chainEntry.Handler == muxEntry.handler && + chainEntry.Priority == muxEntry.priority && + chainEntry.Pattern == dns.Fqdn(muxEntry.domain) { + foundInMux = true + break + } + } + assert.True(t, foundInMux, + "Handler in chain not found in dnsMuxMap") + } + }) + } +} diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index f0aa12b65..4c69a173d 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -2,9 +2,13 @@ package dns import ( "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" "net" + "slices" + "strings" "sync" "sync/atomic" "time" @@ -40,6 +44,7 @@ type upstreamResolverBase struct { cancel context.CancelFunc upstreamClient upstreamClient upstreamServers []string + domain string disabled bool failsCount atomic.Int32 successCount atomic.Int32 @@ -53,12 +58,13 @@ type upstreamResolverBase struct { statusRecorder *peer.Status } -func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase { +func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase { ctx, cancel := context.WithCancel(ctx) return &upstreamResolverBase{ ctx: ctx, cancel: cancel, + domain: domain, upstreamTimeout: upstreamTimeout, reactivatePeriod: reactivatePeriod, failsTillDeact: failsTillDeact, @@ -71,6 +77,17 @@ func (u *upstreamResolverBase) String() string { return fmt.Sprintf("upstream %v", u.upstreamServers) } +// ID returns the unique handler ID +func (u *upstreamResolverBase) id() handlerID { + servers := slices.Clone(u.upstreamServers) + slices.Sort(servers) + + hash := sha256.New() + hash.Write([]byte(u.domain + ":")) + hash.Write([]byte(strings.Join(servers, ","))) + return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) +} + func (u *upstreamResolverBase) MatchSubdomains() bool { return true } @@ -87,7 +104,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { u.checkUpstreamFails(err) }() - log.WithField("question", r.Question[0]).Trace("received an upstream question") + log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) // set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records if r.Extra == nil { r.SetEdns0(4096, false) @@ -96,6 +113,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { select { case <-u.ctx.Done(): + log.Tracef("%s has been stopped", u) return default: } @@ -112,41 +130,36 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if err != nil { if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { - log.WithError(err).WithField("upstream", upstream). - Warn("got an error while connecting to upstream") + log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) continue } - u.failsCount.Add(1) - log.WithError(err).WithField("upstream", upstream). - Error("got other error while querying the upstream") - return + log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) + continue } - if rm == nil { - log.WithError(err).WithField("upstream", upstream). - Warn("no response from upstream") - return - } - // those checks need to be independent of each other due to memory address issues - if !rm.Response { - log.WithError(err).WithField("upstream", upstream). - Warn("no response from upstream") - return + if rm == nil || !rm.Response { + log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) + continue } u.successCount.Add(1) - log.Tracef("took %s to query the upstream %s", t, upstream) + log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) - err = w.WriteMsg(rm) - if err != nil { - log.WithError(err).Error("got an error while writing the upstream resolver response") + if err = w.WriteMsg(rm); err != nil { + log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) } // count the fails only if they happen sequentially u.failsCount.Store(0) return } u.failsCount.Add(1) - log.Error("all queries to the upstream nameservers failed with timeout") + log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) + + m := new(dns.Msg) + m.SetRcode(r, dns.RcodeServerFailure) + if err := w.WriteMsg(m); err != nil { + log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) + } } // checkUpstreamFails counts fails and disables or enables upstream resolving diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index 36ea05e44..a9e46ca02 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -27,8 +27,9 @@ func newUpstreamResolver( _ *net.IPNet, statusRecorder *peer.Status, hostsDNSHolder *hostsDNSHolder, + domain string, ) (*upstreamResolver, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) c := &upstreamResolver{ upstreamResolverBase: upstreamResolverBase, hostsDNSHolder: hostsDNSHolder, diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index a29350f8c..51acbf7a6 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -23,8 +23,9 @@ func newUpstreamResolver( _ *net.IPNet, statusRecorder *peer.Status, _ *hostsDNSHolder, + domain string, ) (*upstreamResolver, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) nonIOS := &upstreamResolver{ upstreamResolverBase: upstreamResolverBase, } diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 60ed79d87..7d3301e14 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -30,8 +30,9 @@ func newUpstreamResolver( net *net.IPNet, statusRecorder *peer.Status, _ *hostsDNSHolder, + domain string, ) (*upstreamResolverIOS, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) ios := &upstreamResolverIOS{ upstreamResolverBase: upstreamResolverBase, diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index c1251dcc1..c5adc0858 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -20,6 +20,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { timeout time.Duration cancelCTX bool expectedAnswer string + acceptNXDomain bool }{ { name: "Should Resolve A Record", @@ -36,11 +37,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { expectedAnswer: "1.1.1.1", }, { - name: "Should Not Resolve If Can't Connect To Both Servers", - inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), - InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"}, - timeout: 200 * time.Millisecond, - responseShouldBeNil: true, + name: "Should Not Resolve If Can't Connect To Both Servers", + inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), + InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"}, + timeout: 200 * time.Millisecond, + acceptNXDomain: true, }, { name: "Should Not Resolve If Parent Context Is Canceled", @@ -51,14 +52,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { responseShouldBeNil: true, }, } - // should resolve if first upstream times out - // should not write when both fails - // should not resolve if parent context is canceled for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) - resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil) + resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".") resolver.upstreamServers = testCase.InputServers resolver.upstreamTimeout = testCase.timeout if testCase.cancelCTX { @@ -84,16 +82,22 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { t.Fatalf("should write a response message") } - foundAnswer := false - for _, answer := range responseMSG.Answer { - if strings.Contains(answer.String(), testCase.expectedAnswer) { - foundAnswer = true - break - } + if testCase.acceptNXDomain && responseMSG.Rcode == dns.RcodeNameError { + return } - if !foundAnswer { - t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer) + if testCase.expectedAnswer != "" { + foundAnswer := false + for _, answer := range responseMSG.Answer { + if strings.Contains(answer.String(), testCase.expectedAnswer) { + foundAnswer = true + break + } + } + + if !foundAnswer { + t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer) + } } }) } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 0df2a2e81..311ddbd7f 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -721,7 +721,9 @@ func (d *Status) GetRelayStates() []relay.ProbeResult { func (d *Status) GetDNSStates() []NSGroupState { d.mux.Lock() defer d.mux.Unlock() - return d.nsGroupStates + + // shallow copy is good enough, as slices fields are currently not updated + return slices.Clone(d.nsGroupStates) } func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {