[client] Fix stale local records (#3776)

This commit is contained in:
Viktor Liu
2025-05-05 14:29:05 +02:00
committed by GitHub
parent ffdd115ded
commit 9762b39f29
13 changed files with 786 additions and 416 deletions

View File

@ -1,7 +1,6 @@
package dns_test
import (
"net"
"testing"
"github.com/miekg/dns"
@ -9,6 +8,7 @@ import (
"github.com/stretchr/testify/mock"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
)
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
@ -30,7 +30,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
r.SetQuestion("example.com.", dns.TypeA)
// Create test writer
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup expectations - only highest priority handler should be called
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
@ -142,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r)
@ -259,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
// Create and execute request
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r)
// Verify expectations
@ -316,7 +316,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
}).Once()
// Execute
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r)
// Verify all handlers were called in order
@ -325,20 +325,6 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
handler3.AssertExpectations(t)
}
// mockResponseWriter implements dns.ResponseWriter for testing
type mockResponseWriter struct {
mock.Mock
}
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (m *mockResponseWriter) Close() error { return nil }
func (m *mockResponseWriter) TsigStatus() error { return nil }
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
func (m *mockResponseWriter) Hijack() {}
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
tests := []struct {
name string
@ -425,7 +411,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
// Create test request
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup expectations
for priority, handler := range handlers {
@ -471,7 +457,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
// Test 1: Initial state
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Highest priority handler (routeHandler) should be called
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
@ -490,7 +476,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 2: Remove highest priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now middle priority handler (matchHandler) should be called
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
@ -506,7 +492,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
@ -519,7 +505,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
for _, m := range mocks {
@ -675,7 +661,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
// Execute request
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
chain.ServeDNS(&mockResponseWriter{}, r)
chain.ServeDNS(&test.MockResponseWriter{}, r)
// Verify each handler was called exactly as expected
for _, h := range tt.addHandlers {
@ -819,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup handler expectations
for pattern, handler := range handlers {
@ -969,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
handler := &nbdns.MockHandler{}
r := new(dns.Msg)
r.SetQuestion(tt.queryPattern, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// First verify no handler is called before adding any
chain.ServeDNS(w, r)

View File

@ -1,130 +0,0 @@
package dns
import (
"fmt"
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
)
type registrationMap map[string]struct{}
type localResolver struct {
registeredMap registrationMap
records sync.Map // key: string (domain_class_type), value: []dns.RR
}
func (d *localResolver) MatchSubdomains() bool {
return true
}
func (d *localResolver) stop() {
}
// String returns a string representation of the local resolver
func (d *localResolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
}
// ID returns the unique handler ID
func (d *localResolver) id() handlerID {
return "local-resolver"
}
// ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) > 0 {
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
}
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
// lookup all records matching the question
records := d.lookupRecords(r)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
replyMessage.Rcode = dns.RcodeNameError
}
err := w.WriteMsg(replyMessage)
if err != nil {
log.Debugf("got an error while writing the local resolver response, error: %v", err)
}
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
if len(r.Question) == 0 {
return nil
}
question := r.Question[0]
question.Name = strings.ToLower(question.Name)
key := buildRecordKey(question.Name, question.Qclass, question.Qtype)
value, found := d.records.Load(key)
if !found {
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
r.Question[0].Qtype = dns.TypeCNAME
return d.lookupRecords(r)
}
return nil
}
records, ok := value.([]dns.RR)
if !ok {
log.Errorf("failed to cast records to []dns.RR, records: %v", value)
return nil
}
// if there's more than one record, rotate them (round-robin)
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records.Store(key, records)
}
return records
}
// registerRecord stores a new record by appending it to any existing list
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) {
rr, err := dns.NewRR(record.String())
if err != nil {
return "", fmt.Errorf("register record: %w", err)
}
rr.Header().Rdlength = record.Len()
header := rr.Header()
key := buildRecordKey(header.Name, header.Class, header.Rrtype)
// load any existing slice of records, then append
existing, _ := d.records.LoadOrStore(key, []dns.RR{})
records := existing.([]dns.RR)
records = append(records, rr)
// store updated slice
d.records.Store(key, records)
return key, nil
}
// deleteRecord removes *all* records under the recordKey.
func (d *localResolver) deleteRecord(recordKey string) {
d.records.Delete(dns.Fqdn(recordKey))
}
// buildRecordKey consistently generates a key: name_class_type
func buildRecordKey(name string, class, qType uint16) string {
return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType)
}
func (d *localResolver) probeAvailability() {}

