diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 4c910a95f..5f03e0758 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -1,7 +1,6 @@ package dns_test import ( - "net" "testing" "github.com/miekg/dns" @@ -9,6 +8,7 @@ import ( "github.com/stretchr/testify/mock" nbdns "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dns/test" ) // TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order @@ -30,7 +30,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { r.SetQuestion("example.com.", dns.TypeA) // Create test writer - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Setup expectations - only highest priority handler should be called dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once() @@ -142,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { r := new(dns.Msg) r.SetQuestion(tt.queryDomain, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} chain.ServeDNS(w, r) @@ -259,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { // Create and execute request r := new(dns.Msg) r.SetQuestion(tt.queryDomain, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} chain.ServeDNS(w, r) // Verify expectations @@ -316,7 +316,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { }).Once() // Execute - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} chain.ServeDNS(w, r) // Verify all handlers were called in order @@ -325,20 +325,6 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { handler3.AssertExpectations(t) } -// mockResponseWriter implements dns.ResponseWriter for testing -type mockResponseWriter struct { - mock.Mock -} - -func (m *mockResponseWriter) LocalAddr() net.Addr { return nil } -func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil } -func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil } -func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } -func (m *mockResponseWriter) Close() error { return nil } -func (m *mockResponseWriter) TsigStatus() error { return nil } -func (m *mockResponseWriter) TsigTimersOnly(bool) {} -func (m *mockResponseWriter) Hijack() {} - func TestHandlerChain_PriorityDeregistration(t *testing.T) { tests := []struct { name string @@ -425,7 +411,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { // Create test request r := new(dns.Msg) r.SetQuestion(tt.query, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Setup expectations for priority, handler := range handlers { @@ -471,7 +457,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) // Test 1: Initial state - w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Highest priority handler (routeHandler) should be called routeHandler.On("ServeDNS", mock.Anything, r).Return().Once() matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet @@ -490,7 +476,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Test 2: Remove highest priority handler chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute) - w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w2 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Now middle priority handler (matchHandler) should be called matchHandler.On("ServeDNS", mock.Anything, r).Return().Once() defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet @@ -506,7 +492,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Test 3: Remove middle priority handler chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) - w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Now lowest priority handler (defaultHandler) should be called defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once() @@ -519,7 +505,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Test 4: Remove last handler chain.RemoveHandler(testDomain, nbdns.PriorityDefault) - w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w4 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain for _, m := range mocks { @@ -675,7 +661,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { // Execute request r := new(dns.Msg) r.SetQuestion(tt.query, dns.TypeA) - chain.ServeDNS(&mockResponseWriter{}, r) + chain.ServeDNS(&test.MockResponseWriter{}, r) // Verify each handler was called exactly as expected for _, h := range tt.addHandlers { @@ -819,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { r := new(dns.Msg) r.SetQuestion(tt.query, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Setup handler expectations for pattern, handler := range handlers { @@ -969,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) { handler := &nbdns.MockHandler{} r := new(dns.Msg) r.SetQuestion(tt.queryPattern, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // First verify no handler is called before adding any chain.ServeDNS(w, r) diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go deleted file mode 100644 index 76e18e3ce..000000000 --- a/client/internal/dns/local.go +++ /dev/null @@ -1,130 +0,0 @@ -package dns - -import ( - "fmt" - "strings" - "sync" - - "github.com/miekg/dns" - log "github.com/sirupsen/logrus" - - nbdns "github.com/netbirdio/netbird/dns" -) - -type registrationMap map[string]struct{} - -type localResolver struct { - registeredMap registrationMap - records sync.Map // key: string (domain_class_type), value: []dns.RR -} - -func (d *localResolver) MatchSubdomains() bool { - return true -} - -func (d *localResolver) stop() { -} - -// String returns a string representation of the local resolver -func (d *localResolver) String() string { - return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap)) -} - -// ID returns the unique handler ID -func (d *localResolver) id() handlerID { - return "local-resolver" -} - -// ServeDNS handles a DNS request -func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - if len(r.Question) > 0 { - log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) - } - - replyMessage := &dns.Msg{} - replyMessage.SetReply(r) - replyMessage.RecursionAvailable = true - - // lookup all records matching the question - records := d.lookupRecords(r) - if len(records) > 0 { - replyMessage.Rcode = dns.RcodeSuccess - replyMessage.Answer = append(replyMessage.Answer, records...) - } else { - replyMessage.Rcode = dns.RcodeNameError - } - - err := w.WriteMsg(replyMessage) - if err != nil { - log.Debugf("got an error while writing the local resolver response, error: %v", err) - } -} - -// lookupRecords fetches *all* DNS records matching the first question in r. -func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR { - if len(r.Question) == 0 { - return nil - } - question := r.Question[0] - question.Name = strings.ToLower(question.Name) - key := buildRecordKey(question.Name, question.Qclass, question.Qtype) - - value, found := d.records.Load(key) - if !found { - // alternatively check if we have a cname - if question.Qtype != dns.TypeCNAME { - r.Question[0].Qtype = dns.TypeCNAME - return d.lookupRecords(r) - } - - return nil - } - - records, ok := value.([]dns.RR) - if !ok { - log.Errorf("failed to cast records to []dns.RR, records: %v", value) - return nil - } - - // if there's more than one record, rotate them (round-robin) - if len(records) > 1 { - first := records[0] - records = append(records[1:], first) - d.records.Store(key, records) - } - - return records -} - -// registerRecord stores a new record by appending it to any existing list -func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) { - rr, err := dns.NewRR(record.String()) - if err != nil { - return "", fmt.Errorf("register record: %w", err) - } - - rr.Header().Rdlength = record.Len() - header := rr.Header() - key := buildRecordKey(header.Name, header.Class, header.Rrtype) - - // load any existing slice of records, then append - existing, _ := d.records.LoadOrStore(key, []dns.RR{}) - records := existing.([]dns.RR) - records = append(records, rr) - - // store updated slice - d.records.Store(key, records) - return key, nil -} - -// deleteRecord removes *all* records under the recordKey. -func (d *localResolver) deleteRecord(recordKey string) { - d.records.Delete(dns.Fqdn(recordKey)) -} - -// buildRecordKey consistently generates a key: name_class_type -func buildRecordKey(name string, class, qType uint16) string { - return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType) -} - -func (d *localResolver) probeAvailability() {} diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go new file mode 100644 index 000000000..de3d8514b --- /dev/null +++ b/client/internal/dns/local/local.go @@ -0,0 +1,149 @@ +package local + +import ( + "fmt" + "slices" + "strings" + "sync" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + "github.com/netbirdio/netbird/client/internal/dns/types" + nbdns "github.com/netbirdio/netbird/dns" +) + +type Resolver struct { + mu sync.RWMutex + records map[dns.Question][]dns.RR +} + +func NewResolver() *Resolver { + return &Resolver{ + records: make(map[dns.Question][]dns.RR), + } +} + +func (d *Resolver) MatchSubdomains() bool { + return true +} + +// String returns a string representation of the local resolver +func (d *Resolver) String() string { + return fmt.Sprintf("local resolver [%d records]", len(d.records)) +} + +func (d *Resolver) Stop() {} + +// ID returns the unique handler ID +func (d *Resolver) ID() types.HandlerID { + return "local-resolver" +} + +func (d *Resolver) ProbeAvailability() {} + +// ServeDNS handles a DNS request +func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + log.Debugf("received local resolver request with no question") + return + } + question := r.Question[0] + question.Name = strings.ToLower(dns.Fqdn(question.Name)) + + log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass) + + replyMessage := &dns.Msg{} + replyMessage.SetReply(r) + replyMessage.RecursionAvailable = true + + // lookup all records matching the question + records := d.lookupRecords(question) + if len(records) > 0 { + replyMessage.Rcode = dns.RcodeSuccess + replyMessage.Answer = append(replyMessage.Answer, records...) + } else { + // TODO: return success if we have a different record type for the same name, relevant for search domains + replyMessage.Rcode = dns.RcodeNameError + } + + if err := w.WriteMsg(replyMessage); err != nil { + log.Warnf("failed to write the local resolver response: %v", err) + } +} + +// lookupRecords fetches *all* DNS records matching the first question in r. +func (d *Resolver) lookupRecords(question dns.Question) []dns.RR { + d.mu.RLock() + records, found := d.records[question] + + if !found { + d.mu.RUnlock() + // alternatively check if we have a cname + if question.Qtype != dns.TypeCNAME { + question.Qtype = dns.TypeCNAME + return d.lookupRecords(question) + } + return nil + } + + recordsCopy := slices.Clone(records) + d.mu.RUnlock() + + // if there's more than one record, rotate them (round-robin) + if len(recordsCopy) > 1 { + d.mu.Lock() + records = d.records[question] + if len(records) > 1 { + first := records[0] + records = append(records[1:], first) + d.records[question] = records + } + d.mu.Unlock() + } + + return recordsCopy +} + +func (d *Resolver) Update(update []nbdns.SimpleRecord) { + d.mu.Lock() + defer d.mu.Unlock() + + maps.Clear(d.records) + + for _, rec := range update { + if err := d.registerRecord(rec); err != nil { + log.Warnf("failed to register the record (%s): %v", rec, err) + continue + } + } +} + +// RegisterRecord stores a new record by appending it to any existing list +func (d *Resolver) RegisterRecord(record nbdns.SimpleRecord) error { + d.mu.Lock() + defer d.mu.Unlock() + + return d.registerRecord(record) +} + +// registerRecord performs the registration with the lock already held +func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error { + rr, err := dns.NewRR(record.String()) + if err != nil { + return fmt.Errorf("register record: %w", err) + } + + rr.Header().Rdlength = record.Len() + header := rr.Header() + q := dns.Question{ + Name: strings.ToLower(dns.Fqdn(header.Name)), + Qtype: header.Rrtype, + Qclass: header.Class, + } + + d.records[q] = append(d.records[q], rr) + + return nil +} diff --git a/client/internal/dns/local/local_test.go b/client/internal/dns/local/local_test.go new file mode 100644 index 000000000..1d38191e7 --- /dev/null +++ b/client/internal/dns/local/local_test.go @@ -0,0 +1,472 @@ +package local + +import ( + "strings" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/dns/test" + nbdns "github.com/netbirdio/netbird/dns" +) + +func TestLocalResolver_ServeDNS(t *testing.T) { + recordA := nbdns.SimpleRecord{ + Name: "peera.netbird.cloud.", + Type: 1, + Class: nbdns.DefaultClass, + TTL: 300, + RData: "1.2.3.4", + } + + recordCNAME := nbdns.SimpleRecord{ + Name: "peerb.netbird.cloud.", + Type: 5, + Class: nbdns.DefaultClass, + TTL: 300, + RData: "www.netbird.io", + } + + testCases := []struct { + name string + inputRecord nbdns.SimpleRecord + inputMSG *dns.Msg + responseShouldBeNil bool + }{ + { + name: "Should Resolve A Record", + inputRecord: recordA, + inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA), + }, + { + name: "Should Resolve CNAME Record", + inputRecord: recordCNAME, + inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME), + }, + { + name: "Should Not Write When Not Found A Record", + inputRecord: recordA, + inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA), + responseShouldBeNil: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + resolver := NewResolver() + _ = resolver.RegisterRecord(testCase.inputRecord) + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + resolver.ServeDNS(responseWriter, testCase.inputMSG) + + if responseMSG == nil || len(responseMSG.Answer) == 0 { + if testCase.responseShouldBeNil { + return + } + t.Fatalf("should write a response message") + } + + answerString := responseMSG.Answer[0].String() + if !strings.Contains(answerString, testCase.inputRecord.Name) { + t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString) + } + if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) { + t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString) + } + if !strings.Contains(answerString, testCase.inputRecord.RData) { + t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString) + } + }) + } +} + +// TestLocalResolver_Update_StaleRecord verifies that updating +// a record correctly replaces the old one, preventing stale entries. +func TestLocalResolver_Update_StaleRecord(t *testing.T) { + recordName := "host.example.com." + recordType := dns.TypeA + recordClass := dns.ClassINET + + record1 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1", + } + record2 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "2.2.2.2", + } + + recordKey := dns.Question{Name: recordName, Qtype: uint16(recordClass), Qclass: recordType} + + resolver := NewResolver() + + update1 := []nbdns.SimpleRecord{record1} + update2 := []nbdns.SimpleRecord{record2} + + // Apply first update + resolver.Update(update1) + + // Verify first update + resolver.mu.RLock() + rrSlice1, found1 := resolver.records[recordKey] + resolver.mu.RUnlock() + + require.True(t, found1, "Record key %s not found after first update", recordKey) + require.Len(t, rrSlice1, 1, "Should have exactly 1 record after first update") + assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData) + + // Apply second update + resolver.Update(update2) + + // Verify second update + resolver.mu.RLock() + rrSlice2, found2 := resolver.records[recordKey] + resolver.mu.RUnlock() + + require.True(t, found2, "Record key %s not found after second update", recordKey) + require.Len(t, rrSlice2, 1, "Should have exactly 1 record after update overwriting the key") + assert.Contains(t, rrSlice2[0].String(), record2.RData, "The single record should be the updated one (%s)", record2.RData) + assert.NotContains(t, rrSlice2[0].String(), record1.RData, "The stale record (%s) should not be present", record1.RData) +} + +// TestLocalResolver_MultipleRecords_SameQuestion verifies that multiple records +// with the same question are stored properly +func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) { + resolver := NewResolver() + + recordName := "multi.example.com." + recordType := dns.TypeA + + // Create two records with the same name and type but different IPs + record1 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1", + } + record2 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2", + } + + update := []nbdns.SimpleRecord{record1, record2} + + // Apply update with both records + resolver.Update(update) + + // Create question that matches both records + question := dns.Question{ + Name: recordName, + Qtype: recordType, + Qclass: dns.ClassINET, + } + + // Verify both records are stored + resolver.mu.RLock() + records, found := resolver.records[question] + resolver.mu.RUnlock() + + require.True(t, found, "Records for question %v not found", question) + require.Len(t, records, 2, "Should have exactly 2 records for the same question") + + // Verify both record data values are present + recordStrings := []string{records[0].String(), records[1].String()} + assert.Contains(t, recordStrings[0]+recordStrings[1], record1.RData, "First record data should be present") + assert.Contains(t, recordStrings[0]+recordStrings[1], record2.RData, "Second record data should be present") +} + +// TestLocalResolver_RecordRotation verifies that records are rotated in a round-robin fashion +func TestLocalResolver_RecordRotation(t *testing.T) { + resolver := NewResolver() + + recordName := "rotation.example.com." + recordType := dns.TypeA + + // Create three records with the same name and type but different IPs + record1 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1", + } + record2 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.2", + } + record3 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3", + } + + update := []nbdns.SimpleRecord{record1, record2, record3} + + // Apply update with all three records + resolver.Update(update) + + msg := new(dns.Msg).SetQuestion(recordName, recordType) + + // First lookup - should return the records in original order + var responses [3]*dns.Msg + + // Perform three lookups to verify rotation + for i := 0; i < 3; i++ { + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responses[i] = m + return nil + }, + } + + resolver.ServeDNS(responseWriter, msg) + } + + // Verify all three responses contain answers + for i, resp := range responses { + require.NotNil(t, resp, "Response %d should not be nil", i) + require.Len(t, resp.Answer, 3, "Response %d should have 3 answers", i) + } + + // Verify the first record in each response is different due to rotation + firstRecordIPs := []string{ + responses[0].Answer[0].String(), + responses[1].Answer[0].String(), + responses[2].Answer[0].String(), + } + + // Each record should be different (rotated) + assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[1], "First lookup should differ from second lookup due to rotation") + assert.NotEqual(t, firstRecordIPs[1], firstRecordIPs[2], "Second lookup should differ from third lookup due to rotation") + assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[2], "First lookup should differ from third lookup due to rotation") + + // After three rotations, we should have cycled through all records + assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record1.RData) + assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record2.RData) + assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record3.RData) +} + +// TestLocalResolver_CaseInsensitiveMatching verifies that DNS record lookups are case-insensitive +func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) { + resolver := NewResolver() + + // Create record with lowercase name + lowerCaseRecord := nbdns.SimpleRecord{ + Name: "lower.example.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "10.10.10.10", + } + + // Create record with mixed case name + mixedCaseRecord := nbdns.SimpleRecord{ + Name: "MiXeD.ExAmPlE.CoM.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "20.20.20.20", + } + + // Update resolver with the records + resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}) + + testCases := []struct { + name string + queryName string + expectedRData string + shouldResolve bool + }{ + { + name: "Query lowercase with lowercase record", + queryName: "lower.example.com.", + expectedRData: "10.10.10.10", + shouldResolve: true, + }, + { + name: "Query uppercase with lowercase record", + queryName: "LOWER.EXAMPLE.COM.", + expectedRData: "10.10.10.10", + shouldResolve: true, + }, + { + name: "Query mixed case with lowercase record", + queryName: "LoWeR.eXaMpLe.CoM.", + expectedRData: "10.10.10.10", + shouldResolve: true, + }, + { + name: "Query lowercase with mixed case record", + queryName: "mixed.example.com.", + expectedRData: "20.20.20.20", + shouldResolve: true, + }, + { + name: "Query uppercase with mixed case record", + queryName: "MIXED.EXAMPLE.COM.", + expectedRData: "20.20.20.20", + shouldResolve: true, + }, + { + name: "Query with different casing pattern", + queryName: "mIxEd.ExaMpLe.cOm.", + expectedRData: "20.20.20.20", + shouldResolve: true, + }, + { + name: "Query non-existent domain", + queryName: "nonexistent.example.com.", + shouldResolve: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var responseMSG *dns.Msg + + // Create DNS query with the test case name + msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA) + + // Create mock response writer to capture the response + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + // Perform DNS query + resolver.ServeDNS(responseWriter, msg) + + // Check if we expect a successful resolution + if !tc.shouldResolve { + if responseMSG == nil || len(responseMSG.Answer) == 0 { + // Expected no answer, test passes + return + } + t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer) + } + + // Verify we got a response + require.NotNil(t, responseMSG, "Should have received a response message") + require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer") + + // Verify the response contains the expected data + answerString := responseMSG.Answer[0].String() + assert.Contains(t, answerString, tc.expectedRData, + "Answer should contain the expected IP address %s, got: %s", + tc.expectedRData, answerString) + }) + } +} + +// TestLocalResolver_CNAMEFallback verifies that the resolver correctly falls back +// to checking for CNAME records when the requested record type isn't found +func TestLocalResolver_CNAMEFallback(t *testing.T) { + resolver := NewResolver() + + // Create a CNAME record (but no A record for this name) + cnameRecord := nbdns.SimpleRecord{ + Name: "alias.example.com.", + Type: int(dns.TypeCNAME), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "target.example.com.", + } + + // Create an A record for the CNAME target + targetRecord := nbdns.SimpleRecord{ + Name: "target.example.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.100.100", + } + + // Update resolver with both records + resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord}) + + testCases := []struct { + name string + queryName string + queryType uint16 + expectedType string + expectedRData string + shouldResolve bool + }{ + { + name: "Directly query CNAME record", + queryName: "alias.example.com.", + queryType: dns.TypeCNAME, + expectedType: "CNAME", + expectedRData: "target.example.com.", + shouldResolve: true, + }, + { + name: "Query A record but get CNAME fallback", + queryName: "alias.example.com.", + queryType: dns.TypeA, + expectedType: "CNAME", + expectedRData: "target.example.com.", + shouldResolve: true, + }, + { + name: "Query AAAA record but get CNAME fallback", + queryName: "alias.example.com.", + queryType: dns.TypeAAAA, + expectedType: "CNAME", + expectedRData: "target.example.com.", + shouldResolve: true, + }, + { + name: "Query direct A record", + queryName: "target.example.com.", + queryType: dns.TypeA, + expectedType: "A", + expectedRData: "192.168.100.100", + shouldResolve: true, + }, + { + name: "Query non-existent name", + queryName: "nonexistent.example.com.", + queryType: dns.TypeA, + shouldResolve: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var responseMSG *dns.Msg + + // Create DNS query with the test case parameters + msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType) + + // Create mock response writer to capture the response + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + // Perform DNS query + resolver.ServeDNS(responseWriter, msg) + + // Check if we expect a successful resolution + if !tc.shouldResolve { + if responseMSG == nil || len(responseMSG.Answer) == 0 || responseMSG.Rcode != dns.RcodeSuccess { + // Expected no resolution, test passes + return + } + t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer) + } + + // Verify we got a successful response + require.NotNil(t, responseMSG, "Should have received a response message") + require.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "Response should have success status code") + require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer") + + // Verify the response contains the expected data + answerString := responseMSG.Answer[0].String() + assert.Contains(t, answerString, tc.expectedType, + "Answer should be of type %s, got: %s", tc.expectedType, answerString) + assert.Contains(t, answerString, tc.expectedRData, + "Answer should contain the expected data %s, got: %s", tc.expectedRData, answerString) + }) + } +} diff --git a/client/internal/dns/local_test.go b/client/internal/dns/local_test.go deleted file mode 100644 index 0a42b321a..000000000 --- a/client/internal/dns/local_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package dns - -import ( - "strings" - "testing" - - "github.com/miekg/dns" - - nbdns "github.com/netbirdio/netbird/dns" -) - -func TestLocalResolver_ServeDNS(t *testing.T) { - recordA := nbdns.SimpleRecord{ - Name: "peera.netbird.cloud.", - Type: 1, - Class: nbdns.DefaultClass, - TTL: 300, - RData: "1.2.3.4", - } - - recordCNAME := nbdns.SimpleRecord{ - Name: "peerb.netbird.cloud.", - Type: 5, - Class: nbdns.DefaultClass, - TTL: 300, - RData: "www.netbird.io", - } - - testCases := []struct { - name string - inputRecord nbdns.SimpleRecord - inputMSG *dns.Msg - responseShouldBeNil bool - }{ - { - name: "Should Resolve A Record", - inputRecord: recordA, - inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA), - }, - { - name: "Should Resolve CNAME Record", - inputRecord: recordCNAME, - inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME), - }, - { - name: "Should Not Write When Not Found A Record", - inputRecord: recordA, - inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA), - responseShouldBeNil: true, - }, - } - - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - resolver := &localResolver{ - registeredMap: make(registrationMap), - } - _, _ = resolver.registerRecord(testCase.inputRecord) - var responseMSG *dns.Msg - responseWriter := &mockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - responseMSG = m - return nil - }, - } - - resolver.ServeDNS(responseWriter, testCase.inputMSG) - - if responseMSG == nil || len(responseMSG.Answer) == 0 { - if testCase.responseShouldBeNil { - return - } - t.Fatalf("should write a response message") - } - - answerString := responseMSG.Answer[0].String() - if !strings.Contains(answerString, testCase.inputRecord.Name) { - t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString) - } - if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) { - t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString) - } - if !strings.Contains(answerString, testCase.inputRecord.RData) { - t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString) - } - }) - } -} diff --git a/client/internal/dns/mock_test.go b/client/internal/dns/mock_test.go deleted file mode 100644 index d52ae24da..000000000 --- a/client/internal/dns/mock_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package dns - -import ( - "net" - - "github.com/miekg/dns" -) - -type mockResponseWriter struct { - WriteMsgFunc func(m *dns.Msg) error -} - -func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error { - if rw.WriteMsgFunc != nil { - return rw.WriteMsgFunc(m) - } - return nil -} - -func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil } -func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil } -func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } -func (rw *mockResponseWriter) Close() error { return nil } -func (rw *mockResponseWriter) TsigStatus() error { return nil } -func (rw *mockResponseWriter) TsigTimersOnly(bool) {} -func (rw *mockResponseWriter) Hijack() {} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 65b90e5f0..3f49c23fd 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -15,6 +15,8 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/internal/dns/local" + "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -46,8 +48,6 @@ type Server interface { ProbeAvailability() } -type handlerID string - type nsGroupsByDomain struct { domain string groups []*nbdns.NameServerGroup @@ -61,7 +61,7 @@ type DefaultServer struct { mux sync.Mutex service service dnsMuxMap registeredHandlerMap - localResolver *localResolver + localResolver *local.Resolver wgInterface WGIface hostManager hostManager updateSerial uint64 @@ -84,9 +84,9 @@ type DefaultServer struct { type handlerWithStop interface { dns.Handler - stop() - probeAvailability() - id() handlerID + Stop() + ProbeAvailability() + ID() types.HandlerID } type handlerWrapper struct { @@ -95,7 +95,7 @@ type handlerWrapper struct { priority int } -type registeredHandlerMap map[handlerID]handlerWrapper +type registeredHandlerMap map[types.HandlerID]handlerWrapper // NewDefaultServer returns a new dns server func NewDefaultServer( @@ -171,16 +171,14 @@ func newDefaultServer( handlerChain := NewHandlerChain() ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - disableSys: disableSys, - service: dnsService, - handlerChain: handlerChain, - extraDomains: make(map[domain.Domain]int), - dnsMuxMap: make(registeredHandlerMap), - localResolver: &localResolver{ - registeredMap: make(registrationMap), - }, + ctx: ctx, + ctxCancel: stop, + disableSys: disableSys, + service: dnsService, + handlerChain: handlerChain, + extraDomains: make(map[domain.Domain]int), + dnsMuxMap: make(registeredHandlerMap), + localResolver: local.NewResolver(), wgInterface: wgInterface, statusRecorder: statusRecorder, stateManager: stateManager, @@ -403,7 +401,7 @@ func (s *DefaultServer) ProbeAvailability() { wg.Add(1) go func(mux handlerWithStop) { defer wg.Done() - mux.probeAvailability() + mux.ProbeAvailability() }(mux.handler) } wg.Wait() @@ -420,7 +418,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.service.Stop() } - localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones) + localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) if err != nil { return fmt.Errorf("local handler updater: %w", err) } @@ -434,7 +432,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.updateMux(muxUpdates) // register local records - s.updateLocalResolver(localRecordsByDomain) + s.localResolver.Update(localRecords) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) @@ -516,11 +514,9 @@ func (s *DefaultServer) handleErrNoGroupaAll(err error) { ) } -func (s *DefaultServer) buildLocalHandlerUpdate( - customZones []nbdns.CustomZone, -) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) { +func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) { var muxUpdates []handlerWrapper - localRecords := make(map[string][]nbdns.SimpleRecord) + var localRecords []nbdns.SimpleRecord for _, customZone := range customZones { if len(customZone.Records) == 0 { @@ -534,17 +530,13 @@ func (s *DefaultServer) buildLocalHandlerUpdate( priority: PriorityMatchDomain, }) - // group all records under this domain for _, record := range customZone.Records { - var class uint16 = dns.ClassINET if record.Class != nbdns.DefaultClass { log.Warnf("received an invalid class type: %s", record.Class) continue } - - key := buildRecordKey(record.Name, class, uint16(record.Type)) - - localRecords[key] = append(localRecords[key], record) + // zone records contain the fqdn, so we can just flatten them + localRecords = append(localRecords, record) } } @@ -627,7 +619,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai } if len(handler.upstreamServers) == 0 { - handler.stop() + handler.Stop() log.Errorf("received a nameserver group with an invalid nameserver list") continue } @@ -656,7 +648,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { // this will introduce a short period of time when the server is not able to handle DNS requests for _, existing := range s.dnsMuxMap { s.deregisterHandler([]string{existing.domain}, existing.priority) - existing.handler.stop() + existing.handler.Stop() } muxUpdateMap := make(registeredHandlerMap) @@ -667,7 +659,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { containsRootUpdate = true } s.registerHandler([]string{update.domain}, update.handler, update.priority) - muxUpdateMap[update.handler.id()] = update + muxUpdateMap[update.handler.ID()] = update } // If there's no root update and we had a root handler, restore it @@ -683,33 +675,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { s.dnsMuxMap = muxUpdateMap } -func (s *DefaultServer) updateLocalResolver(update map[string][]nbdns.SimpleRecord) { - // remove old records that are no longer present - for key := range s.localResolver.registeredMap { - _, found := update[key] - if !found { - s.localResolver.deleteRecord(key) - } - } - - updatedMap := make(registrationMap) - for _, recs := range update { - for _, rec := range recs { - // convert the record to a dns.RR and register - key, err := s.localResolver.registerRecord(rec) - if err != nil { - log.Warnf("got an error while registering the record (%s), error: %v", - rec.String(), err) - continue - } - - updatedMap[key] = struct{}{} - } - } - - s.localResolver.registeredMap = updatedMap -} - func getNSHostPort(ns nbdns.NameServer) string { return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index ed69b0e93..1c7c9b117 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -23,6 +23,9 @@ import ( "github.com/netbirdio/netbird/client/iface/device" pfmock "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/internal/dns/local" + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -107,6 +110,7 @@ func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamRe } func TestUpdateDNSServer(t *testing.T) { + nameServers := []nbdns.NameServer{ { IP: netip.MustParseAddr("8.8.8.8"), @@ -120,22 +124,21 @@ func TestUpdateDNSServer(t *testing.T) { }, } - dummyHandler := &localResolver{} + dummyHandler := local.NewResolver() testCases := []struct { name string initUpstreamMap registeredHandlerMap - initLocalMap registrationMap + initLocalRecords []nbdns.SimpleRecord initSerial uint64 inputSerial uint64 inputUpdate nbdns.Config shouldFail bool expectedUpstreamMap registeredHandlerMap - expectedLocalMap registrationMap + expectedLocalQs []dns.Question }{ { name: "Initial Config Should Succeed", - initLocalMap: make(registrationMap), initUpstreamMap: make(registeredHandlerMap), initSerial: 0, inputSerial: 1, @@ -159,30 +162,30 @@ func TestUpdateDNSServer(t *testing.T) { }, }, expectedUpstreamMap: registeredHandlerMap{ - generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{ + generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ domain: "netbird.io", handler: dummyHandler, priority: PriorityMatchDomain, }, - dummyHandler.id(): handlerWrapper{ + dummyHandler.ID(): handlerWrapper{ domain: "netbird.cloud", handler: dummyHandler, priority: PriorityMatchDomain, }, - generateDummyHandler(".", nameServers).id(): handlerWrapper{ + generateDummyHandler(".", nameServers).ID(): handlerWrapper{ domain: nbdns.RootZone, handler: dummyHandler, priority: PriorityDefault, }, }, - expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, + expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}, }, { - name: "New Config Should Succeed", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + name: "New Config Should Succeed", + initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initUpstreamMap: registeredHandlerMap{ - generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ - domain: buildRecordKey(zoneRecords[0].Name, 1, 1), + generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ + domain: "netbird.cloud", handler: dummyHandler, priority: PriorityMatchDomain, }, @@ -205,7 +208,7 @@ func TestUpdateDNSServer(t *testing.T) { }, }, expectedUpstreamMap: registeredHandlerMap{ - generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{ + generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ domain: "netbird.io", handler: dummyHandler, priority: PriorityMatchDomain, @@ -216,22 +219,22 @@ func TestUpdateDNSServer(t *testing.T) { priority: PriorityMatchDomain, }, }, - expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, + expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}}, }, { - name: "Smaller Config Serial Should Be Skipped", - initLocalMap: make(registrationMap), - initUpstreamMap: make(registeredHandlerMap), - initSerial: 2, - inputSerial: 1, - shouldFail: true, + name: "Smaller Config Serial Should Be Skipped", + initLocalRecords: []nbdns.SimpleRecord{}, + initUpstreamMap: make(registeredHandlerMap), + initSerial: 2, + inputSerial: 1, + shouldFail: true, }, { - name: "Empty NS Group Domain Or Not Primary Element Should Fail", - initLocalMap: make(registrationMap), - initUpstreamMap: make(registeredHandlerMap), - initSerial: 0, - inputSerial: 1, + name: "Empty NS Group Domain Or Not Primary Element Should Fail", + initLocalRecords: []nbdns.SimpleRecord{}, + initUpstreamMap: make(registeredHandlerMap), + initSerial: 0, + inputSerial: 1, inputUpdate: nbdns.Config{ ServiceEnable: true, CustomZones: []nbdns.CustomZone{ @@ -249,11 +252,11 @@ func TestUpdateDNSServer(t *testing.T) { shouldFail: true, }, { - name: "Invalid NS Group Nameservers list Should Fail", - initLocalMap: make(registrationMap), - initUpstreamMap: make(registeredHandlerMap), - initSerial: 0, - inputSerial: 1, + name: "Invalid NS Group Nameservers list Should Fail", + initLocalRecords: []nbdns.SimpleRecord{}, + initUpstreamMap: make(registeredHandlerMap), + initSerial: 0, + inputSerial: 1, inputUpdate: nbdns.Config{ ServiceEnable: true, CustomZones: []nbdns.CustomZone{ @@ -271,11 +274,11 @@ func TestUpdateDNSServer(t *testing.T) { shouldFail: true, }, { - name: "Invalid Custom Zone Records list Should Skip", - initLocalMap: make(registrationMap), - initUpstreamMap: make(registeredHandlerMap), - initSerial: 0, - inputSerial: 1, + name: "Invalid Custom Zone Records list Should Skip", + initLocalRecords: []nbdns.SimpleRecord{}, + initUpstreamMap: make(registeredHandlerMap), + initSerial: 0, + inputSerial: 1, inputUpdate: nbdns.Config{ ServiceEnable: true, CustomZones: []nbdns.CustomZone{ @@ -290,17 +293,17 @@ func TestUpdateDNSServer(t *testing.T) { }, }, }, - expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{ + expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{ domain: ".", handler: dummyHandler, priority: PriorityDefault, }}, }, { - name: "Empty Config Should Succeed and Clean Maps", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + name: "Empty Config Should Succeed and Clean Maps", + initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initUpstreamMap: registeredHandlerMap{ - generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ + generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: zoneRecords[0].Name, handler: dummyHandler, priority: PriorityMatchDomain, @@ -310,13 +313,13 @@ func TestUpdateDNSServer(t *testing.T) { inputSerial: 1, inputUpdate: nbdns.Config{ServiceEnable: true}, expectedUpstreamMap: make(registeredHandlerMap), - expectedLocalMap: make(registrationMap), + expectedLocalQs: []dns.Question{}, }, { - name: "Disabled Service Should clean map", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + name: "Disabled Service Should clean map", + initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initUpstreamMap: registeredHandlerMap{ - generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ + generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: zoneRecords[0].Name, handler: dummyHandler, priority: PriorityMatchDomain, @@ -326,7 +329,7 @@ func TestUpdateDNSServer(t *testing.T) { inputSerial: 1, inputUpdate: nbdns.Config{ServiceEnable: false}, expectedUpstreamMap: make(registeredHandlerMap), - expectedLocalMap: make(registrationMap), + expectedLocalQs: []dns.Question{}, }, } @@ -377,7 +380,7 @@ func TestUpdateDNSServer(t *testing.T) { }() dnsServer.dnsMuxMap = testCase.initUpstreamMap - dnsServer.localResolver.registeredMap = testCase.initLocalMap + dnsServer.localResolver.Update(testCase.initLocalRecords) dnsServer.updateSerial = testCase.initSerial err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) @@ -399,15 +402,23 @@ func TestUpdateDNSServer(t *testing.T) { } } - if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) { - t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap)) + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + for _, q := range testCase.expectedLocalQs { + dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{ + Question: []dns.Question{q}, + }) } - for key := range testCase.expectedLocalMap { - _, found := dnsServer.localResolver.registeredMap[key] - if !found { - t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap) - } + if len(testCase.expectedLocalQs) > 0 { + assert.NotNil(t, responseMSG, "response message should not be nil") + assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success") + assert.NotEmpty(t, responseMSG.Answer, "response message should have answers") } }) } @@ -491,11 +502,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { dnsServer.dnsMuxMap = registeredHandlerMap{ "id1": handlerWrapper{ domain: zoneRecords[0].Name, - handler: &localResolver{}, + handler: &local.Resolver{}, priority: PriorityMatchDomain, }, } - dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}} + //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} + dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}) dnsServer.updateSerial = 0 nameServers := []nbdns.NameServer{ @@ -582,7 +594,7 @@ func TestDNSServerStartStop(t *testing.T) { } time.Sleep(100 * time.Millisecond) defer dnsServer.Stop() - _, err = dnsServer.localResolver.registerRecord(zoneRecords[0]) + err = dnsServer.localResolver.RegisterRecord(zoneRecords[0]) if err != nil { t.Error(err) } @@ -630,13 +642,11 @@ func TestDNSServerStartStop(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { hostManager := &mockHostConfigurator{} server := DefaultServer{ - ctx: context.Background(), - service: NewServiceViaMemory(&mocWGIface{}), - localResolver: &localResolver{ - registeredMap: make(registrationMap), - }, - handlerChain: NewHandlerChain(), - hostManager: hostManager, + ctx: context.Background(), + service: NewServiceViaMemory(&mocWGIface{}), + localResolver: local.NewResolver(), + handlerChain: NewHandlerChain(), + hostManager: hostManager, currentConfig: HostDNSConfig{ Domains: []DomainConfig{ {false, "domain0", false}, @@ -1004,7 +1014,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) { t.Run(tc.name, func(t *testing.T) { r := new(dns.Msg) r.SetQuestion(tc.query, dns.TypeA) - w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} if mh, ok := tc.expectedHandler.(*MockHandler); ok { mh.On("ServeDNS", mock.Anything, r).Once() @@ -1037,9 +1047,9 @@ type mockHandler struct { } func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} -func (m *mockHandler) stop() {} -func (m *mockHandler) probeAvailability() {} -func (m *mockHandler) id() handlerID { return handlerID(m.Id) } +func (m *mockHandler) Stop() {} +func (m *mockHandler) ProbeAvailability() {} +func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) } type mockService struct{} @@ -1113,7 +1123,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { name string initialHandlers registeredHandlerMap updates []handlerWrapper - expectedHandlers map[string]string // map[handlerID]domain + expectedHandlers map[string]string // map[HandlerID]domain description string }{ { @@ -1409,7 +1419,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { // Check each expected handler for id, expectedDomain := range tt.expectedHandlers { - handler, exists := server.dnsMuxMap[handlerID(id)] + handler, exists := server.dnsMuxMap[types.HandlerID(id)] assert.True(t, exists, "Expected handler %s not found", id) if exists { assert.Equal(t, expectedDomain, handler.domain, @@ -1418,9 +1428,9 @@ func TestDefaultServer_UpdateMux(t *testing.T) { } // Verify no unexpected handlers exist - for handlerID := range server.dnsMuxMap { - _, expected := tt.expectedHandlers[string(handlerID)] - assert.True(t, expected, "Unexpected handler found: %s", handlerID) + for HandlerID := range server.dnsMuxMap { + _, expected := tt.expectedHandlers[string(HandlerID)] + assert.True(t, expected, "Unexpected handler found: %s", HandlerID) } // Verify the handlerChain state and order @@ -1696,7 +1706,7 @@ func TestExtraDomains(t *testing.T) { handlerChain: NewHandlerChain(), wgInterface: &mocWGIface{}, hostManager: mockHostConfig, - localResolver: &localResolver{}, + localResolver: &local.Resolver{}, service: mockSvc, statusRecorder: peer.NewRecorder("test"), extraDomains: make(map[domain.Domain]int), @@ -1781,7 +1791,7 @@ func TestExtraDomainsRefCounting(t *testing.T) { ctx: context.Background(), handlerChain: NewHandlerChain(), hostManager: mockHostConfig, - localResolver: &localResolver{}, + localResolver: &local.Resolver{}, service: mockSvc, statusRecorder: peer.NewRecorder("test"), extraDomains: make(map[domain.Domain]int), @@ -1833,7 +1843,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) { ctx: context.Background(), handlerChain: NewHandlerChain(), hostManager: mockHostConfig, - localResolver: &localResolver{}, + localResolver: &local.Resolver{}, service: mockSvc, statusRecorder: peer.NewRecorder("test"), extraDomains: make(map[domain.Domain]int), @@ -1916,7 +1926,7 @@ func TestDomainCaseHandling(t *testing.T) { ctx: context.Background(), handlerChain: NewHandlerChain(), hostManager: mockHostConfig, - localResolver: &localResolver{}, + localResolver: &local.Resolver{}, service: mockSvc, statusRecorder: peer.NewRecorder("test"), extraDomains: make(map[domain.Domain]int), diff --git a/client/internal/dns/test/mock.go b/client/internal/dns/test/mock.go new file mode 100644 index 000000000..1db452805 --- /dev/null +++ b/client/internal/dns/test/mock.go @@ -0,0 +1,26 @@ +package test + +import ( + "net" + + "github.com/miekg/dns" +) + +type MockResponseWriter struct { + WriteMsgFunc func(m *dns.Msg) error +} + +func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error { + if rw.WriteMsgFunc != nil { + return rw.WriteMsgFunc(m) + } + return nil +} + +func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil } +func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil } +func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil } +func (rw *MockResponseWriter) Close() error { return nil } +func (rw *MockResponseWriter) TsigStatus() error { return nil } +func (rw *MockResponseWriter) TsigTimersOnly(bool) {} +func (rw *MockResponseWriter) Hijack() {} diff --git a/client/internal/dns/types/types.go b/client/internal/dns/types/types.go new file mode 100644 index 000000000..5a8be03b7 --- /dev/null +++ b/client/internal/dns/types/types.go @@ -0,0 +1,3 @@ +package types + +type HandlerID string diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index fa69d4934..2fbfb3b91 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -19,6 +19,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" ) @@ -81,21 +82,21 @@ func (u *upstreamResolverBase) String() string { } // ID returns the unique handler ID -func (u *upstreamResolverBase) id() handlerID { +func (u *upstreamResolverBase) ID() types.HandlerID { servers := slices.Clone(u.upstreamServers) slices.Sort(servers) hash := sha256.New() hash.Write([]byte(u.domain + ":")) hash.Write([]byte(strings.Join(servers, ","))) - return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) + return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) } func (u *upstreamResolverBase) MatchSubdomains() bool { return true } -func (u *upstreamResolverBase) stop() { +func (u *upstreamResolverBase) Stop() { log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) u.cancel() } @@ -198,9 +199,9 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) { ) } -// probeAvailability tests all upstream servers simultaneously and +// ProbeAvailability tests all upstream servers simultaneously and // disables the resolver if none work -func (u *upstreamResolverBase) probeAvailability() { +func (u *upstreamResolverBase) ProbeAvailability() { u.mutex.Lock() defer u.mutex.Unlock() diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 5dbcc9f79..13bc91a37 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -8,6 +8,8 @@ import ( "time" "github.com/miekg/dns" + + "github.com/netbirdio/netbird/client/internal/dns/test" ) func TestUpstreamResolver_ServeDNS(t *testing.T) { @@ -66,7 +68,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { } var responseMSG *dns.Msg - responseWriter := &mockResponseWriter{ + responseWriter := &test.MockResponseWriter{ WriteMsgFunc: func(m *dns.Msg) error { responseMSG = m return nil @@ -130,7 +132,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { resolver.failsTillDeact = 0 resolver.reactivatePeriod = time.Microsecond * 100 - responseWriter := &mockResponseWriter{ + responseWriter := &test.MockResponseWriter{ WriteMsgFunc: func(m *dns.Msg) error { return nil }, } diff --git a/dns/dns.go b/dns/dns.go index 3a1c76e56..f889a32ec 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -66,17 +66,17 @@ func (s SimpleRecord) String() string { func (s SimpleRecord) Len() uint16 { emptyString := s.RData == "" switch s.Type { - case 1: + case int(dns.TypeA): if emptyString { return 0 } return net.IPv4len - case 5: + case int(dns.TypeCNAME): if emptyString || s.RData == "." { return 1 } return uint16(len(s.RData) + 1) - case 28: + case int(dns.TypeAAAA): if emptyString { return 0 }