diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 506c429cd..d912919a1 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -165,7 +165,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns defer cancel() ips, err := f.resolver.LookupNetIP(ctx, network, domain) if err != nil { - f.handleDNSError(w, query, resp, domain, err) + f.handleDNSError(ctx, w, question, resp, domain, err) return nil } @@ -244,20 +244,57 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe } } +// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true +// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type) +// +// LIMITATION: This function only checks A and AAAA record types to determine domain existence. +// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records, +// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder +// only handles A/AAAA queries and returns NOTIMP for other types. +func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) { + // Try querying for a different record type to see if the domain exists + // If the original query was for AAAA, try A. If it was for A, try AAAA. + // This helps distinguish between NXDOMAIN and NODATA. + var alternativeNetwork string + switch originalQtype { + case dns.TypeAAAA: + alternativeNetwork = "ip4" + case dns.TypeA: + alternativeNetwork = "ip6" + default: + resp.Rcode = dns.RcodeNameError + return + } + + if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil { + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) && dnsErr.IsNotFound { + // Alternative query also returned not found - domain truly doesn't exist + resp.Rcode = dns.RcodeNameError + return + } + // Some other error (timeout, server failure, etc.) - can't determine, assume domain exists + resp.Rcode = dns.RcodeSuccess + return + } + + // Alternative query succeeded - domain exists but has no records of this type + resp.Rcode = dns.RcodeSuccess +} + // handleDNSError processes DNS lookup errors and sends an appropriate error response -func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) { +func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) { var dnsErr *net.DNSError switch { case errors.As(err, &dnsErr): resp.Rcode = dns.RcodeServerFailure if dnsErr.IsNotFound { - // Pass through NXDOMAIN - resp.Rcode = dns.RcodeNameError + f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype) } if dnsErr.Server != "" { - log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err) + log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err) } else { log.Warnf(errResolveFailed, domain, err) } diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index c820fbb60..57085e19a 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -3,6 +3,7 @@ package dnsfwd import ( "context" "fmt" + "net" "net/netip" "strings" "testing" @@ -16,8 +17,8 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) func Test_getMatchingEntries(t *testing.T) { @@ -708,6 +709,131 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { assert.Len(t, matches, 3, "Should match 3 patterns") } +// TestDNSForwarder_NodataVsNxdomain tests that the forwarder correctly distinguishes +// between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of that type) +func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { + mockFirewall := &MockFirewall{} + mockResolver := &MockResolver{} + + forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder.resolver = mockResolver + + d, err := domain.FromString("example.com") + require.NoError(t, err) + + set := firewall.NewDomainSet([]domain.Domain{d}) + entries := []*ForwarderEntry{{Domain: d, ResID: "test-res", Set: set}} + forwarder.UpdateDomains(entries) + + tests := []struct { + name string + queryType uint16 + setupMocks func() + expectedCode int + expectNoAnswer bool // true if we expect NOERROR with empty answer (NODATA case) + description string + }{ + { + name: "domain exists but no AAAA records (NODATA)", + queryType: dns.TypeAAAA, + setupMocks: func() { + // First query for AAAA returns not found + mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com."). + Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once() + // Check query for A records succeeds (domain exists) + mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com."). + Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once() + }, + expectedCode: dns.RcodeSuccess, + expectNoAnswer: true, + description: "Should return NOERROR when domain exists but has no records of requested type", + }, + { + name: "domain exists but no A records (NODATA)", + queryType: dns.TypeA, + setupMocks: func() { + // First query for A returns not found + mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com."). + Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once() + // Check query for AAAA records succeeds (domain exists) + mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com."). + Return([]netip.Addr{netip.MustParseAddr("2001:db8::1")}, nil).Once() + }, + expectedCode: dns.RcodeSuccess, + expectNoAnswer: true, + description: "Should return NOERROR when domain exists but has no A records", + }, + { + name: "domain doesn't exist (NXDOMAIN)", + queryType: dns.TypeA, + setupMocks: func() { + // First query for A returns not found + mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com."). + Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once() + // Check query for AAAA also returns not found (domain doesn't exist) + mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com."). + Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once() + }, + expectedCode: dns.RcodeNameError, + expectNoAnswer: true, + description: "Should return NXDOMAIN when domain doesn't exist at all", + }, + { + name: "domain exists with records (normal success)", + queryType: dns.TypeA, + setupMocks: func() { + mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com."). + Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once() + // Expect firewall update for successful resolution + expectedPrefix := netip.PrefixFrom(netip.MustParseAddr("1.2.3.4"), 32) + mockFirewall.On("UpdateSet", set, []netip.Prefix{expectedPrefix}).Return(nil).Once() + }, + expectedCode: dns.RcodeSuccess, + expectNoAnswer: false, + description: "Should return NOERROR with answer when records exist", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset mock expectations + mockResolver.ExpectedCalls = nil + mockResolver.Calls = nil + mockFirewall.ExpectedCalls = nil + mockFirewall.Calls = nil + + tt.setupMocks() + + query := &dns.Msg{} + query.SetQuestion(dns.Fqdn("example.com"), tt.queryType) + + var writtenResp *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + writtenResp = m + return nil + }, + } + + resp := forwarder.handleDNSQuery(mockWriter, query) + + // If a response was returned, it means it should be written (happens in wrapper functions) + if resp != nil && writtenResp == nil { + writtenResp = resp + } + + require.NotNil(t, writtenResp, "Expected response to be written") + assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description) + + if tt.expectNoAnswer { + assert.Empty(t, writtenResp.Answer, "Response should have no answer records") + } + + mockResolver.AssertExpectations(t) + }) + } +} + func TestDNSForwarder_EmptyQuery(t *testing.T) { // Test handling of malformed query with no questions forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})