From de7384e8ea363815f7a05dc3c96f7a4eaca73073 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 17 Jun 2025 14:03:00 +0200 Subject: [PATCH] [client] Tighten allowed domains for dns forwarder (#3978) --- client/internal/dns/upstream.go | 31 +- client/internal/dnsfwd/forwarder.go | 71 +- client/internal/dnsfwd/forwarder_test.go | 630 +++++++++++++++++- .../routemanager/dnsinterceptor/handler.go | 31 +- 4 files changed, 697 insertions(+), 66 deletions(-) diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 2fbfb3b91..c44d36599 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -2,6 +2,7 @@ package dns import ( "context" + "crypto/rand" "crypto/sha256" "encoding/hex" "errors" @@ -103,19 +104,21 @@ func (u *upstreamResolverBase) Stop() { // ServeDNS handles a DNS request func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + requestID := GenerateRequestID() + logger := log.WithField("request_id", requestID) var err error defer func() { u.checkUpstreamFails(err) }() - log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) if r.Extra == nil { r.MsgHdr.AuthenticatedData = true } select { case <-u.ctx.Done(): - log.Tracef("%s has been stopped", u) + logger.Tracef("%s has been stopped", u) return default: } @@ -132,35 +135,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if err != nil { if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { - log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) + logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) continue } - log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) + logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) continue } if rm == nil || !rm.Response { - log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) + logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) continue } u.successCount.Add(1) - log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) + logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) if err = w.WriteMsg(rm); err != nil { - log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) + logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) } // count the fails only if they happen sequentially u.failsCount.Store(0) return } u.failsCount.Add(1) - log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) + logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) m := new(dns.Msg) m.SetRcode(r, dns.RcodeServerFailure) if err := w.WriteMsg(m); err != nil { - log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) + logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) } } @@ -385,3 +388,13 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u return rm, t, nil } + +func GenerateRequestID() string { + bytes := make([]byte, 4) + _, err := rand.Read(bytes) + if err != nil { + log.Errorf("failed to generate request ID: %v", err) + return "" + } + return hex.EncodeToString(bytes) +} diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 45b479632..506c429cd 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -18,14 +18,20 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) const errResolveFailed = "failed to resolve query for domain=%s: %v" const upstreamTimeout = 15 * time.Second +type resolver interface { + LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) +} + +type firewaller interface { + UpdateSet(set firewall.Set, prefixes []netip.Prefix) error +} + type DNSForwarder struct { listenAddress string ttl uint32 @@ -38,16 +44,18 @@ type DNSForwarder struct { mutex sync.RWMutex fwdEntries []*ForwarderEntry - firewall firewall.Manager + firewall firewaller + resolver resolver } -func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder { +func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) return &DNSForwarder{ listenAddress: listenAddress, ttl: ttl, firewall: firewall, statusRecorder: statusRecorder, + resolver: net.DefaultResolver, } } @@ -57,14 +65,17 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { // UDP server mux := dns.NewServeMux() f.mux = mux + mux.HandleFunc(".", f.handleDNSQueryUDP) f.dnsServer = &dns.Server{ Addr: f.listenAddress, Net: "udp", Handler: mux, } + // TCP server tcpMux := dns.NewServeMux() f.tcpMux = tcpMux + tcpMux.HandleFunc(".", f.handleDNSQueryTCP) f.tcpServer = &dns.Server{ Addr: f.listenAddress, Net: "tcp", @@ -87,30 +98,13 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { // return the first error we get (e.g. bind failure or shutdown) return <-errCh } + func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() - if f.mux == nil { - log.Debug("DNS mux is nil, skipping domain update") - f.fwdEntries = entries - return - } - - oldDomains := filterDomains(f.fwdEntries) - for _, d := range oldDomains { - f.mux.HandleRemove(d.PunycodeString()) - f.tcpMux.HandleRemove(d.PunycodeString()) - } - - newDomains := filterDomains(entries) - for _, d := range newDomains { - f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP) - f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP) - } - f.fwdEntries = entries - log.Debugf("Updated domains from %v to %v", oldDomains, newDomains) + log.Debugf("Updated DNS forwarder with %d domains", len(entries)) } func (f *DNSForwarder) Close(ctx context.Context) error { @@ -157,22 +151,31 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns return nil } + mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, ".")) + // query doesn't match any configured domain + if mostSpecificResId == "" { + resp.Rcode = dns.RcodeRefused + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } + return nil + } + ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) defer cancel() - ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain) + ips, err := f.resolver.LookupNetIP(ctx, network, domain) if err != nil { f.handleDNSError(w, query, resp, domain, err) return nil } - f.updateInternalState(domain, ips) + f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.addIPsToResponse(resp, domain, ips) return resp } func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { - resp := f.handleDNSQuery(w, query) if resp == nil { return @@ -206,9 +209,8 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { } } -func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) { +func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) { var prefixes []netip.Prefix - mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, ".")) if mostSpecificResId != "" { for _, ip := range ips { var prefix netip.Prefix @@ -339,16 +341,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar return selectedResId, matches } - -// filterDomains returns a list of normalized domains -func filterDomains(entries []*ForwarderEntry) domain.List { - newDomains := make(domain.List, 0, len(entries)) - for _, d := range entries { - if d.Domain == "" { - log.Warn("empty domain in DNS forwarder") - continue - } - newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString()))) - } - return newDomains -} diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index f0829bbbd..d8228c733 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -1,11 +1,21 @@ package dnsfwd import ( + "context" + "fmt" + "net/netip" + "strings" "testing" + "time" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) @@ -13,7 +23,7 @@ import ( func Test_getMatchingEntries(t *testing.T) { testCases := []struct { name string - storedMappings map[string]route.ResID // key: domain pattern, value: resId + storedMappings map[string]route.ResID queryDomain string expectedResId route.ResID }{ @@ -44,7 +54,7 @@ func Test_getMatchingEntries(t *testing.T) { { name: "Wildcard pattern does not match different domain", storedMappings: map[string]route.ResID{"*.example.com": "res4"}, - queryDomain: "foo.notexample.com", + queryDomain: "foo.example.org", expectedResId: "", }, { @@ -101,3 +111,619 @@ func Test_getMatchingEntries(t *testing.T) { }) } } + +type MockFirewall struct { + mock.Mock +} + +func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + args := m.Called(set, prefixes) + return args.Error(0) +} + +type MockResolver struct { + mock.Mock +} + +func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { + args := m.Called(ctx, network, host) + return args.Get(0).([]netip.Addr), args.Error(1) +} + +func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) { + tests := []struct { + name string + configuredDomain string + queryDomain string + shouldMatch bool + expectedResID route.ResID + description string + }{ + { + name: "exact domain match should be allowed", + configuredDomain: "example.com", + queryDomain: "example.com", + shouldMatch: true, + expectedResID: "test-res-id", + description: "Direct match to configured domain should work", + }, + { + name: "subdomain access should be restricted", + configuredDomain: "example.com", + queryDomain: "mail.example.com", + shouldMatch: false, + expectedResID: "", + description: "Subdomain should not be accessible unless explicitly configured", + }, + { + name: "wildcard should allow subdomains", + configuredDomain: "*.example.com", + queryDomain: "mail.example.com", + shouldMatch: true, + expectedResID: "test-res-id", + description: "Wildcard domains should allow subdomain access", + }, + { + name: "wildcard should allow base domain", + configuredDomain: "*.example.com", + queryDomain: "example.com", + shouldMatch: true, + expectedResID: "test-res-id", + description: "Wildcard should also match the base domain", + }, + { + name: "deep subdomain should be restricted", + configuredDomain: "example.com", + queryDomain: "deep.mail.example.com", + shouldMatch: false, + expectedResID: "", + description: "Deep subdomains should not be accessible", + }, + { + name: "wildcard allows deep subdomains", + configuredDomain: "*.example.com", + queryDomain: "deep.mail.example.com", + shouldMatch: true, + expectedResID: "test-res-id", + description: "Wildcard should allow deep subdomains", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + forwarder := &DNSForwarder{} + + d, err := domain.FromString(tt.configuredDomain) + require.NoError(t, err) + + entries := []*ForwarderEntry{ + { + Domain: d, + ResID: "test-res-id", + }, + } + + forwarder.UpdateDomains(entries) + + resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain) + + if tt.shouldMatch { + assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID") + assert.NotEmpty(t, matchingEntries, "Expected matching entries") + t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain) + } else { + assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match") + assert.Empty(t, matchingEntries, "Expected no matching entries") + t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain) + } + }) + } +} + +func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + tests := []struct { + name string + configuredDomain string + queryDomain string + shouldResolve bool + description string + }{ + { + name: "configured exact domain resolves", + configuredDomain: "example.com", + queryDomain: "example.com", + shouldResolve: true, + description: "Exact match should resolve", + }, + { + name: "unauthorized subdomain blocked", + configuredDomain: "example.com", + queryDomain: "mail.example.com", + shouldResolve: false, + description: "Subdomain should be blocked without wildcard", + }, + { + name: "wildcard allows subdomain", + configuredDomain: "*.example.com", + queryDomain: "mail.example.com", + shouldResolve: true, + description: "Wildcard should allow subdomain", + }, + { + name: "wildcard allows base domain", + configuredDomain: "*.example.com", + queryDomain: "example.com", + shouldResolve: true, + description: "Wildcard should allow base domain", + }, + { + name: "unrelated domain blocked", + configuredDomain: "example.com", + queryDomain: "example.org", + shouldResolve: false, + description: "Unrelated domain should be blocked", + }, + { + name: "deep subdomain blocked", + configuredDomain: "example.com", + queryDomain: "deep.mail.example.com", + shouldResolve: false, + description: "Deep subdomain should be blocked", + }, + { + name: "wildcard allows deep subdomain", + configuredDomain: "*.example.com", + queryDomain: "deep.mail.example.com", + shouldResolve: true, + description: "Wildcard should allow deep subdomain", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockFirewall := &MockFirewall{} + mockResolver := &MockResolver{} + + if tt.shouldResolve { + mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil) + + // Mock successful DNS resolution + fakeIP := netip.MustParseAddr("1.2.3.4") + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil) + } + + forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder.resolver = mockResolver + + d, err := domain.FromString(tt.configuredDomain) + require.NoError(t, err) + + entries := []*ForwarderEntry{ + { + Domain: d, + ResID: "test-res-id", + Set: firewall.NewDomainSet([]domain.Domain{d}), + }, + } + + forwarder.UpdateDomains(entries) + + query := &dns.Msg{} + query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA) + + mockWriter := &test.MockResponseWriter{} + resp := forwarder.handleDNSQuery(mockWriter, query) + + if tt.shouldResolve { + require.NotNil(t, resp, "Expected response for authorized domain") + require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response") + assert.NotEmpty(t, resp.Answer, "Expected DNS answer records") + + time.Sleep(10 * time.Millisecond) + mockFirewall.AssertExpectations(t) + mockResolver.AssertExpectations(t) + } else { + if resp != nil { + assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess, + "Unauthorized domain should not return successful answers") + } + mockFirewall.AssertNotCalled(t, "UpdateSet") + mockResolver.AssertNotCalled(t, "LookupNetIP") + } + }) + } +} + +func TestDNSForwarder_FirewallSetUpdates(t *testing.T) { + tests := []struct { + name string + configuredDomains []string + query string + mockIP string + shouldResolve bool + expectedSetCount int // How many sets should be updated + description string + }{ + { + name: "exact domain gets firewall update", + configuredDomains: []string{"example.com"}, + query: "example.com", + mockIP: "1.1.1.1", + shouldResolve: true, + expectedSetCount: 1, + description: "Single exact match updates one set", + }, + { + name: "wildcard domain gets firewall update", + configuredDomains: []string{"*.example.com"}, + query: "mail.example.com", + mockIP: "1.1.1.2", + shouldResolve: true, + expectedSetCount: 1, + description: "Wildcard match updates one set", + }, + { + name: "overlapping exact and wildcard both get updates", + configuredDomains: []string{"*.example.com", "mail.example.com"}, + query: "mail.example.com", + mockIP: "1.1.1.3", + shouldResolve: true, + expectedSetCount: 2, + description: "Both exact and wildcard sets should be updated", + }, + { + name: "unauthorized domain gets no firewall update", + configuredDomains: []string{"example.com"}, + query: "mail.example.com", + mockIP: "1.1.1.4", + shouldResolve: false, + expectedSetCount: 0, + description: "No firewall update for unauthorized domains", + }, + { + name: "multiple wildcards matching get all updated", + configuredDomains: []string{"*.example.com", "*.sub.example.com"}, + query: "test.sub.example.com", + mockIP: "1.1.1.5", + shouldResolve: true, + expectedSetCount: 2, + description: "All matching wildcard sets should be updated", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockFirewall := &MockFirewall{} + mockResolver := &MockResolver{} + + // Set up forwarder + forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder.resolver = mockResolver + + // Create entries and track sets + var entries []*ForwarderEntry + sets := make([]firewall.Set, 0) + + for i, configDomain := range tt.configuredDomains { + d, err := domain.FromString(configDomain) + require.NoError(t, err) + + set := firewall.NewDomainSet([]domain.Domain{d}) + sets = append(sets, set) + + entries = append(entries, &ForwarderEntry{ + Domain: d, + ResID: route.ResID(fmt.Sprintf("res-%d", i)), + Set: set, + }) + } + + forwarder.UpdateDomains(entries) + + // Set up mocks + if tt.shouldResolve { + fakeIP := netip.MustParseAddr(tt.mockIP) + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)). + Return([]netip.Addr{fakeIP}, nil).Once() + + expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)} + + // Count how many sets should actually match + updateCount := 0 + for i, entry := range entries { + domain := strings.ToLower(tt.query) + pattern := entry.Domain.PunycodeString() + + matches := false + if strings.HasPrefix(pattern, "*.") { + baseDomain := strings.TrimPrefix(pattern, "*.") + if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) { + matches = true + } + } else if domain == pattern { + matches = true + } + + if matches { + mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once() + updateCount++ + } + } + + assert.Equal(t, tt.expectedSetCount, updateCount, + "Expected %d sets to be updated, but mock expects %d", + tt.expectedSetCount, updateCount) + } + + // Execute query + dnsQuery := &dns.Msg{} + dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA) + + mockWriter := &test.MockResponseWriter{} + resp := forwarder.handleDNSQuery(mockWriter, dnsQuery) + + // Verify response + if tt.shouldResolve { + require.NotNil(t, resp, "Expected response for authorized domain") + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.NotEmpty(t, resp.Answer) + } else if resp != nil { + assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0, + "Unauthorized domain should be refused or have no answers") + } + + // Verify all mock expectations were met + mockFirewall.AssertExpectations(t) + mockResolver.AssertExpectations(t) + }) + } +} + +// Test to verify that multiple IPs for one domain result in all prefixes being sent together +func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) { + mockFirewall := &MockFirewall{} + mockResolver := &MockResolver{} + + forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder.resolver = mockResolver + + // Configure a single domain + d, err := domain.FromString("example.com") + require.NoError(t, err) + + set := firewall.NewDomainSet([]domain.Domain{d}) + entries := []*ForwarderEntry{{ + Domain: d, + ResID: "test-res", + Set: set, + }} + + forwarder.UpdateDomains(entries) + + // Mock resolver returns multiple IPs + ips := []netip.Addr{ + netip.MustParseAddr("1.1.1.1"), + netip.MustParseAddr("1.1.1.2"), + netip.MustParseAddr("1.1.1.3"), + } + mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com."). + Return(ips, nil).Once() + + // Expect ONE UpdateSet call with ALL prefixes + expectedPrefixes := []netip.Prefix{ + netip.PrefixFrom(ips[0], 32), + netip.PrefixFrom(ips[1], 32), + netip.PrefixFrom(ips[2], 32), + } + mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once() + + // Execute query + query := &dns.Msg{} + query.SetQuestion("example.com.", dns.TypeA) + + mockWriter := &test.MockResponseWriter{} + resp := forwarder.handleDNSQuery(mockWriter, query) + + // Verify response contains all IPs + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 3, "Should have 3 answer records") + + // Verify mocks + mockFirewall.AssertExpectations(t) + mockResolver.AssertExpectations(t) +} + +func TestDNSForwarder_ResponseCodes(t *testing.T) { + tests := []struct { + name string + queryType uint16 + queryDomain string + configured string + expectedCode int + description string + }{ + { + name: "unauthorized domain returns REFUSED", + queryType: dns.TypeA, + queryDomain: "evil.com", + configured: "example.com", + expectedCode: dns.RcodeRefused, + description: "RFC compliant REFUSED for unauthorized queries", + }, + { + name: "unsupported query type returns NOTIMP", + queryType: dns.TypeMX, + queryDomain: "example.com", + configured: "example.com", + expectedCode: dns.RcodeNotImplemented, + description: "RFC compliant NOTIMP for unsupported types", + }, + { + name: "CNAME query returns NOTIMP", + queryType: dns.TypeCNAME, + queryDomain: "example.com", + configured: "example.com", + expectedCode: dns.RcodeNotImplemented, + description: "CNAME queries not supported", + }, + { + name: "TXT query returns NOTIMP", + queryType: dns.TypeTXT, + queryDomain: "example.com", + configured: "example.com", + expectedCode: dns.RcodeNotImplemented, + description: "TXT queries not supported", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + + d, err := domain.FromString(tt.configured) + require.NoError(t, err) + + entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}} + forwarder.UpdateDomains(entries) + + query := &dns.Msg{} + query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType) + + // Capture the written response + var writtenResp *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + writtenResp = m + return nil + }, + } + + _ = forwarder.handleDNSQuery(mockWriter, query) + + // Check the response written to the writer + require.NotNil(t, writtenResp, "Expected response to be written") + assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description) + }) + } +} + +func TestDNSForwarder_TCPTruncation(t *testing.T) { + // Test that large UDP responses are truncated with TC bit set + mockResolver := &MockResolver{} + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder.resolver = mockResolver + + d, _ := domain.FromString("example.com") + entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}} + forwarder.UpdateDomains(entries) + + // Mock many IPs to create a large response + var manyIPs []netip.Addr + for i := 0; i < 100; i++ { + manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256))) + } + mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil) + + // Query without EDNS0 + query := &dns.Msg{} + query.SetQuestion("example.com.", dns.TypeA) + + var writtenResp *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + writtenResp = m + return nil + }, + } + forwarder.handleDNSQueryUDP(mockWriter, query) + + require.NotNil(t, writtenResp) + assert.True(t, writtenResp.Truncated, "Large response should be truncated") + assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size") +} + +func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { + // Test complex overlapping pattern scenarios + mockFirewall := &MockFirewall{} + mockResolver := &MockResolver{} + + forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder.resolver = mockResolver + + // Set up complex overlapping patterns + patterns := []string{ + "*.example.com", // Matches all subdomains + "*.mail.example.com", // More specific wildcard + "smtp.mail.example.com", // Exact match + "example.com", // Base domain + } + + var entries []*ForwarderEntry + sets := make(map[string]firewall.Set) + + for _, pattern := range patterns { + d, _ := domain.FromString(pattern) + set := firewall.NewDomainSet([]domain.Domain{d}) + sets[pattern] = set + entries = append(entries, &ForwarderEntry{ + Domain: d, + ResID: route.ResID("res-" + pattern), + Set: set, + }) + } + + forwarder.UpdateDomains(entries) + + // Test smtp.mail.example.com - should match 3 patterns + fakeIP := netip.MustParseAddr("1.2.3.4") + mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil) + + expectedPrefix := netip.PrefixFrom(fakeIP, 32) + // All three matching patterns should get firewall updates + mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil) + mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil) + mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil) + + query := &dns.Msg{} + query.SetQuestion("smtp.mail.example.com.", dns.TypeA) + + mockWriter := &test.MockResponseWriter{} + resp := forwarder.handleDNSQuery(mockWriter, query) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + + // Verify all three sets were updated + mockFirewall.AssertExpectations(t) + + // Verify the most specific ResID was selected + // (exact match should win over wildcards) + resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com") + assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID) + assert.Len(t, matches, 3, "Should match 3 patterns") +} + +func TestDNSForwarder_EmptyQuery(t *testing.T) { + // Test handling of malformed query with no questions + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + + query := &dns.Msg{} + // Don't set any question + + writeCalled := false + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + writeCalled = true + return nil + }, + } + resp := forwarder.handleDNSQuery(mockWriter, query) + + assert.Nil(t, resp, "Should return nil for empty query") + assert.False(t, writeCalled, "Should not write response for empty query") +} diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 23478c88c..d643d1e32 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -144,15 +144,18 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error { // ServeDNS implements the dns.Handler interface func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + requestID := nbdns.GenerateRequestID() + logger := log.WithField("request_id", requestID) + if len(r.Question) == 0 { return } - log.Tracef("received DNS request for domain=%s type=%v class=%v", + logger.Tracef("received DNS request for domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) // pass if non A/AAAA query if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA { - d.continueToNextHandler(w, r, "non A/AAAA query") + d.continueToNextHandler(w, r, logger, "non A/AAAA query") return } @@ -161,13 +164,13 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { d.mu.RUnlock() if peerKey == "" { - d.writeDNSError(w, r, "no current peer key") + d.writeDNSError(w, r, logger, "no current peer key") return } upstreamIP, err := d.getUpstreamIP(peerKey) if err != nil { - d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err)) + d.writeDNSError(w, r, logger, fmt.Sprintf("get upstream IP: %v", err)) return } @@ -184,9 +187,9 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) if err != nil { - log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) + logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { - log.Errorf("failed writing DNS response: %v", err) + logger.Errorf("failed writing DNS response: %v", err) } return } @@ -196,34 +199,34 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { answer = reply.Answer } - log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer) + logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer) reply.Id = r.Id if err := d.writeMsg(w, reply); err != nil { - log.Errorf("failed writing DNS response: %v", err) + logger.Errorf("failed writing DNS response: %v", err) } } -func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) { - log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason) +func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) { + logger.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason) resp := new(dns.Msg) resp.SetRcode(r, dns.RcodeServerFailure) if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write DNS error response: %v", err) + logger.Errorf("failed to write DNS error response: %v", err) } } // continueToNextHandler signals the handler chain to try the next handler -func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) { - log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) +func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) { + logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) resp := new(dns.Msg) resp.SetRcode(r, dns.RcodeNameError) // Set Zero bit to signal handler chain to continue resp.MsgHdr.Zero = true if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed writing DNS continue response: %v", err) + logger.Errorf("failed writing DNS continue response: %v", err) } }