mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-16 01:58:16 +02:00
[client] Fix stale local records (#3776)
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
package dns_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@ -9,6 +8,7 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
)
|
||||
|
||||
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||
@ -30,7 +30,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
// Create test writer
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
// Setup expectations - only highest priority handler should be called
|
||||
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||
@ -142,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
@ -259,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
// Create and execute request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
// Verify expectations
|
||||
@ -316,7 +316,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||
}).Once()
|
||||
|
||||
// Execute
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
// Verify all handlers were called in order
|
||||
@ -325,20 +325,6 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||
handler3.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// mockResponseWriter implements dns.ResponseWriter for testing
|
||||
type mockResponseWriter struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
|
||||
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
func (m *mockResponseWriter) Close() error { return nil }
|
||||
func (m *mockResponseWriter) TsigStatus() error { return nil }
|
||||
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
|
||||
func (m *mockResponseWriter) Hijack() {}
|
||||
|
||||
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -425,7 +411,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
// Create test request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.query, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
// Setup expectations
|
||||
for priority, handler := range handlers {
|
||||
@ -471,7 +457,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||
|
||||
// Test 1: Initial state
|
||||
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
// Highest priority handler (routeHandler) should be called
|
||||
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||
@ -490,7 +476,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
// Test 2: Remove highest priority handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||
|
||||
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
// Now middle priority handler (matchHandler) should be called
|
||||
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
|
||||
@ -506,7 +492,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
// Test 3: Remove middle priority handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||
|
||||
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
// Now lowest priority handler (defaultHandler) should be called
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
|
||||
@ -519,7 +505,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
// Test 4: Remove last handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||
|
||||
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
|
||||
|
||||
for _, m := range mocks {
|
||||
@ -675,7 +661,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
// Execute request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.query, dns.TypeA)
|
||||
chain.ServeDNS(&mockResponseWriter{}, r)
|
||||
chain.ServeDNS(&test.MockResponseWriter{}, r)
|
||||
|
||||
// Verify each handler was called exactly as expected
|
||||
for _, h := range tt.addHandlers {
|
||||
@ -819,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.query, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
// Setup handler expectations
|
||||
for pattern, handler := range handlers {
|
||||
@ -969,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||
handler := &nbdns.MockHandler{}
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryPattern, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
// First verify no handler is called before adding any
|
||||
chain.ServeDNS(w, r)
|
||||
|
@ -1,130 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
type registrationMap map[string]struct{}
|
||||
|
||||
type localResolver struct {
|
||||
registeredMap registrationMap
|
||||
records sync.Map // key: string (domain_class_type), value: []dns.RR
|
||||
}
|
||||
|
||||
func (d *localResolver) MatchSubdomains() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *localResolver) stop() {
|
||||
}
|
||||
|
||||
// String returns a string representation of the local resolver
|
||||
func (d *localResolver) String() string {
|
||||
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
||||
}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
func (d *localResolver) id() handlerID {
|
||||
return "local-resolver"
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) > 0 {
|
||||
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
}
|
||||
|
||||
replyMessage := &dns.Msg{}
|
||||
replyMessage.SetReply(r)
|
||||
replyMessage.RecursionAvailable = true
|
||||
|
||||
// lookup all records matching the question
|
||||
records := d.lookupRecords(r)
|
||||
if len(records) > 0 {
|
||||
replyMessage.Rcode = dns.RcodeSuccess
|
||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
||||
} else {
|
||||
replyMessage.Rcode = dns.RcodeNameError
|
||||
}
|
||||
|
||||
err := w.WriteMsg(replyMessage)
|
||||
if err != nil {
|
||||
log.Debugf("got an error while writing the local resolver response, error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||
func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
|
||||
if len(r.Question) == 0 {
|
||||
return nil
|
||||
}
|
||||
question := r.Question[0]
|
||||
question.Name = strings.ToLower(question.Name)
|
||||
key := buildRecordKey(question.Name, question.Qclass, question.Qtype)
|
||||
|
||||
value, found := d.records.Load(key)
|
||||
if !found {
|
||||
// alternatively check if we have a cname
|
||||
if question.Qtype != dns.TypeCNAME {
|
||||
r.Question[0].Qtype = dns.TypeCNAME
|
||||
return d.lookupRecords(r)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
records, ok := value.([]dns.RR)
|
||||
if !ok {
|
||||
log.Errorf("failed to cast records to []dns.RR, records: %v", value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// if there's more than one record, rotate them (round-robin)
|
||||
if len(records) > 1 {
|
||||
first := records[0]
|
||||
records = append(records[1:], first)
|
||||
d.records.Store(key, records)
|
||||
}
|
||||
|
||||
return records
|
||||
}
|
||||
|
||||
// registerRecord stores a new record by appending it to any existing list
|
||||
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) {
|
||||
rr, err := dns.NewRR(record.String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("register record: %w", err)
|
||||
}
|
||||
|
||||
rr.Header().Rdlength = record.Len()
|
||||
header := rr.Header()
|
||||
key := buildRecordKey(header.Name, header.Class, header.Rrtype)
|
||||
|
||||
// load any existing slice of records, then append
|
||||
existing, _ := d.records.LoadOrStore(key, []dns.RR{})
|
||||
records := existing.([]dns.RR)
|
||||
records = append(records, rr)
|
||||
|
||||
// store updated slice
|
||||
d.records.Store(key, records)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// deleteRecord removes *all* records under the recordKey.
|
||||
func (d *localResolver) deleteRecord(recordKey string) {
|
||||
d.records.Delete(dns.Fqdn(recordKey))
|
||||
}
|
||||
|
||||
// buildRecordKey consistently generates a key: name_class_type
|
||||
func buildRecordKey(name string, class, qType uint16) string {
|
||||
return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType)
|
||||
}
|
||||
|
||||
func (d *localResolver) probeAvailability() {}
|
149
client/internal/dns/local/local.go
Normal file
149
client/internal/dns/local/local.go
Normal file
@ -0,0 +1,149 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
type Resolver struct {
|
||||
mu sync.RWMutex
|
||||
records map[dns.Question][]dns.RR
|
||||
}
|
||||
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question][]dns.RR),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Resolver) MatchSubdomains() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// String returns a string representation of the local resolver
|
||||
func (d *Resolver) String() string {
|
||||
return fmt.Sprintf("local resolver [%d records]", len(d.records))
|
||||
}
|
||||
|
||||
func (d *Resolver) Stop() {}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
func (d *Resolver) ID() types.HandlerID {
|
||||
return "local-resolver"
|
||||
}
|
||||
|
||||
func (d *Resolver) ProbeAvailability() {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
log.Debugf("received local resolver request with no question")
|
||||
return
|
||||
}
|
||||
question := r.Question[0]
|
||||
question.Name = strings.ToLower(dns.Fqdn(question.Name))
|
||||
|
||||
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
|
||||
|
||||
replyMessage := &dns.Msg{}
|
||||
replyMessage.SetReply(r)
|
||||
replyMessage.RecursionAvailable = true
|
||||
|
||||
// lookup all records matching the question
|
||||
records := d.lookupRecords(question)
|
||||
if len(records) > 0 {
|
||||
replyMessage.Rcode = dns.RcodeSuccess
|
||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
||||
} else {
|
||||
// TODO: return success if we have a different record type for the same name, relevant for search domains
|
||||
replyMessage.Rcode = dns.RcodeNameError
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(replyMessage); err != nil {
|
||||
log.Warnf("failed to write the local resolver response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
||||
d.mu.RLock()
|
||||
records, found := d.records[question]
|
||||
|
||||
if !found {
|
||||
d.mu.RUnlock()
|
||||
// alternatively check if we have a cname
|
||||
if question.Qtype != dns.TypeCNAME {
|
||||
question.Qtype = dns.TypeCNAME
|
||||
return d.lookupRecords(question)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
recordsCopy := slices.Clone(records)
|
||||
d.mu.RUnlock()
|
||||
|
||||
// if there's more than one record, rotate them (round-robin)
|
||||
if len(recordsCopy) > 1 {
|
||||
d.mu.Lock()
|
||||
records = d.records[question]
|
||||
if len(records) > 1 {
|
||||
first := records[0]
|
||||
records = append(records[1:], first)
|
||||
d.records[question] = records
|
||||
}
|
||||
d.mu.Unlock()
|
||||
}
|
||||
|
||||
return recordsCopy
|
||||
}
|
||||
|
||||
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
maps.Clear(d.records)
|
||||
|
||||
for _, rec := range update {
|
||||
if err := d.registerRecord(rec); err != nil {
|
||||
log.Warnf("failed to register the record (%s): %v", rec, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRecord stores a new record by appending it to any existing list
|
||||
func (d *Resolver) RegisterRecord(record nbdns.SimpleRecord) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
return d.registerRecord(record)
|
||||
}
|
||||
|
||||
// registerRecord performs the registration with the lock already held
|
||||
func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
|
||||
rr, err := dns.NewRR(record.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("register record: %w", err)
|
||||
}
|
||||
|
||||
rr.Header().Rdlength = record.Len()
|
||||
header := rr.Header()
|
||||
q := dns.Question{
|
||||
Name: strings.ToLower(dns.Fqdn(header.Name)),
|
||||
Qtype: header.Rrtype,
|
||||
Qclass: header.Class,
|
||||
}
|
||||
|
||||
d.records[q] = append(d.records[q], rr)
|
||||
|
||||
return nil
|
||||
}
|
472
client/internal/dns/local/local_test.go
Normal file
472
client/internal/dns/local/local_test.go
Normal file
@ -0,0 +1,472 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "peera.netbird.cloud.",
|
||||
Type: 1,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "1.2.3.4",
|
||||
}
|
||||
|
||||
recordCNAME := nbdns.SimpleRecord{
|
||||
Name: "peerb.netbird.cloud.",
|
||||
Type: 5,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "www.netbird.io",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputRecord nbdns.SimpleRecord
|
||||
inputMSG *dns.Msg
|
||||
responseShouldBeNil bool
|
||||
}{
|
||||
{
|
||||
name: "Should Resolve A Record",
|
||||
inputRecord: recordA,
|
||||
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
|
||||
},
|
||||
{
|
||||
name: "Should Resolve CNAME Record",
|
||||
inputRecord: recordCNAME,
|
||||
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
|
||||
},
|
||||
{
|
||||
name: "Should Not Write When Not Found A Record",
|
||||
inputRecord: recordA,
|
||||
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
_ = resolver.RegisterRecord(testCase.inputRecord)
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, testCase.inputMSG)
|
||||
|
||||
if responseMSG == nil || len(responseMSG.Answer) == 0 {
|
||||
if testCase.responseShouldBeNil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("should write a response message")
|
||||
}
|
||||
|
||||
answerString := responseMSG.Answer[0].String()
|
||||
if !strings.Contains(answerString, testCase.inputRecord.Name) {
|
||||
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
|
||||
}
|
||||
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
|
||||
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
|
||||
}
|
||||
if !strings.Contains(answerString, testCase.inputRecord.RData) {
|
||||
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_Update_StaleRecord verifies that updating
|
||||
// a record correctly replaces the old one, preventing stale entries.
|
||||
func TestLocalResolver_Update_StaleRecord(t *testing.T) {
|
||||
recordName := "host.example.com."
|
||||
recordType := dns.TypeA
|
||||
recordClass := dns.ClassINET
|
||||
|
||||
record1 := nbdns.SimpleRecord{
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1",
|
||||
}
|
||||
record2 := nbdns.SimpleRecord{
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "2.2.2.2",
|
||||
}
|
||||
|
||||
recordKey := dns.Question{Name: recordName, Qtype: uint16(recordClass), Qclass: recordType}
|
||||
|
||||
resolver := NewResolver()
|
||||
|
||||
update1 := []nbdns.SimpleRecord{record1}
|
||||
update2 := []nbdns.SimpleRecord{record2}
|
||||
|
||||
// Apply first update
|
||||
resolver.Update(update1)
|
||||
|
||||
// Verify first update
|
||||
resolver.mu.RLock()
|
||||
rrSlice1, found1 := resolver.records[recordKey]
|
||||
resolver.mu.RUnlock()
|
||||
|
||||
require.True(t, found1, "Record key %s not found after first update", recordKey)
|
||||
require.Len(t, rrSlice1, 1, "Should have exactly 1 record after first update")
|
||||
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
|
||||
|
||||
// Apply second update
|
||||
resolver.Update(update2)
|
||||
|
||||
// Verify second update
|
||||
resolver.mu.RLock()
|
||||
rrSlice2, found2 := resolver.records[recordKey]
|
||||
resolver.mu.RUnlock()
|
||||
|
||||
require.True(t, found2, "Record key %s not found after second update", recordKey)
|
||||
require.Len(t, rrSlice2, 1, "Should have exactly 1 record after update overwriting the key")
|
||||
assert.Contains(t, rrSlice2[0].String(), record2.RData, "The single record should be the updated one (%s)", record2.RData)
|
||||
assert.NotContains(t, rrSlice2[0].String(), record1.RData, "The stale record (%s) should not be present", record1.RData)
|
||||
}
|
||||
|
||||
// TestLocalResolver_MultipleRecords_SameQuestion verifies that multiple records
|
||||
// with the same question are stored properly
|
||||
func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
recordName := "multi.example.com."
|
||||
recordType := dns.TypeA
|
||||
|
||||
// Create two records with the same name and type but different IPs
|
||||
record1 := nbdns.SimpleRecord{
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1",
|
||||
}
|
||||
record2 := nbdns.SimpleRecord{
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
|
||||
}
|
||||
|
||||
update := []nbdns.SimpleRecord{record1, record2}
|
||||
|
||||
// Apply update with both records
|
||||
resolver.Update(update)
|
||||
|
||||
// Create question that matches both records
|
||||
question := dns.Question{
|
||||
Name: recordName,
|
||||
Qtype: recordType,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
|
||||
// Verify both records are stored
|
||||
resolver.mu.RLock()
|
||||
records, found := resolver.records[question]
|
||||
resolver.mu.RUnlock()
|
||||
|
||||
require.True(t, found, "Records for question %v not found", question)
|
||||
require.Len(t, records, 2, "Should have exactly 2 records for the same question")
|
||||
|
||||
// Verify both record data values are present
|
||||
recordStrings := []string{records[0].String(), records[1].String()}
|
||||
assert.Contains(t, recordStrings[0]+recordStrings[1], record1.RData, "First record data should be present")
|
||||
assert.Contains(t, recordStrings[0]+recordStrings[1], record2.RData, "Second record data should be present")
|
||||
}
|
||||
|
||||
// TestLocalResolver_RecordRotation verifies that records are rotated in a round-robin fashion
|
||||
func TestLocalResolver_RecordRotation(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
recordName := "rotation.example.com."
|
||||
recordType := dns.TypeA
|
||||
|
||||
// Create three records with the same name and type but different IPs
|
||||
record1 := nbdns.SimpleRecord{
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1",
|
||||
}
|
||||
record2 := nbdns.SimpleRecord{
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.2",
|
||||
}
|
||||
record3 := nbdns.SimpleRecord{
|
||||
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
|
||||
}
|
||||
|
||||
update := []nbdns.SimpleRecord{record1, record2, record3}
|
||||
|
||||
// Apply update with all three records
|
||||
resolver.Update(update)
|
||||
|
||||
msg := new(dns.Msg).SetQuestion(recordName, recordType)
|
||||
|
||||
// First lookup - should return the records in original order
|
||||
var responses [3]*dns.Msg
|
||||
|
||||
// Perform three lookups to verify rotation
|
||||
for i := 0; i < 3; i++ {
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responses[i] = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, msg)
|
||||
}
|
||||
|
||||
// Verify all three responses contain answers
|
||||
for i, resp := range responses {
|
||||
require.NotNil(t, resp, "Response %d should not be nil", i)
|
||||
require.Len(t, resp.Answer, 3, "Response %d should have 3 answers", i)
|
||||
}
|
||||
|
||||
// Verify the first record in each response is different due to rotation
|
||||
firstRecordIPs := []string{
|
||||
responses[0].Answer[0].String(),
|
||||
responses[1].Answer[0].String(),
|
||||
responses[2].Answer[0].String(),
|
||||
}
|
||||
|
||||
// Each record should be different (rotated)
|
||||
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[1], "First lookup should differ from second lookup due to rotation")
|
||||
assert.NotEqual(t, firstRecordIPs[1], firstRecordIPs[2], "Second lookup should differ from third lookup due to rotation")
|
||||
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[2], "First lookup should differ from third lookup due to rotation")
|
||||
|
||||
// After three rotations, we should have cycled through all records
|
||||
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record1.RData)
|
||||
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record2.RData)
|
||||
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record3.RData)
|
||||
}
|
||||
|
||||
// TestLocalResolver_CaseInsensitiveMatching verifies that DNS record lookups are case-insensitive
|
||||
func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// Create record with lowercase name
|
||||
lowerCaseRecord := nbdns.SimpleRecord{
|
||||
Name: "lower.example.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "10.10.10.10",
|
||||
}
|
||||
|
||||
// Create record with mixed case name
|
||||
mixedCaseRecord := nbdns.SimpleRecord{
|
||||
Name: "MiXeD.ExAmPlE.CoM.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "20.20.20.20",
|
||||
}
|
||||
|
||||
// Update resolver with the records
|
||||
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
queryName string
|
||||
expectedRData string
|
||||
shouldResolve bool
|
||||
}{
|
||||
{
|
||||
name: "Query lowercase with lowercase record",
|
||||
queryName: "lower.example.com.",
|
||||
expectedRData: "10.10.10.10",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query uppercase with lowercase record",
|
||||
queryName: "LOWER.EXAMPLE.COM.",
|
||||
expectedRData: "10.10.10.10",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query mixed case with lowercase record",
|
||||
queryName: "LoWeR.eXaMpLe.CoM.",
|
||||
expectedRData: "10.10.10.10",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query lowercase with mixed case record",
|
||||
queryName: "mixed.example.com.",
|
||||
expectedRData: "20.20.20.20",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query uppercase with mixed case record",
|
||||
queryName: "MIXED.EXAMPLE.COM.",
|
||||
expectedRData: "20.20.20.20",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query with different casing pattern",
|
||||
queryName: "mIxEd.ExaMpLe.cOm.",
|
||||
expectedRData: "20.20.20.20",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query non-existent domain",
|
||||
queryName: "nonexistent.example.com.",
|
||||
shouldResolve: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var responseMSG *dns.Msg
|
||||
|
||||
// Create DNS query with the test case name
|
||||
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
|
||||
|
||||
// Create mock response writer to capture the response
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Perform DNS query
|
||||
resolver.ServeDNS(responseWriter, msg)
|
||||
|
||||
// Check if we expect a successful resolution
|
||||
if !tc.shouldResolve {
|
||||
if responseMSG == nil || len(responseMSG.Answer) == 0 {
|
||||
// Expected no answer, test passes
|
||||
return
|
||||
}
|
||||
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
|
||||
}
|
||||
|
||||
// Verify we got a response
|
||||
require.NotNil(t, responseMSG, "Should have received a response message")
|
||||
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
|
||||
|
||||
// Verify the response contains the expected data
|
||||
answerString := responseMSG.Answer[0].String()
|
||||
assert.Contains(t, answerString, tc.expectedRData,
|
||||
"Answer should contain the expected IP address %s, got: %s",
|
||||
tc.expectedRData, answerString)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_CNAMEFallback verifies that the resolver correctly falls back
|
||||
// to checking for CNAME records when the requested record type isn't found
|
||||
func TestLocalResolver_CNAMEFallback(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// Create a CNAME record (but no A record for this name)
|
||||
cnameRecord := nbdns.SimpleRecord{
|
||||
Name: "alias.example.com.",
|
||||
Type: int(dns.TypeCNAME),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "target.example.com.",
|
||||
}
|
||||
|
||||
// Create an A record for the CNAME target
|
||||
targetRecord := nbdns.SimpleRecord{
|
||||
Name: "target.example.com.",
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "192.168.100.100",
|
||||
}
|
||||
|
||||
// Update resolver with both records
|
||||
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
queryName string
|
||||
queryType uint16
|
||||
expectedType string
|
||||
expectedRData string
|
||||
shouldResolve bool
|
||||
}{
|
||||
{
|
||||
name: "Directly query CNAME record",
|
||||
queryName: "alias.example.com.",
|
||||
queryType: dns.TypeCNAME,
|
||||
expectedType: "CNAME",
|
||||
expectedRData: "target.example.com.",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query A record but get CNAME fallback",
|
||||
queryName: "alias.example.com.",
|
||||
queryType: dns.TypeA,
|
||||
expectedType: "CNAME",
|
||||
expectedRData: "target.example.com.",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query AAAA record but get CNAME fallback",
|
||||
queryName: "alias.example.com.",
|
||||
queryType: dns.TypeAAAA,
|
||||
expectedType: "CNAME",
|
||||
expectedRData: "target.example.com.",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query direct A record",
|
||||
queryName: "target.example.com.",
|
||||
queryType: dns.TypeA,
|
||||
expectedType: "A",
|
||||
expectedRData: "192.168.100.100",
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: "Query non-existent name",
|
||||
queryName: "nonexistent.example.com.",
|
||||
queryType: dns.TypeA,
|
||||
shouldResolve: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var responseMSG *dns.Msg
|
||||
|
||||
// Create DNS query with the test case parameters
|
||||
msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType)
|
||||
|
||||
// Create mock response writer to capture the response
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Perform DNS query
|
||||
resolver.ServeDNS(responseWriter, msg)
|
||||
|
||||
// Check if we expect a successful resolution
|
||||
if !tc.shouldResolve {
|
||||
if responseMSG == nil || len(responseMSG.Answer) == 0 || responseMSG.Rcode != dns.RcodeSuccess {
|
||||
// Expected no resolution, test passes
|
||||
return
|
||||
}
|
||||
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
|
||||
}
|
||||
|
||||
// Verify we got a successful response
|
||||
require.NotNil(t, responseMSG, "Should have received a response message")
|
||||
require.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "Response should have success status code")
|
||||
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
|
||||
|
||||
// Verify the response contains the expected data
|
||||
answerString := responseMSG.Answer[0].String()
|
||||
assert.Contains(t, answerString, tc.expectedType,
|
||||
"Answer should be of type %s, got: %s", tc.expectedType, answerString)
|
||||
assert.Contains(t, answerString, tc.expectedRData,
|
||||
"Answer should contain the expected data %s, got: %s", tc.expectedRData, answerString)
|
||||
})
|
||||
}
|
||||
}
|
@ -1,88 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "peera.netbird.cloud.",
|
||||
Type: 1,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "1.2.3.4",
|
||||
}
|
||||
|
||||
recordCNAME := nbdns.SimpleRecord{
|
||||
Name: "peerb.netbird.cloud.",
|
||||
Type: 5,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "www.netbird.io",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputRecord nbdns.SimpleRecord
|
||||
inputMSG *dns.Msg
|
||||
responseShouldBeNil bool
|
||||
}{
|
||||
{
|
||||
name: "Should Resolve A Record",
|
||||
inputRecord: recordA,
|
||||
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
|
||||
},
|
||||
{
|
||||
name: "Should Resolve CNAME Record",
|
||||
inputRecord: recordCNAME,
|
||||
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
|
||||
},
|
||||
{
|
||||
name: "Should Not Write When Not Found A Record",
|
||||
inputRecord: recordA,
|
||||
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
resolver := &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
}
|
||||
_, _ = resolver.registerRecord(testCase.inputRecord)
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &mockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, testCase.inputMSG)
|
||||
|
||||
if responseMSG == nil || len(responseMSG.Answer) == 0 {
|
||||
if testCase.responseShouldBeNil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("should write a response message")
|
||||
}
|
||||
|
||||
answerString := responseMSG.Answer[0].String()
|
||||
if !strings.Contains(answerString, testCase.inputRecord.Name) {
|
||||
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
|
||||
}
|
||||
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
|
||||
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
|
||||
}
|
||||
if !strings.Contains(answerString, testCase.inputRecord.RData) {
|
||||
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,26 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type mockResponseWriter struct {
|
||||
WriteMsgFunc func(m *dns.Msg) error
|
||||
}
|
||||
|
||||
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||
if rw.WriteMsgFunc != nil {
|
||||
return rw.WriteMsgFunc(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
func (rw *mockResponseWriter) Close() error { return nil }
|
||||
func (rw *mockResponseWriter) TsigStatus() error { return nil }
|
||||
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
|
||||
func (rw *mockResponseWriter) Hijack() {}
|
@ -15,6 +15,8 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@ -46,8 +48,6 @@ type Server interface {
|
||||
ProbeAvailability()
|
||||
}
|
||||
|
||||
type handlerID string
|
||||
|
||||
type nsGroupsByDomain struct {
|
||||
domain string
|
||||
groups []*nbdns.NameServerGroup
|
||||
@ -61,7 +61,7 @@ type DefaultServer struct {
|
||||
mux sync.Mutex
|
||||
service service
|
||||
dnsMuxMap registeredHandlerMap
|
||||
localResolver *localResolver
|
||||
localResolver *local.Resolver
|
||||
wgInterface WGIface
|
||||
hostManager hostManager
|
||||
updateSerial uint64
|
||||
@ -84,9 +84,9 @@ type DefaultServer struct {
|
||||
|
||||
type handlerWithStop interface {
|
||||
dns.Handler
|
||||
stop()
|
||||
probeAvailability()
|
||||
id() handlerID
|
||||
Stop()
|
||||
ProbeAvailability()
|
||||
ID() types.HandlerID
|
||||
}
|
||||
|
||||
type handlerWrapper struct {
|
||||
@ -95,7 +95,7 @@ type handlerWrapper struct {
|
||||
priority int
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[handlerID]handlerWrapper
|
||||
type registeredHandlerMap map[types.HandlerID]handlerWrapper
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(
|
||||
@ -171,16 +171,14 @@ func newDefaultServer(
|
||||
handlerChain := NewHandlerChain()
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
disableSys: disableSys,
|
||||
service: dnsService,
|
||||
handlerChain: handlerChain,
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
disableSys: disableSys,
|
||||
service: dnsService,
|
||||
handlerChain: handlerChain,
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: local.NewResolver(),
|
||||
wgInterface: wgInterface,
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
@ -403,7 +401,7 @@ func (s *DefaultServer) ProbeAvailability() {
|
||||
wg.Add(1)
|
||||
go func(mux handlerWithStop) {
|
||||
defer wg.Done()
|
||||
mux.probeAvailability()
|
||||
mux.ProbeAvailability()
|
||||
}(mux.handler)
|
||||
}
|
||||
wg.Wait()
|
||||
@ -420,7 +418,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
s.service.Stop()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("local handler updater: %w", err)
|
||||
}
|
||||
@ -434,7 +432,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
s.updateMux(muxUpdates)
|
||||
|
||||
// register local records
|
||||
s.updateLocalResolver(localRecordsByDomain)
|
||||
s.localResolver.Update(localRecords)
|
||||
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||
|
||||
@ -516,11 +514,9 @@ func (s *DefaultServer) handleErrNoGroupaAll(err error) {
|
||||
)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(
|
||||
customZones []nbdns.CustomZone,
|
||||
) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) {
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
|
||||
var muxUpdates []handlerWrapper
|
||||
localRecords := make(map[string][]nbdns.SimpleRecord)
|
||||
var localRecords []nbdns.SimpleRecord
|
||||
|
||||
for _, customZone := range customZones {
|
||||
if len(customZone.Records) == 0 {
|
||||
@ -534,17 +530,13 @@ func (s *DefaultServer) buildLocalHandlerUpdate(
|
||||
priority: PriorityMatchDomain,
|
||||
})
|
||||
|
||||
// group all records under this domain
|
||||
for _, record := range customZone.Records {
|
||||
var class uint16 = dns.ClassINET
|
||||
if record.Class != nbdns.DefaultClass {
|
||||
log.Warnf("received an invalid class type: %s", record.Class)
|
||||
continue
|
||||
}
|
||||
|
||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||
|
||||
localRecords[key] = append(localRecords[key], record)
|
||||
// zone records contain the fqdn, so we can just flatten them
|
||||
localRecords = append(localRecords, record)
|
||||
}
|
||||
}
|
||||
|
||||
@ -627,7 +619,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
||||
}
|
||||
|
||||
if len(handler.upstreamServers) == 0 {
|
||||
handler.stop()
|
||||
handler.Stop()
|
||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
||||
continue
|
||||
}
|
||||
@ -656,7 +648,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||
for _, existing := range s.dnsMuxMap {
|
||||
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||
existing.handler.stop()
|
||||
existing.handler.Stop()
|
||||
}
|
||||
|
||||
muxUpdateMap := make(registeredHandlerMap)
|
||||
@ -667,7 +659,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
containsRootUpdate = true
|
||||
}
|
||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||
muxUpdateMap[update.handler.id()] = update
|
||||
muxUpdateMap[update.handler.ID()] = update
|
||||
}
|
||||
|
||||
// If there's no root update and we had a root handler, restore it
|
||||
@ -683,33 +675,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateLocalResolver(update map[string][]nbdns.SimpleRecord) {
|
||||
// remove old records that are no longer present
|
||||
for key := range s.localResolver.registeredMap {
|
||||
_, found := update[key]
|
||||
if !found {
|
||||
s.localResolver.deleteRecord(key)
|
||||
}
|
||||
}
|
||||
|
||||
updatedMap := make(registrationMap)
|
||||
for _, recs := range update {
|
||||
for _, rec := range recs {
|
||||
// convert the record to a dns.RR and register
|
||||
key, err := s.localResolver.registerRecord(rec)
|
||||
if err != nil {
|
||||
log.Warnf("got an error while registering the record (%s), error: %v",
|
||||
rec.String(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
updatedMap[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
s.localResolver.registeredMap = updatedMap
|
||||
}
|
||||
|
||||
func getNSHostPort(ns nbdns.NameServer) string {
|
||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||
}
|
||||
|
@ -23,6 +23,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@ -107,6 +110,7 @@ func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamRe
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
@ -120,22 +124,21 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
dummyHandler := &localResolver{}
|
||||
dummyHandler := local.NewResolver()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap registeredHandlerMap
|
||||
initLocalMap registrationMap
|
||||
initLocalRecords []nbdns.SimpleRecord
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Config
|
||||
shouldFail bool
|
||||
expectedUpstreamMap registeredHandlerMap
|
||||
expectedLocalMap registrationMap
|
||||
expectedLocalQs []dns.Question
|
||||
}{
|
||||
{
|
||||
name: "Initial Config Should Succeed",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
@ -159,30 +162,30 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
dummyHandler.id(): handlerWrapper{
|
||||
dummyHandler.ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: nbdns.RootZone,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
name: "New Config Should Succeed",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||
domain: buildRecordKey(zoneRecords[0].Name, 1, 1),
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
@ -205,7 +208,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
@ -216,22 +219,22 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@ -249,11 +252,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@ -271,11 +274,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalRecords: []nbdns.SimpleRecord{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@ -290,17 +293,17 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||
domain: ".",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
@ -310,13 +313,13 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalMap: make(registrationMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
@ -326,7 +329,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
expectedUpstreamMap: make(registeredHandlerMap),
|
||||
expectedLocalMap: make(registrationMap),
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
}
|
||||
|
||||
@ -377,7 +380,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
||||
dnsServer.localResolver.Update(testCase.initLocalRecords)
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
@ -399,15 +402,23 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) {
|
||||
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap))
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
for _, q := range testCase.expectedLocalQs {
|
||||
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||
Question: []dns.Question{q},
|
||||
})
|
||||
}
|
||||
|
||||
for key := range testCase.expectedLocalMap {
|
||||
_, found := dnsServer.localResolver.registeredMap[key]
|
||||
if !found {
|
||||
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap)
|
||||
}
|
||||
if len(testCase.expectedLocalQs) > 0 {
|
||||
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -491,11 +502,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||
"id1": handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &localResolver{},
|
||||
handler: &local.Resolver{},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
}
|
||||
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
||||
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
|
||||
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
@ -582,7 +594,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
defer dnsServer.Stop()
|
||||
_, err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
err = dnsServer.localResolver.RegisterRecord(zoneRecords[0])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -630,13 +642,11 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
hostManager := &mockHostConfigurator{}
|
||||
server := DefaultServer{
|
||||
ctx: context.Background(),
|
||||
service: NewServiceViaMemory(&mocWGIface{}),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: hostManager,
|
||||
ctx: context.Background(),
|
||||
service: NewServiceViaMemory(&mocWGIface{}),
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: hostManager,
|
||||
currentConfig: HostDNSConfig{
|
||||
Domains: []DomainConfig{
|
||||
{false, "domain0", false},
|
||||
@ -1004,7 +1014,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tc.query, dns.TypeA)
|
||||
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w := &ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||
|
||||
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||
mh.On("ServeDNS", mock.Anything, r).Once()
|
||||
@ -1037,9 +1047,9 @@ type mockHandler struct {
|
||||
}
|
||||
|
||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||
func (m *mockHandler) stop() {}
|
||||
func (m *mockHandler) probeAvailability() {}
|
||||
func (m *mockHandler) id() handlerID { return handlerID(m.Id) }
|
||||
func (m *mockHandler) Stop() {}
|
||||
func (m *mockHandler) ProbeAvailability() {}
|
||||
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
||||
|
||||
type mockService struct{}
|
||||
|
||||
@ -1113,7 +1123,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
name string
|
||||
initialHandlers registeredHandlerMap
|
||||
updates []handlerWrapper
|
||||
expectedHandlers map[string]string // map[handlerID]domain
|
||||
expectedHandlers map[string]string // map[HandlerID]domain
|
||||
description string
|
||||
}{
|
||||
{
|
||||
@ -1409,7 +1419,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
|
||||
// Check each expected handler
|
||||
for id, expectedDomain := range tt.expectedHandlers {
|
||||
handler, exists := server.dnsMuxMap[handlerID(id)]
|
||||
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
|
||||
assert.True(t, exists, "Expected handler %s not found", id)
|
||||
if exists {
|
||||
assert.Equal(t, expectedDomain, handler.domain,
|
||||
@ -1418,9 +1428,9 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify no unexpected handlers exist
|
||||
for handlerID := range server.dnsMuxMap {
|
||||
_, expected := tt.expectedHandlers[string(handlerID)]
|
||||
assert.True(t, expected, "Unexpected handler found: %s", handlerID)
|
||||
for HandlerID := range server.dnsMuxMap {
|
||||
_, expected := tt.expectedHandlers[string(HandlerID)]
|
||||
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
|
||||
}
|
||||
|
||||
// Verify the handlerChain state and order
|
||||
@ -1696,7 +1706,7 @@ func TestExtraDomains(t *testing.T) {
|
||||
handlerChain: NewHandlerChain(),
|
||||
wgInterface: &mocWGIface{},
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
localResolver: &local.Resolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
@ -1781,7 +1791,7 @@ func TestExtraDomainsRefCounting(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
localResolver: &local.Resolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
@ -1833,7 +1843,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
localResolver: &local.Resolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
@ -1916,7 +1926,7 @@ func TestDomainCaseHandling(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
localResolver: &local.Resolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
|
26
client/internal/dns/test/mock.go
Normal file
26
client/internal/dns/test/mock.go
Normal file
@ -0,0 +1,26 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type MockResponseWriter struct {
|
||||
WriteMsgFunc func(m *dns.Msg) error
|
||||
}
|
||||
|
||||
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||
if rw.WriteMsgFunc != nil {
|
||||
return rw.WriteMsgFunc(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
func (rw *MockResponseWriter) Close() error { return nil }
|
||||
func (rw *MockResponseWriter) TsigStatus() error { return nil }
|
||||
func (rw *MockResponseWriter) TsigTimersOnly(bool) {}
|
||||
func (rw *MockResponseWriter) Hijack() {}
|
3
client/internal/dns/types/types.go
Normal file
3
client/internal/dns/types/types.go
Normal file
@ -0,0 +1,3 @@
|
||||
package types
|
||||
|
||||
type HandlerID string
|
@ -19,6 +19,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
@ -81,21 +82,21 @@ func (u *upstreamResolverBase) String() string {
|
||||
}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
func (u *upstreamResolverBase) id() handlerID {
|
||||
func (u *upstreamResolverBase) ID() types.HandlerID {
|
||||
servers := slices.Clone(u.upstreamServers)
|
||||
slices.Sort(servers)
|
||||
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(u.domain + ":"))
|
||||
hash.Write([]byte(strings.Join(servers, ",")))
|
||||
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) stop() {
|
||||
func (u *upstreamResolverBase) Stop() {
|
||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||
u.cancel()
|
||||
}
|
||||
@ -198,9 +199,9 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
|
||||
)
|
||||
}
|
||||
|
||||
// probeAvailability tests all upstream servers simultaneously and
|
||||
// ProbeAvailability tests all upstream servers simultaneously and
|
||||
// disables the resolver if none work
|
||||
func (u *upstreamResolverBase) probeAvailability() {
|
||||
func (u *upstreamResolverBase) ProbeAvailability() {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
|
@ -8,6 +8,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
)
|
||||
|
||||
func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
@ -66,7 +68,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &mockResponseWriter{
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
@ -130,7 +132,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
resolver.failsTillDeact = 0
|
||||
resolver.reactivatePeriod = time.Microsecond * 100
|
||||
|
||||
responseWriter := &mockResponseWriter{
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error { return nil },
|
||||
}
|
||||
|
||||
|
@ -66,17 +66,17 @@ func (s SimpleRecord) String() string {
|
||||
func (s SimpleRecord) Len() uint16 {
|
||||
emptyString := s.RData == ""
|
||||
switch s.Type {
|
||||
case 1:
|
||||
case int(dns.TypeA):
|
||||
if emptyString {
|
||||
return 0
|
||||
}
|
||||
return net.IPv4len
|
||||
case 5:
|
||||
case int(dns.TypeCNAME):
|
||||
if emptyString || s.RData == "." {
|
||||
return 1
|
||||
}
|
||||
return uint16(len(s.RData) + 1)
|
||||
case 28:
|
||||
case int(dns.TypeAAAA):
|
||||
if emptyString {
|
||||
return 0
|
||||
}
|
||||
|
Reference in New Issue
Block a user