View File

@ -0,0 +1,149 @@
package local
import (
"fmt"
"slices"
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns"
)
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
}
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question][]dns.RR),
}
}
func (d *Resolver) MatchSubdomains() bool {
return true
}
// String returns a string representation of the local resolver
func (d *Resolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.records))
}
func (d *Resolver) Stop() {}
// ID returns the unique handler ID
func (d *Resolver) ID() types.HandlerID {
return "local-resolver"
}
func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
log.Debugf("received local resolver request with no question")
return
}
question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name))
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
// lookup all records matching the question
records := d.lookupRecords(question)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
// TODO: return success if we have a different record type for the same name, relevant for search domains
replyMessage.Rcode = dns.RcodeNameError
}
if err := w.WriteMsg(replyMessage); err != nil {
log.Warnf("failed to write the local resolver response: %v", err)
}
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.RLock()
records, found := d.records[question]
if !found {
d.mu.RUnlock()
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
question.Qtype = dns.TypeCNAME
return d.lookupRecords(question)
}
return nil
}
recordsCopy := slices.Clone(records)
d.mu.RUnlock()
// if there's more than one record, rotate them (round-robin)
if len(recordsCopy) > 1 {
d.mu.Lock()
records = d.records[question]
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records[question] = records
}
d.mu.Unlock()
}
return recordsCopy
}
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
for _, rec := range update {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
continue
}
}
}
// RegisterRecord stores a new record by appending it to any existing list
func (d *Resolver) RegisterRecord(record nbdns.SimpleRecord) error {
d.mu.Lock()
defer d.mu.Unlock()
return d.registerRecord(record)
}
// registerRecord performs the registration with the lock already held
func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
rr, err := dns.NewRR(record.String())
if err != nil {
return fmt.Errorf("register record: %w", err)
}
rr.Header().Rdlength = record.Len()
header := rr.Header()
q := dns.Question{
Name: strings.ToLower(dns.Fqdn(header.Name)),
Qtype: header.Rrtype,
Qclass: header.Class,
}
d.records[q] = append(d.records[q], rr)
return nil
}

View File

