mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-17 18:41:41 +02:00
[client] Distinguish between NXDOMAIN and NODATA in the dns forwarder (#4321)
This commit is contained in:
@ -165,7 +165,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
|
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.handleDNSError(w, query, resp, domain, err)
|
f.handleDNSError(ctx, w, question, resp, domain, err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -244,20 +244,57 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
|
||||||
|
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
|
||||||
|
//
|
||||||
|
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
|
||||||
|
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
|
||||||
|
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
|
||||||
|
// only handles A/AAAA queries and returns NOTIMP for other types.
|
||||||
|
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
|
||||||
|
// Try querying for a different record type to see if the domain exists
|
||||||
|
// If the original query was for AAAA, try A. If it was for A, try AAAA.
|
||||||
|
// This helps distinguish between NXDOMAIN and NODATA.
|
||||||
|
var alternativeNetwork string
|
||||||
|
switch originalQtype {
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
alternativeNetwork = "ip4"
|
||||||
|
case dns.TypeA:
|
||||||
|
alternativeNetwork = "ip6"
|
||||||
|
default:
|
||||||
|
resp.Rcode = dns.RcodeNameError
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
|
||||||
|
var dnsErr *net.DNSError
|
||||||
|
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
|
||||||
|
// Alternative query also returned not found - domain truly doesn't exist
|
||||||
|
resp.Rcode = dns.RcodeNameError
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
|
||||||
|
resp.Rcode = dns.RcodeSuccess
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Alternative query succeeded - domain exists but has no records of this type
|
||||||
|
resp.Rcode = dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||||
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) {
|
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case errors.As(err, &dnsErr):
|
case errors.As(err, &dnsErr):
|
||||||
resp.Rcode = dns.RcodeServerFailure
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
if dnsErr.IsNotFound {
|
if dnsErr.IsNotFound {
|
||||||
// Pass through NXDOMAIN
|
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
|
||||||
resp.Rcode = dns.RcodeNameError
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if dnsErr.Server != "" {
|
if dnsErr.Server != "" {
|
||||||
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err)
|
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
|
||||||
} else {
|
} else {
|
||||||
log.Warnf(errResolveFailed, domain, err)
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package dnsfwd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@ -16,8 +17,8 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_getMatchingEntries(t *testing.T) {
|
func Test_getMatchingEntries(t *testing.T) {
|
||||||
@ -708,6 +709,131 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
|||||||
assert.Len(t, matches, 3, "Should match 3 patterns")
|
assert.Len(t, matches, 3, "Should match 3 patterns")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestDNSForwarder_NodataVsNxdomain tests that the forwarder correctly distinguishes
|
||||||
|
// between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of that type)
|
||||||
|
func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||||
|
mockFirewall := &MockFirewall{}
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
d, err := domain.FromString("example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||||
|
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res", Set: set}}
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryType uint16
|
||||||
|
setupMocks func()
|
||||||
|
expectedCode int
|
||||||
|
expectNoAnswer bool // true if we expect NOERROR with empty answer (NODATA case)
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "domain exists but no AAAA records (NODATA)",
|
||||||
|
queryType: dns.TypeAAAA,
|
||||||
|
setupMocks: func() {
|
||||||
|
// First query for AAAA returns not found
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
|
||||||
|
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
|
||||||
|
// Check query for A records succeeds (domain exists)
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||||
|
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
|
||||||
|
},
|
||||||
|
expectedCode: dns.RcodeSuccess,
|
||||||
|
expectNoAnswer: true,
|
||||||
|
description: "Should return NOERROR when domain exists but has no records of requested type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "domain exists but no A records (NODATA)",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
setupMocks: func() {
|
||||||
|
// First query for A returns not found
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||||
|
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
|
||||||
|
// Check query for AAAA records succeeds (domain exists)
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
|
||||||
|
Return([]netip.Addr{netip.MustParseAddr("2001:db8::1")}, nil).Once()
|
||||||
|
},
|
||||||
|
expectedCode: dns.RcodeSuccess,
|
||||||
|
expectNoAnswer: true,
|
||||||
|
description: "Should return NOERROR when domain exists but has no A records",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "domain doesn't exist (NXDOMAIN)",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
setupMocks: func() {
|
||||||
|
// First query for A returns not found
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||||
|
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
|
||||||
|
// Check query for AAAA also returns not found (domain doesn't exist)
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
|
||||||
|
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
|
||||||
|
},
|
||||||
|
expectedCode: dns.RcodeNameError,
|
||||||
|
expectNoAnswer: true,
|
||||||
|
description: "Should return NXDOMAIN when domain doesn't exist at all",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "domain exists with records (normal success)",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
setupMocks: func() {
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||||
|
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
|
||||||
|
// Expect firewall update for successful resolution
|
||||||
|
expectedPrefix := netip.PrefixFrom(netip.MustParseAddr("1.2.3.4"), 32)
|
||||||
|
mockFirewall.On("UpdateSet", set, []netip.Prefix{expectedPrefix}).Return(nil).Once()
|
||||||
|
},
|
||||||
|
expectedCode: dns.RcodeSuccess,
|
||||||
|
expectNoAnswer: false,
|
||||||
|
description: "Should return NOERROR with answer when records exist",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset mock expectations
|
||||||
|
mockResolver.ExpectedCalls = nil
|
||||||
|
mockResolver.Calls = nil
|
||||||
|
mockFirewall.ExpectedCalls = nil
|
||||||
|
mockFirewall.Calls = nil
|
||||||
|
|
||||||
|
tt.setupMocks()
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
||||||
|
|
||||||
|
var writtenResp *dns.Msg
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writtenResp = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
// If a response was returned, it means it should be written (happens in wrapper functions)
|
||||||
|
if resp != nil && writtenResp == nil {
|
||||||
|
writtenResp = resp
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
|
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
||||||
|
|
||||||
|
if tt.expectNoAnswer {
|
||||||
|
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
|
||||||
|
}
|
||||||
|
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||||
// Test handling of malformed query with no questions
|
// Test handling of malformed query with no questions
|
||||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
Reference in New Issue
Block a user