mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-23 19:21:23 +02:00
[client] Fix local resolver returning error for existing domains with other types (#3959)
This commit is contained in:
parent
0ad2590974
commit
3e43298471
@ -12,16 +12,19 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
type Resolver struct {
|
||||
mu sync.RWMutex
|
||||
records map[dns.Question][]dns.RR
|
||||
domains map[domain.Domain]struct{}
|
||||
}
|
||||
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question][]dns.RR),
|
||||
domains: make(map[domain.Domain]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,8 +67,12 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
replyMessage.Rcode = dns.RcodeSuccess
|
||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
||||
} else {
|
||||
// TODO: return success if we have a different record type for the same name, relevant for search domains
|
||||
replyMessage.Rcode = dns.RcodeNameError
|
||||
// Check if we have any records for this domain name with different types
|
||||
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
|
||||
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
|
||||
} else {
|
||||
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
|
||||
}
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(replyMessage); err != nil {
|
||||
@ -73,6 +80,15 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
|
||||
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
_, exists := d.domains[domainName]
|
||||
return exists
|
||||
}
|
||||
|
||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
||||
d.mu.RLock()
|
||||
@ -111,6 +127,7 @@ func (d *Resolver) Update(update []nbdns.SimpleRecord) {
|
||||
defer d.mu.Unlock()
|
||||
|
||||
maps.Clear(d.records)
|
||||
maps.Clear(d.domains)
|
||||
|
||||
for _, rec := range update {
|
||||
if err := d.registerRecord(rec); err != nil {
|
||||
@ -144,6 +161,7 @@ func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
|
||||
}
|
||||
|
||||
d.records[q] = append(d.records[q], rr)
|
||||
d.domains[domain.Domain(q.Name)] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -470,3 +470,115 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_NoErrorWithDifferentRecordType verifies that querying for a record type
|
||||
// that doesn't exist but where other record types exist for the same domain returns NOERROR
|
||||
// with 0 records instead of NXDOMAIN
|
||||
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "example.netbird.cloud.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.1.100",
|
||||
}
|
||||
|
||||
recordCNAME := nbdns.SimpleRecord{
|
||||
Name: "alias.netbird.cloud.",
|
||||
Type: int(dns.TypeCNAME),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "target.example.com.",
|
||||
}
|
||||
|
||||
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
queryName string
|
||||
queryType uint16
|
||||
expectedRcode int
|
||||
shouldHaveData bool
|
||||
}{
|
||||
{
|
||||
name: "Query A record that exists",
|
||||
queryName: "example.netbird.cloud.",
|
||||
queryType: dns.TypeA,
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
shouldHaveData: true,
|
||||
},
|
||||
{
|
||||
name: "Query AAAA for domain with only A record",
|
||||
queryName: "example.netbird.cloud.",
|
||||
queryType: dns.TypeAAAA,
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
shouldHaveData: false,
|
||||
},
|
||||
{
|
||||
name: "Query other record with different case and non-fqdn",
|
||||
queryName: "EXAMPLE.netbird.cloud",
|
||||
queryType: dns.TypeAAAA,
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
shouldHaveData: false,
|
||||
},
|
||||
{
|
||||
name: "Query TXT for domain with only A record",
|
||||
queryName: "example.netbird.cloud.",
|
||||
queryType: dns.TypeTXT,
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
shouldHaveData: false,
|
||||
},
|
||||
{
|
||||
name: "Query A for domain with only CNAME record",
|
||||
queryName: "alias.netbird.cloud.",
|
||||
queryType: dns.TypeA,
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
shouldHaveData: true,
|
||||
},
|
||||
{
|
||||
name: "Query AAAA for domain with only CNAME record",
|
||||
queryName: "alias.netbird.cloud.",
|
||||
queryType: dns.TypeAAAA,
|
||||
expectedRcode: dns.RcodeSuccess,
|
||||
shouldHaveData: true,
|
||||
},
|
||||
{
|
||||
name: "Query for completely non-existent domain",
|
||||
queryName: "nonexistent.netbird.cloud.",
|
||||
queryType: dns.TypeA,
|
||||
expectedRcode: dns.RcodeNameError,
|
||||
shouldHaveData: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var responseMSG *dns.Msg
|
||||
|
||||
msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType)
|
||||
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, msg)
|
||||
|
||||
require.NotNil(t, responseMSG, "Should have received a response message")
|
||||
|
||||
assert.Equal(t, tc.expectedRcode, responseMSG.Rcode,
|
||||
"Response code should be %d (%s)",
|
||||
tc.expectedRcode, dns.RcodeToString[tc.expectedRcode])
|
||||
|
||||
if tc.shouldHaveData {
|
||||
assert.Greater(t, len(responseMSG.Answer), 0, "Response should contain answers")
|
||||
} else {
|
||||
assert.Equal(t, 0, len(responseMSG.Answer), "Response should contain no answers")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user