@ -0,0 +1,472 @@
package local
import (
"strings"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
}
recordCNAME := nbdns.SimpleRecord{
Name: "peerb.netbird.cloud.",
Type: 5,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "www.netbird.io",
}
testCases := []struct {
name string
inputRecord nbdns.SimpleRecord
inputMSG *dns.Msg
responseShouldBeNil bool
}{
{
name: "Should Resolve A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
},
{
name: "Should Resolve CNAME Record",
inputRecord: recordCNAME,
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
},
{
name: "Should Not Write When Not Found A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
responseShouldBeNil: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
resolver := NewResolver()
_ = resolver.RegisterRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil || len(responseMSG.Answer) == 0 {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
answerString := responseMSG.Answer[0].String()
if !strings.Contains(answerString, testCase.inputRecord.Name) {
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
}
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
}
if !strings.Contains(answerString, testCase.inputRecord.RData) {
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
}
})
}
}
// TestLocalResolver_Update_StaleRecord verifies that updating
// a record correctly replaces the old one, preventing stale entries.
func TestLocalResolver_Update_StaleRecord(t *testing.T) {
recordName := "host.example.com."
recordType := dns.TypeA
recordClass := dns.ClassINET
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "2.2.2.2",
}
recordKey := dns.Question{Name: recordName, Qtype: uint16(recordClass), Qclass: recordType}
resolver := NewResolver()
update1 := []nbdns.SimpleRecord{record1}
update2 := []nbdns.SimpleRecord{record2}
// Apply first update
resolver.Update(update1)
// Verify first update
resolver.mu.RLock()
rrSlice1, found1 := resolver.records[recordKey]
resolver.mu.RUnlock()
require.True(t, found1, "Record key %s not found after first update", recordKey)
require.Len(t, rrSlice1, 1, "Should have exactly 1 record after first update")
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
// Apply second update
resolver.Update(update2)
// Verify second update
resolver.mu.RLock()
rrSlice2, found2 := resolver.records[recordKey]
resolver.mu.RUnlock()
require.True(t, found2, "Record key %s not found after second update", recordKey)
require.Len(t, rrSlice2, 1, "Should have exactly 1 record after update overwriting the key")
assert.Contains(t, rrSlice2[0].String(), record2.RData, "The single record should be the updated one (%s)", record2.RData)
assert.NotContains(t, rrSlice2[0].String(), record1.RData, "The stale record (%s) should not be present", record1.RData)
}
// TestLocalResolver_MultipleRecords_SameQuestion verifies that multiple records
// with the same question are stored properly
func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
resolver := NewResolver()
recordName := "multi.example.com."
recordType := dns.TypeA
// Create two records with the same name and type but different IPs
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
}
update := []nbdns.SimpleRecord{record1, record2}
// Apply update with both records
resolver.Update(update)
// Create question that matches both records
question := dns.Question{
Name: recordName,
Qtype: recordType,
Qclass: dns.ClassINET,
}
// Verify both records are stored
resolver.mu.RLock()
records, found := resolver.records[question]
resolver.mu.RUnlock()
require.True(t, found, "Records for question %v not found", question)
require.Len(t, records, 2, "Should have exactly 2 records for the same question")
// Verify both record data values are present
recordStrings := []string{records[0].String(), records[1].String()}
assert.Contains(t, recordStrings[0]+recordStrings[1], record1.RData, "First record data should be present")
assert.Contains(t, recordStrings[0]+recordStrings[1], record2.RData, "Second record data should be present")
}
// TestLocalResolver_RecordRotation verifies that records are rotated in a round-robin fashion
func TestLocalResolver_RecordRotation(t *testing.T) {
resolver := NewResolver()
recordName := "rotation.example.com."
recordType := dns.TypeA
// Create three records with the same name and type but different IPs
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.2",
}
record3 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
}
update := []nbdns.SimpleRecord{record1, record2, record3}
// Apply update with all three records
resolver.Update(update)
msg := new(dns.Msg).SetQuestion(recordName, recordType)
// First lookup - should return the records in original order
var responses [3]*dns.Msg
// Perform three lookups to verify rotation
for i := 0; i < 3; i++ {
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responses[i] = m
return nil
},
}
resolver.ServeDNS(responseWriter, msg)
}
// Verify all three responses contain answers
for i, resp := range responses {
require.NotNil(t, resp, "Response %d should not be nil", i)
require.Len(t, resp.Answer, 3, "Response %d should have 3 answers", i)
}
// Verify the first record in each response is different due to rotation
firstRecordIPs := []string{
responses[0].Answer[0].String(),
responses[1].Answer[0].String(),
responses[2].Answer[0].String(),
}
// Each record should be different (rotated)
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[1], "First lookup should differ from second lookup due to rotation")
assert.NotEqual(t, firstRecordIPs[1], firstRecordIPs[2], "Second lookup should differ from third lookup due to rotation")
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[2], "First lookup should differ from third lookup due to rotation")
// After three rotations, we should have cycled through all records
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record1.RData)
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record2.RData)
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record3.RData)
}
// TestLocalResolver_CaseInsensitiveMatching verifies that DNS record lookups are case-insensitive
func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
resolver := NewResolver()
// Create record with lowercase name
lowerCaseRecord := nbdns.SimpleRecord{
Name: "lower.example.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.10.10.10",
}
// Create record with mixed case name
mixedCaseRecord := nbdns.SimpleRecord{
Name: "MiXeD.ExAmPlE.CoM.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "20.20.20.20",
}
// Update resolver with the records
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
testCases := []struct {
name string
queryName string
expectedRData string
shouldResolve bool
}{
{
name: "Query lowercase with lowercase record",
queryName: "lower.example.com.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query uppercase with lowercase record",
queryName: "LOWER.EXAMPLE.COM.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query mixed case with lowercase record",
queryName: "LoWeR.eXaMpLe.CoM.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query lowercase with mixed case record",
queryName: "mixed.example.com.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query uppercase with mixed case record",
queryName: "MIXED.EXAMPLE.COM.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query with different casing pattern",
queryName: "mIxEd.ExaMpLe.cOm.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query non-existent domain",
queryName: "nonexistent.example.com.",
shouldResolve: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var responseMSG *dns.Msg
// Create DNS query with the test case name
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
// Create mock response writer to capture the response
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
// Perform DNS query
resolver.ServeDNS(responseWriter, msg)
// Check if we expect a successful resolution
if !tc.shouldResolve {
if responseMSG == nil || len(responseMSG.Answer) == 0 {
// Expected no answer, test passes
return
}
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
}
// Verify we got a response
require.NotNil(t, responseMSG, "Should have received a response message")
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
// Verify the response contains the expected data
answerString := responseMSG.Answer[0].String()
assert.Contains(t, answerString, tc.expectedRData,
"Answer should contain the expected IP address %s, got: %s",
tc.expectedRData, answerString)
})
}
}
// TestLocalResolver_CNAMEFallback verifies that the resolver correctly falls back
// to checking for CNAME records when the requested record type isn't found
func TestLocalResolver_CNAMEFallback(t *testing.T) {
resolver := NewResolver()
// Create a CNAME record (but no A record for this name)
cnameRecord := nbdns.SimpleRecord{
Name: "alias.example.com.",
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "target.example.com.",
}
// Create an A record for the CNAME target
targetRecord := nbdns.SimpleRecord{
Name: "target.example.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.100.100",
}
// Update resolver with both records
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
testCases := []struct {
name string
queryName string
queryType uint16
expectedType string
expectedRData string
shouldResolve bool
}{
{
name: "Directly query CNAME record",
queryName: "alias.example.com.",
queryType: dns.TypeCNAME,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query A record but get CNAME fallback",
queryName: "alias.example.com.",
queryType: dns.TypeA,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query AAAA record but get CNAME fallback",
queryName: "alias.example.com.",
queryType: dns.TypeAAAA,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query direct A record",
queryName: "target.example.com.",
queryType: dns.TypeA,
expectedType: "A",
expectedRData: "192.168.100.100",
shouldResolve: true,
},
{
name: "Query non-existent name",
queryName: "nonexistent.example.com.",
queryType: dns.TypeA,
shouldResolve: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var responseMSG *dns.Msg
// Create DNS query with the test case parameters
msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType)
// Create mock response writer to capture the response
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
// Perform DNS query
resolver.ServeDNS(responseWriter, msg)
// Check if we expect a successful resolution
if !tc.shouldResolve {
if responseMSG == nil || len(responseMSG.Answer) == 0 || responseMSG.Rcode != dns.RcodeSuccess {
// Expected no resolution, test passes
return
}
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
}
// Verify we got a successful response
require.NotNil(t, responseMSG, "Should have received a response message")
require.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "Response should have success status code")
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
// Verify the response contains the expected data
answerString := responseMSG.Answer[0].String()
assert.Contains(t, answerString, tc.expectedType,
"Answer should be of type %s, got: %s", tc.expectedType, answerString)
assert.Contains(t, answerString, tc.expectedRData,
"Answer should contain the expected data %s, got: %s", tc.expectedRData, answerString)
})
}
}

