From 75c1be69cfba17293d085841d58a939e7936c6c7 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 17 Jun 2025 14:02:30 +0200 Subject: [PATCH] [client] Prioritze the local resolver in the dns handler chain (#3965) --- client/internal/dns/handler_chain.go | 7 +- client/internal/dns/handler_chain_test.go | 62 ++++---- client/internal/dns/server.go | 14 +- client/internal/dns/server_test.go | 170 ++++++++++++++++++---- 4 files changed, 183 insertions(+), 70 deletions(-) diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 22caaa761..7e7e7cc2d 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -11,9 +11,10 @@ import ( ) const ( - PriorityDNSRoute = 100 - PriorityMatchDomain = 50 - PriorityDefault = 1 + PriorityLocal = 100 + PriorityDNSRoute = 75 + PriorityUpstream = 50 + PriorityDefault = 1 ) type SubdomainMatcher interface { diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 5f03e0758..72c0004d5 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -22,7 +22,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { // Setup handlers with different priorities chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault) - chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain) + chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityUpstream) chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute) // Create test request @@ -200,7 +200,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { priority int }{ {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, - {pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "*.example.com.", priority: nbdns.PriorityUpstream}, {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, }, queryDomain: "test.example.com.", @@ -214,7 +214,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { priority int }{ {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, - {pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "test.example.com.", priority: nbdns.PriorityUpstream}, {pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute}, }, queryDomain: "sub.test.example.com.", @@ -281,7 +281,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { // Add handlers in priority order chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute) - chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain) + chain.AddHandler("example.com.", handler2, nbdns.PriorityUpstream) chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault) // Create test request @@ -344,13 +344,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { priority int }{ {"add", "example.com.", nbdns.PriorityDNSRoute}, - {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"add", "example.com.", nbdns.PriorityUpstream}, {"remove", "example.com.", nbdns.PriorityDNSRoute}, }, query: "example.com.", expectedCalls: map[int]bool{ - nbdns.PriorityDNSRoute: false, - nbdns.PriorityMatchDomain: true, + nbdns.PriorityDNSRoute: false, + nbdns.PriorityUpstream: true, }, }, { @@ -361,13 +361,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { priority int }{ {"add", "example.com.", nbdns.PriorityDNSRoute}, - {"add", "example.com.", nbdns.PriorityMatchDomain}, - {"remove", "example.com.", nbdns.PriorityMatchDomain}, + {"add", "example.com.", nbdns.PriorityUpstream}, + {"remove", "example.com.", nbdns.PriorityUpstream}, }, query: "example.com.", expectedCalls: map[int]bool{ - nbdns.PriorityDNSRoute: true, - nbdns.PriorityMatchDomain: false, + nbdns.PriorityDNSRoute: true, + nbdns.PriorityUpstream: false, }, }, { @@ -378,16 +378,16 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { priority int }{ {"add", "example.com.", nbdns.PriorityDNSRoute}, - {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"add", "example.com.", nbdns.PriorityUpstream}, {"add", "example.com.", nbdns.PriorityDefault}, {"remove", "example.com.", nbdns.PriorityDNSRoute}, - {"remove", "example.com.", nbdns.PriorityMatchDomain}, + {"remove", "example.com.", nbdns.PriorityUpstream}, }, query: "example.com.", expectedCalls: map[int]bool{ - nbdns.PriorityDNSRoute: false, - nbdns.PriorityMatchDomain: false, - nbdns.PriorityDefault: true, + nbdns.PriorityDNSRoute: false, + nbdns.PriorityUpstream: false, + nbdns.PriorityDefault: true, }, }, } @@ -454,7 +454,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Add handlers in mixed order chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute) - chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) + chain.AddHandler(testDomain, matchHandler, nbdns.PriorityUpstream) // Test 1: Initial state w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} @@ -490,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { defaultHandler.Calls = nil // Test 3: Remove middle priority handler - chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) + chain.RemoveHandler(testDomain, nbdns.PriorityUpstream) w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Now lowest priority handler (defaultHandler) should be called @@ -607,7 +607,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { shouldMatch bool }{ {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, - {"example.com.", nbdns.PriorityMatchDomain, false, false}, + {"example.com.", nbdns.PriorityUpstream, false, false}, {"Example.Com.", nbdns.PriorityDNSRoute, false, true}, }, query: "example.com.", @@ -702,8 +702,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, + {"add", "example.com.", nbdns.PriorityUpstream, true}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, false}, }, query: "sub.example.com.", expectedMatch: "sub.example.com.", @@ -717,8 +717,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, true}, + {"add", "example.com.", nbdns.PriorityUpstream, true}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, true}, }, query: "sub.example.com.", expectedMatch: "sub.example.com.", @@ -732,10 +732,10 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false}, - {"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false}, + {"add", "example.com.", nbdns.PriorityUpstream, true}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, true}, + {"add", "test.sub.example.com.", nbdns.PriorityUpstream, false}, + {"remove", "test.sub.example.com.", nbdns.PriorityUpstream, false}, }, query: "test.sub.example.com.", expectedMatch: "sub.example.com.", @@ -749,7 +749,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, false}, {"add", "example.com.", nbdns.PriorityDNSRoute, true}, }, query: "sub.example.com.", @@ -764,9 +764,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "other.example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, + {"add", "example.com.", nbdns.PriorityUpstream, true}, + {"add", "other.example.com.", nbdns.PriorityUpstream, true}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, false}, }, query: "sub.example.com.", expectedMatch: "sub.example.com.", diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 7b845235c..e81aebf98 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -527,7 +527,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) muxUpdates = append(muxUpdates, handlerWrapper{ domain: customZone.Domain, handler: s.localResolver, - priority: PriorityMatchDomain, + priority: PriorityLocal, }) for _, record := range customZone.Records { @@ -566,7 +566,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam groupedNS := groupNSGroupsByDomain(nameServerGroups) for _, domainGroup := range groupedNS { - basePriority := PriorityMatchDomain + basePriority := PriorityUpstream if domainGroup.domain == nbdns.RootZone { basePriority = PriorityDefault } @@ -588,10 +588,14 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts priority := basePriority - i - // Check if we're about to overlap with the next priority tier - if basePriority == PriorityMatchDomain && priority <= PriorityDefault { + // Check if we're about to overlap with the next priority tier. + // This boundary check ensures that the priority of upstream handlers does not conflict + // with the default priority tier. By decrementing the priority for each handler, we avoid + // overlaps, but if the calculated priority falls into the default tier, we skip the remaining + // handlers to maintain the integrity of the priority system. + if basePriority == PriorityUpstream && priority <= PriorityDefault { log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", - domainGroup.domain, PriorityMatchDomain-PriorityDefault) + domainGroup.domain, PriorityUpstream-PriorityDefault) break } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index e55b27910..1cf59fb5b 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -164,12 +164,12 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ domain: "netbird.io", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, dummyHandler.ID(): handlerWrapper{ domain: "netbird.cloud", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityLocal, }, generateDummyHandler(".", nameServers).ID(): handlerWrapper{ domain: nbdns.RootZone, @@ -186,7 +186,7 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: "netbird.cloud", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, initSerial: 0, @@ -210,12 +210,12 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ domain: "netbird.io", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, "local-resolver": handlerWrapper{ domain: "netbird.cloud", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityLocal, }, }, expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}}, @@ -305,7 +305,7 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: zoneRecords[0].Name, handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, initSerial: 0, @@ -321,7 +321,7 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: zoneRecords[0].Name, handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, initSerial: 0, @@ -495,7 +495,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { "id1": handlerWrapper{ domain: zoneRecords[0].Name, handler: &local.Resolver{}, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, } //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} @@ -978,7 +978,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) { } chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute) - chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain) + chain.AddHandler("example.com.", upstreamHandler, PriorityUpstream) testCases := []struct { name string @@ -1059,14 +1059,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, "upstream-group2": { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, } @@ -1093,21 +1093,21 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, "upstream-group2": { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, "upstream-other": { domain: "other.com", handler: &mockHandler{ Id: "upstream-other", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, } @@ -1128,7 +1128,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, }, expectedHandlers: map[string]string{ @@ -1146,7 +1146,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, expectedHandlers: map[string]string{ @@ -1164,7 +1164,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group3", }, - priority: PriorityMatchDomain + 1, + priority: PriorityUpstream + 1, }, // Keep existing groups with their original priorities { @@ -1172,14 +1172,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, }, expectedHandlers: map[string]string{ @@ -1199,14 +1199,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, // Add group3 with lowest priority { @@ -1214,7 +1214,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group3", }, - priority: PriorityMatchDomain - 2, + priority: PriorityUpstream - 2, }, }, expectedHandlers: map[string]string{ @@ -1335,14 +1335,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "other.com", handler: &mockHandler{ Id: "upstream-other", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, expectedHandlers: map[string]string{ @@ -1360,28 +1360,28 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, { domain: "other.com", handler: &mockHandler{ Id: "upstream-other", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "new.com", handler: &mockHandler{ Id: "upstream-new", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, expectedHandlers: map[string]string{ @@ -1791,14 +1791,14 @@ func TestExtraDomainsRefCounting(t *testing.T) { // Register domains from different handlers with same domain server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute) - server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain) + server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityUpstream) // Verify refcount is 2 zoneKey := toZone("shared.example.com") assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice") // Deregister one handler - server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain) + server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityUpstream) // Verify refcount is 1 assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler") @@ -1925,7 +1925,7 @@ func TestDomainCaseHandling(t *testing.T) { } server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault) - server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain) + server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityUpstream) assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized") @@ -1945,3 +1945,111 @@ func TestDomainCaseHandling(t *testing.T) { assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent") assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present") } + +func TestLocalResolverPriorityInServer(t *testing.T) { + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: &mocWGIface{}, + handlerChain: NewHandlerChain(), + localResolver: local.NewResolver(), + service: &mockService{}, + extraDomains: make(map[domain.Domain]int), + } + + config := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "local.example.com", + Records: []nbdns.SimpleRecord{ + { + Name: "test.local.example.com", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.100", + }, + }, + }, + }, + NameServerGroups: []*nbdns.NameServerGroup{ + { + Domains: []string{"local.example.com"}, // Same domain as local records + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + }, + }, + } + + localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones) + assert.NoError(t, err) + + upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups) + assert.NoError(t, err) + + // Verify that local handler has higher priority than upstream for same domain + var localPriority, upstreamPriority int + localFound, upstreamFound := false, false + + for _, update := range localMuxUpdates { + if update.domain == "local.example.com" { + localPriority = update.priority + localFound = true + } + } + + for _, update := range upstreamMuxUpdates { + if update.domain == "local.example.com" { + upstreamPriority = update.priority + upstreamFound = true + } + } + + assert.True(t, localFound, "Local handler should be found") + assert.True(t, upstreamFound, "Upstream handler should be found") + assert.Greater(t, localPriority, upstreamPriority, + "Local handler priority (%d) should be higher than upstream priority (%d)", + localPriority, upstreamPriority) + assert.Equal(t, PriorityLocal, localPriority, "Local handler should use PriorityLocal") + assert.Equal(t, PriorityUpstream, upstreamPriority, "Upstream handler should use PriorityUpstream") +} + +func TestLocalResolverPriorityConstants(t *testing.T) { + // Test that priority constants are ordered correctly + assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route") + assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream") + assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default") + + // Test that local resolver uses the correct priority + server := &DefaultServer{ + localResolver: local.NewResolver(), + } + + config := nbdns.Config{ + CustomZones: []nbdns.CustomZone{ + { + Domain: "local.example.com", + Records: []nbdns.SimpleRecord{ + { + Name: "test.local.example.com", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.100", + }, + }, + }, + }, + } + + localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones) + assert.NoError(t, err) + assert.Len(t, localMuxUpdates, 1) + assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal") + assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) +}