From 21ba6ad2667cb97841167363f241b570c7fe269d Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 17 Dec 2024 12:57:07 +0100 Subject: [PATCH] Improve dns forwarder errors and improve domain anonymization (#3052) * Improve dns forwarder errors and improve domain anonymization * Use original domain for dns states * Don't match subdomains for non-wildcard dns routes * Fix iOS * Add string representation for local resolver * Return correct handler for dynamic * Add dns server dns route + upstream handler test --- client/anonymize/anonymize.go | 15 +- client/anonymize/anonymize_test.go | 56 ++++++- client/cmd/networks.go | 16 +- client/internal/dns/handler_chain.go | 80 ++++++--- client/internal/dns/handler_chain_test.go | 153 ++++++++++-------- client/internal/dns/local.go | 14 +- client/internal/dns/server_test.go | 85 ++++++++++ client/internal/dns/upstream.go | 6 +- client/internal/dnsfwd/forwarder.go | 37 ++++- client/internal/peer/status.go | 28 +++- client/internal/routemanager/client.go | 2 +- .../routemanager/dnsinterceptor/handler.go | 56 +++++-- client/internal/routemanager/dynamic/route.go | 2 +- client/ios/NetBirdSDK/client.go | 7 +- client/server/network.go | 27 +++- client/ui/network.go | 6 +- 16 files changed, 429 insertions(+), 161 deletions(-) diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index 9a6d97207..89552724a 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -21,6 +21,8 @@ type Anonymizer struct { currentAnonIPv6 netip.Addr startAnonIPv4 netip.Addr startAnonIPv6 netip.Addr + + domainKeyRegex *regexp.Regexp } func DefaultAddresses() (netip.Addr, netip.Addr) { @@ -36,6 +38,8 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer { currentAnonIPv6: startIPv6, startAnonIPv4: startIPv4, startAnonIPv6: startIPv6, + + domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`), } } @@ -171,20 +175,15 @@ func (a *Anonymizer) AnonymizeSchemeURI(text string) string { return re.ReplaceAllStringFunc(text, a.AnonymizeURI) } -// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string. func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string { - domainPattern := `dns\.Question{Name:"([^"]+)",` - domainRegex := regexp.MustCompile(domainPattern) - - return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string { - parts := strings.Split(match, `"`) + return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string { + parts := strings.SplitN(match, "=", 2) if len(parts) >= 2 { domain := parts[1] if strings.HasSuffix(domain, anonTLD) { return match } - randomDomain := generateRandomString(10) + anonTLD - return strings.Replace(match, domain, randomDomain, 1) + return "domain=" + a.AnonymizeDomain(domain) } return match }) diff --git a/client/anonymize/anonymize_test.go b/client/anonymize/anonymize_test.go index a3aae1ee9..ff2e48869 100644 --- a/client/anonymize/anonymize_test.go +++ b/client/anonymize/anonymize_test.go @@ -46,11 +46,59 @@ func TestAnonymizeIP(t *testing.T) { func TestAnonymizeDNSLogLine(t *testing.T) { anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{}) - testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}` + tests := []struct { + name string + input string + original string + expect string + }{ + { + name: "Basic domain with trailing content", + input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123", + original: "example.com", + expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`, + }, + { + name: "Domain with trailing dot", + input: "domain=example.com. processing request with status=pending", + original: "example.com", + expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`, + }, + { + name: "Multiple domains in log", + input: "forward domain=first.com status=ok, redirect to domain=second.com port=443", + original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately + expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`, + }, + { + name: "Already anonymized domain", + input: "got request domain=anon-xyz123.domain from=client1 to=server2", + original: "", // nothing should be anonymized + expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`, + }, + { + name: "Subdomain with trailing dot", + input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp", + original: "example.com", + expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`, + }, + { + name: "Handler chain pattern log", + input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100", + original: "example.com", + expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`, + }, + } - result := anonymizer.AnonymizeDNSLogLine(testLog) - require.NotEqual(t, testLog, result) - assert.NotContains(t, result, "example.com") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := anonymizer.AnonymizeDNSLogLine(tc.input) + if tc.original != "" { + assert.NotContains(t, result, tc.original) + } + assert.Regexp(t, tc.expect, result) + }) + } } func TestAnonymizeDomain(t *testing.T) { diff --git a/client/cmd/networks.go b/client/cmd/networks.go index 6ebf13810..7b9724bc5 100644 --- a/client/cmd/networks.go +++ b/client/cmd/networks.go @@ -68,19 +68,19 @@ func networksList(cmd *cobra.Command, _ []string) error { return nil } - printRoutes(cmd, resp) + printNetworks(cmd, resp) return nil } -func printRoutes(cmd *cobra.Command, resp *proto.ListNetworksResponse) { +func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) { cmd.Println("Available Networks:") for _, route := range resp.Routes { - printRoute(cmd, route) + printNetwork(cmd, route) } } -func printRoute(cmd *cobra.Command, route *proto.Network) { +func printNetwork(cmd *cobra.Command, route *proto.Network) { selectedStatus := getSelectedStatus(route) domains := route.GetDomains() @@ -113,12 +113,10 @@ func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus) } -func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) { +func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) { cmd.Printf(" Resolved IPs:\n") - for _, domain := range domains { - if ipList, exists := resolvedIPs[domain]; exists { - cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", ")) - } + for resolvedDomain, ipList := range resolvedIPs { + cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", ")) } } diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 4a525844b..c6ac3ebd6 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -14,13 +14,19 @@ const ( PriorityDefault = 0 ) +type SubdomainMatcher interface { + dns.Handler + MatchSubdomains() bool +} + type HandlerEntry struct { - Handler dns.Handler - Priority int - Pattern string - OrigPattern string - IsWildcard bool - StopHandler handlerWithStop + Handler dns.Handler + Priority int + Pattern string + OrigPattern string + IsWildcard bool + StopHandler handlerWithStop + MatchSubdomains bool } // HandlerChain represents a prioritized chain of DNS handlers @@ -32,6 +38,7 @@ type HandlerChain struct { // ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain type ResponseWriterChain struct { dns.ResponseWriter + origPattern string shouldContinue bool } @@ -50,6 +57,11 @@ func NewHandlerChain() *HandlerChain { } } +// GetOrigPattern returns the original pattern of the handler that wrote the response +func (w *ResponseWriterChain) GetOrigPattern() string { + return w.origPattern +} + // AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) { c.mu.Lock() @@ -74,16 +86,23 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority } } - log.Debugf("adding handler for pattern: %s (original: %s, wildcard: %v) with priority %d", - pattern, origPattern, isWildcard, priority) + // Check if handler implements SubdomainMatcher interface + matchSubdomains := false + if matcher, ok := handler.(SubdomainMatcher); ok { + matchSubdomains = matcher.MatchSubdomains() + } + + log.Debugf("adding handler pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d", + pattern, origPattern, isWildcard, matchSubdomains, priority) entry := HandlerEntry{ - Handler: handler, - Priority: priority, - Pattern: pattern, - OrigPattern: origPattern, - IsWildcard: isWildcard, - StopHandler: stopHandler, + Handler: handler, + Priority: priority, + Pattern: pattern, + OrigPattern: origPattern, + IsWildcard: isWildcard, + StopHandler: stopHandler, + MatchSubdomains: matchSubdomains, } // Insert handler in priority order @@ -139,14 +158,14 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } qname := r.Question[0].Name - log.Debugf("handling DNS request for %s", qname) + log.Tracef("handling DNS request for domain=%s", qname) c.mu.RLock() defer c.mu.RUnlock() - log.Debugf("current handlers (%d):", len(c.handlers)) + log.Tracef("current handlers (%d):", len(c.handlers)) for _, h := range c.handlers { - log.Debugf(" - pattern: %s, original: %s, wildcard: %v, priority: %d", + log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d", h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority) } @@ -160,30 +179,41 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".") matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern) default: - matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern) + // For non-wildcard patterns: + // If handler wants subdomain matching, allow suffix match + // Otherwise require exact match + if entry.MatchSubdomains { + matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern) + } else { + matched = qname == entry.Pattern + } } if !matched { - log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false", - entry.OrigPattern, qname, entry.IsWildcard) + log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false", + qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard) continue } - log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v", - entry.OrigPattern, qname, entry.IsWildcard) - chainWriter := &ResponseWriterChain{ResponseWriter: w} + log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v", + qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains) + + chainWriter := &ResponseWriterChain{ + ResponseWriter: w, + origPattern: entry.OrigPattern, + } entry.Handler.ServeDNS(chainWriter, r) // If handler wants to continue, try next handler if chainWriter.shouldContinue { - log.Debugf("handler requested continue to next handler") + log.Tracef("handler requested continue to next handler") continue } return } // No handler matched or all handlers passed - log.Debugf("no handler found for %s", qname) + log.Tracef("no handler found for domain=%s", qname) resp := &dns.Msg{} resp.SetRcode(r, dns.RcodeNameError) if err := w.WriteMsg(resp); err != nil { diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 01ed5f4e7..727b6e908 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -11,23 +11,14 @@ import ( nbdns "github.com/netbirdio/netbird/client/internal/dns" ) -// MockHandler implements dns.Handler interface for testing -type MockHandler struct { - mock.Mock -} - -func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - m.Called(w, r) -} - // TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { chain := nbdns.NewHandlerChain() // Create mock handlers for different priorities - defaultHandler := &MockHandler{} - matchDomainHandler := &MockHandler{} - dnsRouteHandler := &MockHandler{} + defaultHandler := &nbdns.MockHandler{} + matchDomainHandler := &nbdns.MockHandler{} + dnsRouteHandler := &nbdns.MockHandler{} // Setup handlers with different priorities chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil) @@ -58,78 +49,108 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { // TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { tests := []struct { - name string - handlerDomain string - queryDomain string - isWildcard bool - shouldMatch bool + name string + handlerDomain string + queryDomain string + isWildcard bool + matchSubdomains bool + shouldMatch bool }{ { - name: "exact match", - handlerDomain: "example.com.", - queryDomain: "example.com.", - isWildcard: false, - shouldMatch: true, + name: "exact match", + handlerDomain: "example.com.", + queryDomain: "example.com.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: true, }, { - name: "subdomain with non-wildcard", - handlerDomain: "example.com.", - queryDomain: "sub.example.com.", - isWildcard: false, - shouldMatch: true, + name: "subdomain with non-wildcard and MatchSubdomains true", + handlerDomain: "example.com.", + queryDomain: "sub.example.com.", + isWildcard: false, + matchSubdomains: true, + shouldMatch: true, }, { - name: "wildcard match", - handlerDomain: "*.example.com.", - queryDomain: "sub.example.com.", - isWildcard: true, - shouldMatch: true, + name: "subdomain with non-wildcard and MatchSubdomains false", + handlerDomain: "example.com.", + queryDomain: "sub.example.com.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: false, }, { - name: "wildcard no match on apex", - handlerDomain: "*.example.com.", - queryDomain: "example.com.", - isWildcard: true, - shouldMatch: false, + name: "wildcard match", + handlerDomain: "*.example.com.", + queryDomain: "sub.example.com.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: true, }, { - name: "root zone match", - handlerDomain: ".", - queryDomain: "anything.com.", - isWildcard: false, - shouldMatch: true, + name: "wildcard no match on apex", + handlerDomain: "*.example.com.", + queryDomain: "example.com.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: false, }, { - name: "no match different domain", - handlerDomain: "example.com.", - queryDomain: "example.org.", - isWildcard: false, - shouldMatch: false, + name: "root zone match", + handlerDomain: ".", + queryDomain: "anything.com.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "no match different domain", + handlerDomain: "example.com.", + queryDomain: "example.org.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { chain := nbdns.NewHandlerChain() - mockHandler := &MockHandler{} + var handler dns.Handler + + if tt.matchSubdomains { + mockSubHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + handler = mockSubHandler + if tt.shouldMatch { + mockSubHandler.On("ServeDNS", mock.Anything, mock.Anything).Once() + } + } else { + mockHandler := &nbdns.MockHandler{} + handler = mockHandler + if tt.shouldMatch { + mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Once() + } + } pattern := tt.handlerDomain if tt.isWildcard { - pattern = "*." + tt.handlerDomain[2:] // Remove the first two chars if it's a wildcard + pattern = "*." + tt.handlerDomain[2:] } - chain.AddHandler(pattern, mockHandler, nbdns.PriorityDefault, nil) + chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil) r := new(dns.Msg) r.SetQuestion(tt.queryDomain, dns.TypeA) w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} - if tt.shouldMatch { - mockHandler.On("ServeDNS", mock.Anything, r).Once() - } - chain.ServeDNS(w, r) - mockHandler.AssertExpectations(t) + + if h, ok := handler.(*nbdns.MockHandler); ok { + h.AssertExpectations(t) + } else if h, ok := handler.(*nbdns.MockSubdomainHandler); ok { + h.AssertExpectations(t) + } }) } } @@ -218,11 +239,11 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { chain := nbdns.NewHandlerChain() - var handlers []*MockHandler + var handlers []*nbdns.MockHandler // Setup handlers and expectations for i := range tt.handlers { - handler := &MockHandler{} + handler := &nbdns.MockHandler{} handlers = append(handlers, handler) // Set expectation based on whether this handler should be called @@ -254,9 +275,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { chain := nbdns.NewHandlerChain() // Create handlers - handler1 := &MockHandler{} - handler2 := &MockHandler{} - handler3 := &MockHandler{} + handler1 := &nbdns.MockHandler{} + handler2 := &nbdns.MockHandler{} + handler3 := &nbdns.MockHandler{} // Add handlers in priority order chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil) @@ -388,12 +409,12 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { chain := nbdns.NewHandlerChain() - handlers := make(map[int]*MockHandler) + handlers := make(map[int]*nbdns.MockHandler) // Execute operations for _, op := range tt.ops { if op.action == "add" { - handler := &MockHandler{} + handler := &nbdns.MockHandler{} handlers[op.priority] = handler chain.AddHandler(op.pattern, handler, op.priority, nil) } else { @@ -440,10 +461,10 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { testDomain := "example.com." testQuery := "test.example.com." - // Create handlers for three priority levels - routeHandler := &MockHandler{} - matchHandler := &MockHandler{} - defaultHandler := &MockHandler{} + // Create handlers with MatchSubdomains enabled + routeHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + matchHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + defaultHandler := &nbdns.MockSubdomainHandler{Subdomains: true} // Create test request that will be reused r := new(dns.Msg) diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go index 6a459794b..9a78d4d50 100644 --- a/client/internal/dns/local.go +++ b/client/internal/dns/local.go @@ -17,12 +17,24 @@ type localResolver struct { records sync.Map } +func (d *localResolver) MatchSubdomains() bool { + return true +} + func (d *localResolver) stop() { } +// String returns a string representation of the local resolver +func (d *localResolver) String() string { + return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap)) +} + // ServeDNS handles a DNS request func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - log.Tracef("received question: %#v", r.Question[0]) + if len(r.Question) > 0 { + log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + } + replyMessage := &dns.Msg{} replyMessage.SetReply(r) replyMessage.RecursionAvailable = true diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index aca7653a3..44d20c6f3 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -11,7 +11,9 @@ import ( "time" "github.com/golang/mock/gomock" + "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/mock" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/firewall/uspfilter" @@ -874,3 +876,86 @@ func newDnsResolver(ip string, port int) *net.Resolver { }, } } + +// MockHandler implements dns.Handler interface for testing +type MockHandler struct { + mock.Mock +} + +func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + m.Called(w, r) +} + +type MockSubdomainHandler struct { + MockHandler + Subdomains bool +} + +func (m *MockSubdomainHandler) MatchSubdomains() bool { + return m.Subdomains +} + +func TestHandlerChain_DomainPriorities(t *testing.T) { + chain := NewHandlerChain() + + dnsRouteHandler := &MockHandler{} + upstreamHandler := &MockSubdomainHandler{ + Subdomains: true, + } + + chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil) + chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil) + + testCases := []struct { + name string + query string + expectedHandler dns.Handler + }{ + { + name: "exact domain with dns route handler", + query: "example.com.", + expectedHandler: dnsRouteHandler, + }, + { + name: "subdomain should use upstream handler", + query: "sub.example.com.", + expectedHandler: upstreamHandler, + }, + { + name: "deep subdomain should use upstream handler", + query: "deep.sub.example.com.", + expectedHandler: upstreamHandler, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := new(dns.Msg) + r.SetQuestion(tc.query, dns.TypeA) + w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + if mh, ok := tc.expectedHandler.(*MockHandler); ok { + mh.On("ServeDNS", mock.Anything, r).Once() + } else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok { + mh.On("ServeDNS", mock.Anything, r).Once() + } + + chain.ServeDNS(w, r) + + if mh, ok := tc.expectedHandler.(*MockHandler); ok { + mh.AssertExpectations(t) + } else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok { + mh.AssertExpectations(t) + } + + // Reset mocks + if mh, ok := tc.expectedHandler.(*MockHandler); ok { + mh.ExpectedCalls = nil + mh.Calls = nil + } else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok { + mh.ExpectedCalls = nil + mh.Calls = nil + } + }) + } +} diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 94497c61f..f0aa12b65 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -68,7 +68,11 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) * // String returns a string representation of the upstream resolver func (u *upstreamResolverBase) String() string { - return fmt.Sprintf("%v", u.upstreamServers) + return fmt.Sprintf("upstream %v", u.upstreamServers) +} + +func (u *upstreamResolverBase) MatchSubdomains() bool { + return true } func (u *upstreamResolverBase) stop() { diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index dd9636158..f886a54d2 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -2,6 +2,7 @@ package dnsfwd import ( "context" + "errors" "net" "github.com/miekg/dns" @@ -10,6 +11,8 @@ import ( nbdns "github.com/netbirdio/netbird/dns" ) +const errResolveFailed = "failed to resolve query for domain=%s: %v" + type DNSForwarder struct { listenAddress string ttl uint32 @@ -20,15 +23,16 @@ type DNSForwarder struct { } func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder { - log.Debugf("creating DNS forwarder with listen address: %s, ttl: %d, domains: %v", listenAddress, ttl, domains) + log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d domains=%v", listenAddress, ttl, domains) return &DNSForwarder{ listenAddress: listenAddress, ttl: ttl, domains: domains, } } + func (f *DNSForwarder) Listen() error { - log.Infof("listen DNS forwarder on: %s", f.listenAddress) + log.Infof("listen DNS forwarder on address=%s", f.listenAddress) mux := dns.NewServeMux() for _, d := range f.domains { @@ -67,7 +71,8 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { if len(query.Question) == 0 { return } - log.Tracef("received DNS request for DNS forwarder: %v", query.Question[0].Name) + log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", + query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass) question := query.Question[0] domain := question.Name @@ -76,8 +81,26 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { ips, err := net.LookupIP(domain) if err != nil { - log.Warnf("failed to resolve query for domain %s: %v", domain, err) - resp.Rcode = dns.RcodeServerFailure + var dnsErr *net.DNSError + + switch { + case errors.As(err, &dnsErr): + resp.Rcode = dns.RcodeServerFailure + if dnsErr.IsNotFound { + // Pass through NXDOMAIN + resp.Rcode = dns.RcodeNameError + } + + if dnsErr.Server != "" { + log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err) + } else { + log.Warnf(errResolveFailed, domain, err) + } + default: + resp.Rcode = dns.RcodeServerFailure + log.Warnf(errResolveFailed, domain, err) + } + if err := w.WriteMsg(resp); err != nil { log.Errorf("failed to write failure DNS response: %v", err) } @@ -87,7 +110,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { for _, ip := range ips { var respRecord dns.RR if ip.To4() == nil { - log.Tracef("resolved domain %s to IPv6 %s", domain, ip) + log.Tracef("resolved domain=%s to IPv6=%s", domain, ip) rr := dns.AAAA{ AAAA: ip, Hdr: dns.RR_Header{ @@ -99,7 +122,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { } respRecord = &rr } else { - log.Tracef("resolved domain %s to IPv4 %s", domain, ip) + log.Tracef("resolved domain=%s to IPv4=%s", domain, ip) rr := dns.A{ A: ip, Hdr: dns.RR_Header{ diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 74e2ee82c..dc461257a 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -17,6 +17,11 @@ import ( relayClient "github.com/netbirdio/netbird/relay/client" ) +type ResolvedDomainInfo struct { + Prefixes []netip.Prefix + ParentDomain domain.Domain +} + // State contains the latest state of a peer type State struct { Mux *sync.RWMutex @@ -138,7 +143,7 @@ type Status struct { rosenpassEnabled bool rosenpassPermissive bool nsGroupStates []NSGroupState - resolvedDomainsStates map[domain.Domain][]netip.Prefix + resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo // To reduce the number of notification invocation this bool will be true when need to call the notification // Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events @@ -156,7 +161,7 @@ func NewRecorder(mgmAddress string) *Status { offlinePeers: make([]State, 0), notifier: newNotifier(), mgmAddress: mgmAddress, - resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix), + resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{}, } } @@ -591,16 +596,27 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) { d.nsGroupStates = dnsStates } -func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) { +func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) { d.mux.Lock() defer d.mux.Unlock() - d.resolvedDomainsStates[domain] = prefixes + + // Store both the original domain pattern and resolved domain + d.resolvedDomainsStates[resolvedDomain] = ResolvedDomainInfo{ + Prefixes: prefixes, + ParentDomain: originalDomain, + } } func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) { d.mux.Lock() defer d.mux.Unlock() - delete(d.resolvedDomainsStates, domain) + + // Remove all entries that have this domain as their parent + for k, v := range d.resolvedDomainsStates { + if v.ParentDomain == domain { + delete(d.resolvedDomainsStates, k) + } + } } func (d *Status) GetRosenpassState() RosenpassState { @@ -702,7 +718,7 @@ func (d *Status) GetDNSStates() []NSGroupState { return d.nsGroupStates } -func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix { +func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo { d.mux.Lock() defer d.mux.Unlock() return maps.Clone(d.resolvedDomainsStates) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 6265736a1..73f552aab 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -442,5 +442,5 @@ func handlerType(rt *route.Route, useNewDNSRoute bool) int { if useNewDNSRoute { return handlerTypeDomain } - return handlerTypeStatic + return handlerTypeDynamic } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 702290015..28bf20d5f 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -83,6 +83,8 @@ func (d *DnsInterceptor) RemoveRoute() error { } log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", ")) + } + for _, domain := range d.route.Domains { d.statusRecorder.DeleteResolvedDomainsStates(domain) } @@ -138,14 +140,16 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { return } - log.Tracef("received DNS request: %v", r.Question[0].Name) + log.Tracef("received DNS request for domain=%s type=%v class=%v", + r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) d.mu.RLock() peerKey := d.currentPeerKey d.mu.RUnlock() if peerKey == "" { - log.Debugf("no current peer key set, letting next handler try for %s", r.Question[0].Name) + log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name) + d.continueToNextHandler(w, r, "no current peer key") return } @@ -168,7 +172,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if reply != nil { answer = reply.Answer } - log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, peerKey, r.Question[0].Name, answer) + log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer) if err != nil { log.Errorf("failed to exchange DNS request with %s: %v", upstream, err) @@ -186,7 +190,8 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // continueToNextHandler signals the handler chain to try the next handler func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) { - log.Debugf("continuing to next handler for %s: %s", r.Question[0].Name, reason) + log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) + resp := new(dns.Msg) resp.SetRcode(r, dns.RcodeNameError) // Set Zero bit to signal handler chain to continue @@ -210,8 +215,18 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { } if len(r.Answer) > 0 && len(r.Question) > 0 { - // DNS names from miekg/dns are already in punycode format - dom := domain.Domain(r.Question[0].Name) + origPattern := "" + if writer, ok := w.(*nbdns.ResponseWriterChain); ok { + origPattern = writer.GetOrigPattern() + } + + resolvedDomain := domain.Domain(r.Question[0].Name) + + // already punycode via RegisterHandler() + originalDomain := domain.Domain(origPattern) + if originalDomain == "" { + originalDomain = resolvedDomain + } var newPrefixes []netip.Prefix for _, answer := range r.Answer { @@ -220,14 +235,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { case *dns.A: addr, ok := netip.AddrFromSlice(rr.A) if !ok { - log.Debugf("failed to convert A record IP: %v", rr.A) + log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A) continue } ip = addr case *dns.AAAA: addr, ok := netip.AddrFromSlice(rr.AAAA) if !ok { - log.Debugf("failed to convert AAAA record IP: %v", rr.AAAA) + log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA) continue } ip = addr @@ -240,7 +255,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { } if len(newPrefixes) > 0 { - if err := d.updateDomainPrefixes(dom, newPrefixes); err != nil { + if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil { log.Errorf("failed to update domain prefixes: %v", err) } } @@ -253,11 +268,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { return nil } -func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes []netip.Prefix) error { +func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error { d.mu.Lock() defer d.mu.Unlock() - oldPrefixes := d.interceptedDomains[domain] + oldPrefixes := d.interceptedDomains[resolvedDomain] toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes) var merr *multierror.Error @@ -277,7 +292,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes } else if ref.Count > 1 && ref.Out != d.currentPeerKey { log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", prefix.Addr(), - domain.SafeString(), + resolvedDomain.SafeString(), ref.Out, ) } @@ -297,16 +312,23 @@ func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes } } - // Update domain prefixes + // Update domain prefixes using resolved domain as key if len(toAdd) > 0 || len(toRemove) > 0 { - d.interceptedDomains[domain] = newPrefixes - d.statusRecorder.UpdateResolvedDomainsStates(domain, newPrefixes) + d.interceptedDomains[resolvedDomain] = newPrefixes + originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), ".")) + d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes) if len(toAdd) > 0 { - log.Debugf("added dynamic route(s) for [%s]: %s", domain.SafeString(), toAdd) + log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s", + resolvedDomain.SafeString(), + originalDomain.SafeString(), + toAdd) } if len(toRemove) > 0 { - log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), toRemove) + log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", + resolvedDomain.SafeString(), + originalDomain.SafeString(), + toRemove) } } diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index b71a91f74..a0fff7713 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -288,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes) r.dynamicDomains[domain] = updatedPrefixes - r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes) + r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes) } return nberrors.FormatErrorOrNil(merr) diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 1b4413141..befce56a2 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -317,7 +317,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { } -func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails { +func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails { var routeSelection []RoutesSelectionInfo for _, r := range routes { domainList := make([]DomainInfo, 0) @@ -325,9 +325,10 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom domainResp := DomainInfo{ Domain: d.SafeString(), } - if prefixes, exists := resolvedDomains[d]; exists { + + if info, exists := resolvedDomains[d]; exists { var ipStrings []string - for _, prefix := range prefixes { + for _, prefix := range info.Prefixes { ipStrings = append(ipStrings, prefix.Addr().String()) } domainResp.ResolvedIPs = strings.Join(ipStrings, ", ") diff --git a/client/server/network.go b/client/server/network.go index b4b4071b4..aaf361524 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/netip" + "slices" "sort" "golang.org/x/exp/maps" @@ -77,17 +78,27 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro Selected: route.Selected, } - for _, domain := range route.Domains { - if prefixes, exists := resolvedDomains[domain]; exists { - var ipStrings []string - for _, prefix := range prefixes { - ipStrings = append(ipStrings, prefix.Addr().String()) - } - pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{ - Ips: ipStrings, + // Group resolved IPs by their parent domain + domainMap := map[domain.Domain][]string{} + + for resolvedDomain, info := range resolvedDomains { + // Check if this resolved domain's parent is in our route's domains + if slices.Contains(route.Domains, info.ParentDomain) { + ips := make([]string, 0, len(info.Prefixes)) + for _, prefix := range info.Prefixes { + ips = append(ips, prefix.Addr().String()) } + domainMap[resolvedDomain] = ips } } + + // Convert to proto format + for domain, ips := range domainMap { + pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{ + Ips: ips, + } + } + pbRoutes = append(pbRoutes, pbRoute) } diff --git a/client/ui/network.go b/client/ui/network.go index a74c714e0..e6f027f0e 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -129,10 +129,8 @@ func (s *serviceClient) updateNetworks(grid *fyne.Container, f filter) { grid.Add(domainsSelector) var resolvedIPsList []string - for _, domain := range domains { - if ipList, exists := r.GetResolvedIPs()[domain]; exists { - resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", "))) - } + for domain, ipList := range r.GetResolvedIPs() { + resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", "))) } if len(resolvedIPsList) == 0 {