View File

@ -1,88 +0,0 @@
package dns
import (
"strings"
"testing"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
}
recordCNAME := nbdns.SimpleRecord{
Name: "peerb.netbird.cloud.",
Type: 5,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "www.netbird.io",
}
testCases := []struct {
name string
inputRecord nbdns.SimpleRecord
inputMSG *dns.Msg
responseShouldBeNil bool
}{
{
name: "Should Resolve A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
},
{
name: "Should Resolve CNAME Record",
inputRecord: recordCNAME,
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
},
{
name: "Should Not Write When Not Found A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
responseShouldBeNil: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
resolver := &localResolver{
registeredMap: make(registrationMap),
}
_, _ = resolver.registerRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil || len(responseMSG.Answer) == 0 {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
answerString := responseMSG.Answer[0].String()
if !strings.Contains(answerString, testCase.inputRecord.Name) {
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
}
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
}
if !strings.Contains(answerString, testCase.inputRecord.RData) {
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
}
})
}
}

View File

@ -1,26 +0,0 @@
package dns
import (
"net"
"github.com/miekg/dns"
)
type mockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
}
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (rw *mockResponseWriter) Close() error { return nil }
func (rw *mockResponseWriter) TsigStatus() error { return nil }
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
func (rw *mockResponseWriter) Hijack() {}

View File

@ -15,6 +15,8 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
@ -46,8 +48,6 @@ type Server interface {
ProbeAvailability()
}
type handlerID string
type nsGroupsByDomain struct {
domain string
groups []*nbdns.NameServerGroup
@ -61,7 +61,7 @@ type DefaultServer struct {
mux sync.Mutex
service service
dnsMuxMap registeredHandlerMap
localResolver *localResolver
localResolver *local.Resolver
wgInterface WGIface
hostManager hostManager
updateSerial uint64
@ -84,9 +84,9 @@ type DefaultServer struct {
type handlerWithStop interface {
dns.Handler
stop()
probeAvailability()
id() handlerID
Stop()
ProbeAvailability()
ID() types.HandlerID
}
type handlerWrapper struct {
@ -95,7 +95,7 @@ type handlerWrapper struct {
priority int
}
type registeredHandlerMap map[handlerID]handlerWrapper
type registeredHandlerMap map[types.HandlerID]handlerWrapper
// NewDefaultServer returns a new dns server
func NewDefaultServer(
@ -171,16 +171,14 @@ func newDefaultServer(
handlerChain := NewHandlerChain()
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{
ctx: ctx,
ctxCancel: stop,
disableSys: disableSys,
service: dnsService,
handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
ctx: ctx,
ctxCancel: stop,
disableSys: disableSys,
service: dnsService,
handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap),
localResolver: local.NewResolver(),
wgInterface: wgInterface,
statusRecorder: statusRecorder,
stateManager: stateManager,
@ -403,7 +401,7 @@ func (s *DefaultServer) ProbeAvailability() {
wg.Add(1)
go func(mux handlerWithStop) {
defer wg.Done()
mux.probeAvailability()
mux.ProbeAvailability()
}(mux.handler)
}
wg.Wait()
@ -420,7 +418,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.service.Stop()
}
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones)
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("local handler updater: %w", err)
}
@ -434,7 +432,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.updateMux(muxUpdates)
// register local records
s.updateLocalResolver(localRecordsByDomain)
s.localResolver.Update(localRecords)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
@ -516,11 +514,9 @@ func (s *DefaultServer) handleErrNoGroupaAll(err error) {
)
}
func (s *DefaultServer) buildLocalHandlerUpdate(
customZones []nbdns.CustomZone,
) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) {
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
var muxUpdates []handlerWrapper
localRecords := make(map[string][]nbdns.SimpleRecord)
var localRecords []nbdns.SimpleRecord
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
@ -534,17 +530,13 @@ func (s *DefaultServer) buildLocalHandlerUpdate(
priority: PriorityMatchDomain,
})
// group all records under this domain
for _, record := range customZone.Records {
var class uint16 = dns.ClassINET
if record.Class != nbdns.DefaultClass {
log.Warnf("received an invalid class type: %s", record.Class)
continue
}
key := buildRecordKey(record.Name, class, uint16(record.Type))
localRecords[key] = append(localRecords[key], record)
// zone records contain the fqdn, so we can just flatten them
localRecords = append(localRecords, record)
}
}
@ -627,7 +619,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
}
if len(handler.upstreamServers) == 0 {
handler.stop()
handler.Stop()
log.Errorf("received a nameserver group with an invalid nameserver list")
continue
}
@ -656,7 +648,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxMap {
s.deregisterHandler([]string{existing.domain}, existing.priority)
existing.handler.stop()
existing.handler.Stop()
}
muxUpdateMap := make(registeredHandlerMap)
@ -667,7 +659,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
containsRootUpdate = true
}
s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.id()] = update
muxUpdateMap[update.handler.ID()] = update
}
// If there's no root update and we had a root handler, restore it
@ -683,33 +675,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
s.dnsMuxMap = muxUpdateMap
}
func (s *DefaultServer) updateLocalResolver(update map[string][]nbdns.SimpleRecord) {
// remove old records that are no longer present
for key := range s.localResolver.registeredMap {
_, found := update[key]
if !found {
s.localResolver.deleteRecord(key)
}
}
updatedMap := make(registrationMap)
for _, recs := range update {
for _, rec := range recs {
// convert the record to a dns.RR and register
key, err := s.localResolver.registerRecord(rec)
if err != nil {
log.Warnf("got an error while registering the record (%s), error: %v",
rec.String(), err)
continue
}
updatedMap[key] = struct{}{}
}
}
s.localResolver.registeredMap = updatedMap
}
func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
}

View File

@ -23,6 +23,9 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
@ -107,6 +110,7 @@ func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamRe
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
@ -120,22 +124,21 @@ func TestUpdateDNSServer(t *testing.T) {
},
}
dummyHandler := &localResolver{}
dummyHandler := local.NewResolver()
testCases := []struct {
name string
initUpstreamMap registeredHandlerMap
initLocalMap registrationMap
initLocalRecords []nbdns.SimpleRecord
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap registeredHandlerMap
expectedLocalMap registrationMap
expectedLocalQs []dns.Question
}{
{
name: "Initial Config Should Succeed",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
@ -159,30 +162,30 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityMatchDomain,
},
dummyHandler.id(): handlerWrapper{
dummyHandler.ID(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityMatchDomain,
},
generateDummyHandler(".", nameServers).id(): handlerWrapper{
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: nbdns.RootZone,
handler: dummyHandler,
priority: PriorityDefault,
},
},
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
},
{
name: "New Config Should Succeed",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
name: "New Config Should Succeed",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
domain: buildRecordKey(zoneRecords[0].Name, 1, 1),
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityMatchDomain,
},
@ -205,7 +208,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityMatchDomain,
@ -216,22 +219,22 @@ func TestUpdateDNSServer(t *testing.T) {
priority: PriorityMatchDomain,
},
},
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
name: "Smaller Config Serial Should Be Skipped",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@ -249,11 +252,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Invalid NS Group Nameservers list Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@ -271,11 +274,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Invalid Custom Zone Records list Should Skip",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@ -290,17 +293,17 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: ".",
handler: dummyHandler,
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
name: "Empty Config Should Succeed and Clean Maps",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityMatchDomain,
@ -310,13 +313,13 @@ func TestUpdateDNSServer(t *testing.T) {
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: make(registeredHandlerMap),
expectedLocalMap: make(registrationMap),
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
name: "Disabled Service Should clean map",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityMatchDomain,
@ -326,7 +329,7 @@ func TestUpdateDNSServer(t *testing.T) {
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: make(registeredHandlerMap),
expectedLocalMap: make(registrationMap),
expectedLocalQs: []dns.Question{},
},
}
@ -377,7 +380,7 @@ func TestUpdateDNSServer(t *testing.T) {
}()
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.registeredMap = testCase.initLocalMap
dnsServer.localResolver.Update(testCase.initLocalRecords)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
@ -399,15 +402,23 @@ func TestUpdateDNSServer(t *testing.T) {
}
}
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) {
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap))
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
for _, q := range testCase.expectedLocalQs {
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
Question: []dns.Question{q},
})
}
for key := range testCase.expectedLocalMap {
_, found := dnsServer.localResolver.registeredMap[key]
if !found {
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap)
}
if len(testCase.expectedLocalQs) > 0 {
assert.NotNil(t, responseMSG, "response message should not be nil")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
}
})
}
@ -491,11 +502,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
dnsServer.dnsMuxMap = registeredHandlerMap{
"id1": handlerWrapper{
domain: zoneRecords[0].Name,
handler: &localResolver{},
handler: &local.Resolver{},
priority: PriorityMatchDomain,
},
}
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
@ -582,7 +594,7 @@ func TestDNSServerStartStop(t *testing.T) {
}
time.Sleep(100 * time.Millisecond)
defer dnsServer.Stop()
_, err = dnsServer.localResolver.registerRecord(zoneRecords[0])
err = dnsServer.localResolver.RegisterRecord(zoneRecords[0])
if err != nil {
t.Error(err)
}
@ -630,13 +642,11 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
handlerChain: NewHandlerChain(),
hostManager: hostManager,
ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}),
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: hostManager,
currentConfig: HostDNSConfig{
Domains: []DomainConfig{
{false, "domain0", false},
@ -1004,7 +1014,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
r := new(dns.Msg)
r.SetQuestion(tc.query, dns.TypeA)
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
w := &ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
mh.On("ServeDNS", mock.Anything, r).Once()
@ -1037,9 +1047,9 @@ type mockHandler struct {
}
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) stop() {}
func (m *mockHandler) probeAvailability() {}
func (m *mockHandler) id() handlerID { return handlerID(m.Id) }
func (m *mockHandler) Stop() {}
func (m *mockHandler) ProbeAvailability() {}
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
type mockService struct{}
@ -1113,7 +1123,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
name string
initialHandlers registeredHandlerMap
updates []handlerWrapper
expectedHandlers map[string]string // map[handlerID]domain
expectedHandlers map[string]string // map[HandlerID]domain
description string
}{
{
@ -1409,7 +1419,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
// Check each expected handler
for id, expectedDomain := range tt.expectedHandlers {
handler, exists := server.dnsMuxMap[handlerID(id)]
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
assert.True(t, exists, "Expected handler %s not found", id)
if exists {
assert.Equal(t, expectedDomain, handler.domain,
@ -1418,9 +1428,9 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
}
// Verify no unexpected handlers exist
for handlerID := range server.dnsMuxMap {
_, expected := tt.expectedHandlers[string(handlerID)]
assert.True(t, expected, "Unexpected handler found: %s", handlerID)
for HandlerID := range server.dnsMuxMap {
_, expected := tt.expectedHandlers[string(HandlerID)]
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
}
// Verify the handlerChain state and order
@ -1696,7 +1706,7 @@ func TestExtraDomains(t *testing.T) {
handlerChain: NewHandlerChain(),
wgInterface: &mocWGIface{},
hostManager: mockHostConfig,
localResolver: &localResolver{},
localResolver: &local.Resolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),
@ -1781,7 +1791,7 @@ func TestExtraDomainsRefCounting(t *testing.T) {
ctx: context.Background(),
handlerChain: NewHandlerChain(),
hostManager: mockHostConfig,
localResolver: &localResolver{},
localResolver: &local.Resolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),
@ -1833,7 +1843,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
ctx: context.Background(),
handlerChain: NewHandlerChain(),
hostManager: mockHostConfig,
localResolver: &localResolver{},
localResolver: &local.Resolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),
@ -1916,7 +1926,7 @@ func TestDomainCaseHandling(t *testing.T) {
ctx: context.Background(),
handlerChain: NewHandlerChain(),
hostManager: mockHostConfig,
localResolver: &localResolver{},
localResolver: &local.Resolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),

View File

@ -0,0 +1,26 @@
package test
import (
"net"
"github.com/miekg/dns"
)
type MockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
}
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (rw *MockResponseWriter) Close() error { return nil }
func (rw *MockResponseWriter) TsigStatus() error { return nil }
func (rw *MockResponseWriter) TsigTimersOnly(bool) {}
func (rw *MockResponseWriter) Hijack() {}

View File

@ -0,0 +1,3 @@
package types
type HandlerID string

View File

@ -19,6 +19,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
@ -81,21 +82,21 @@ func (u *upstreamResolverBase) String() string {
}
// ID returns the unique handler ID
func (u *upstreamResolverBase) id() handlerID {
func (u *upstreamResolverBase) ID() types.HandlerID {
servers := slices.Clone(u.upstreamServers)
slices.Sort(servers)
hash := sha256.New()
hash.Write([]byte(u.domain + ":"))
hash.Write([]byte(strings.Join(servers, ",")))
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
}
func (u *upstreamResolverBase) MatchSubdomains() bool {
return true
}
func (u *upstreamResolverBase) stop() {
func (u *upstreamResolverBase) Stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
u.cancel()
}
@ -198,9 +199,9 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
)
}
// probeAvailability tests all upstream servers simultaneously and
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) probeAvailability() {
func (u *upstreamResolverBase) ProbeAvailability() {
u.mutex.Lock()
defer u.mutex.Unlock()

View File

@ -8,6 +8,8 @@ import (
"time"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
)
func TestUpstreamResolver_ServeDNS(t *testing.T) {
@ -66,7 +68,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
}
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
@ -130,7 +132,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100
responseWriter := &mockResponseWriter{
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error { return nil },
}

View File

@ -66,17 +66,17 @@ func (s SimpleRecord) String() string {
func (s SimpleRecord) Len() uint16 {
emptyString := s.RData == ""
switch s.Type {
case 1:
case int(dns.TypeA):
if emptyString {
return 0
}
return net.IPv4len
case 5:
case int(dns.TypeCNAME):
if emptyString || s.RData == "." {
return 1
}
return uint16(len(s.RData) + 1)
case 28:
case int(dns.TypeAAAA):
if emptyString {
return 0
}