mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-17 18:41:41 +02:00
[client] Fix stale local records (#3776)
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
package dns_test
|
package dns_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@@ -9,6 +8,7 @@ import (
|
|||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||||
@@ -30,7 +30,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
|||||||
r.SetQuestion("example.com.", dns.TypeA)
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
// Create test writer
|
// Create test writer
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
// Setup expectations - only highest priority handler should be called
|
// Setup expectations - only highest priority handler should be called
|
||||||
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
@@ -142,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
|
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
@@ -259,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
// Create and execute request
|
// Create and execute request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
// Verify expectations
|
// Verify expectations
|
||||||
@@ -316,7 +316,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
|||||||
}).Once()
|
}).Once()
|
||||||
|
|
||||||
// Execute
|
// Execute
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
// Verify all handlers were called in order
|
// Verify all handlers were called in order
|
||||||
@@ -325,20 +325,6 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
|||||||
handler3.AssertExpectations(t)
|
handler3.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// mockResponseWriter implements dns.ResponseWriter for testing
|
|
||||||
type mockResponseWriter struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
|
||||||
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
|
||||||
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
|
|
||||||
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
|
||||||
func (m *mockResponseWriter) Close() error { return nil }
|
|
||||||
func (m *mockResponseWriter) TsigStatus() error { return nil }
|
|
||||||
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
|
|
||||||
func (m *mockResponseWriter) Hijack() {}
|
|
||||||
|
|
||||||
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -425,7 +411,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
// Create test request
|
// Create test request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.query, dns.TypeA)
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
// Setup expectations
|
// Setup expectations
|
||||||
for priority, handler := range handlers {
|
for priority, handler := range handlers {
|
||||||
@@ -471,7 +457,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||||
|
|
||||||
// Test 1: Initial state
|
// Test 1: Initial state
|
||||||
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
// Highest priority handler (routeHandler) should be called
|
// Highest priority handler (routeHandler) should be called
|
||||||
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||||
@@ -490,7 +476,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
// Test 2: Remove highest priority handler
|
// Test 2: Remove highest priority handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||||
|
|
||||||
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
// Now middle priority handler (matchHandler) should be called
|
// Now middle priority handler (matchHandler) should be called
|
||||||
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
|
||||||
@@ -506,7 +492,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
// Test 3: Remove middle priority handler
|
// Test 3: Remove middle priority handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||||
|
|
||||||
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
// Now lowest priority handler (defaultHandler) should be called
|
// Now lowest priority handler (defaultHandler) should be called
|
||||||
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
@@ -519,7 +505,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
// Test 4: Remove last handler
|
// Test 4: Remove last handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||||
|
|
||||||
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
|
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
|
||||||
|
|
||||||
for _, m := range mocks {
|
for _, m := range mocks {
|
||||||
@@ -675,7 +661,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
// Execute request
|
// Execute request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.query, dns.TypeA)
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
chain.ServeDNS(&mockResponseWriter{}, r)
|
chain.ServeDNS(&test.MockResponseWriter{}, r)
|
||||||
|
|
||||||
// Verify each handler was called exactly as expected
|
// Verify each handler was called exactly as expected
|
||||||
for _, h := range tt.addHandlers {
|
for _, h := range tt.addHandlers {
|
||||||
@@ -819,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
|
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.query, dns.TypeA)
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
// Setup handler expectations
|
// Setup handler expectations
|
||||||
for pattern, handler := range handlers {
|
for pattern, handler := range handlers {
|
||||||
@@ -969,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
|||||||
handler := &nbdns.MockHandler{}
|
handler := &nbdns.MockHandler{}
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryPattern, dns.TypeA)
|
r.SetQuestion(tt.queryPattern, dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
// First verify no handler is called before adding any
|
// First verify no handler is called before adding any
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w, r)
|
||||||
|
@@ -1,130 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
type registrationMap map[string]struct{}
|
|
||||||
|
|
||||||
type localResolver struct {
|
|
||||||
registeredMap registrationMap
|
|
||||||
records sync.Map // key: string (domain_class_type), value: []dns.RR
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *localResolver) MatchSubdomains() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *localResolver) stop() {
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns a string representation of the local resolver
|
|
||||||
func (d *localResolver) String() string {
|
|
||||||
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ID returns the unique handler ID
|
|
||||||
func (d *localResolver) id() handlerID {
|
|
||||||
return "local-resolver"
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
|
||||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
if len(r.Question) > 0 {
|
|
||||||
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
|
||||||
}
|
|
||||||
|
|
||||||
replyMessage := &dns.Msg{}
|
|
||||||
replyMessage.SetReply(r)
|
|
||||||
replyMessage.RecursionAvailable = true
|
|
||||||
|
|
||||||
// lookup all records matching the question
|
|
||||||
records := d.lookupRecords(r)
|
|
||||||
if len(records) > 0 {
|
|
||||||
replyMessage.Rcode = dns.RcodeSuccess
|
|
||||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
|
||||||
} else {
|
|
||||||
replyMessage.Rcode = dns.RcodeNameError
|
|
||||||
}
|
|
||||||
|
|
||||||
err := w.WriteMsg(replyMessage)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("got an error while writing the local resolver response, error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
|
||||||
func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
|
|
||||||
if len(r.Question) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
question := r.Question[0]
|
|
||||||
question.Name = strings.ToLower(question.Name)
|
|
||||||
key := buildRecordKey(question.Name, question.Qclass, question.Qtype)
|
|
||||||
|
|
||||||
value, found := d.records.Load(key)
|
|
||||||
if !found {
|
|
||||||
// alternatively check if we have a cname
|
|
||||||
if question.Qtype != dns.TypeCNAME {
|
|
||||||
r.Question[0].Qtype = dns.TypeCNAME
|
|
||||||
return d.lookupRecords(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
records, ok := value.([]dns.RR)
|
|
||||||
if !ok {
|
|
||||||
log.Errorf("failed to cast records to []dns.RR, records: %v", value)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// if there's more than one record, rotate them (round-robin)
|
|
||||||
if len(records) > 1 {
|
|
||||||
first := records[0]
|
|
||||||
records = append(records[1:], first)
|
|
||||||
d.records.Store(key, records)
|
|
||||||
}
|
|
||||||
|
|
||||||
return records
|
|
||||||
}
|
|
||||||
|
|
||||||
// registerRecord stores a new record by appending it to any existing list
|
|
||||||
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) {
|
|
||||||
rr, err := dns.NewRR(record.String())
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("register record: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rr.Header().Rdlength = record.Len()
|
|
||||||
header := rr.Header()
|
|
||||||
key := buildRecordKey(header.Name, header.Class, header.Rrtype)
|
|
||||||
|
|
||||||
// load any existing slice of records, then append
|
|
||||||
existing, _ := d.records.LoadOrStore(key, []dns.RR{})
|
|
||||||
records := existing.([]dns.RR)
|
|
||||||
records = append(records, rr)
|
|
||||||
|
|
||||||
// store updated slice
|
|
||||||
d.records.Store(key, records)
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteRecord removes *all* records under the recordKey.
|
|
||||||
func (d *localResolver) deleteRecord(recordKey string) {
|
|
||||||
d.records.Delete(dns.Fqdn(recordKey))
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildRecordKey consistently generates a key: name_class_type
|
|
||||||
func buildRecordKey(name string, class, qType uint16) string {
|
|
||||||
return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *localResolver) probeAvailability() {}
|
|
149
client/internal/dns/local/local.go
Normal file
149
client/internal/dns/local/local.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package local
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Resolver struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
records map[dns.Question][]dns.RR
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewResolver() *Resolver {
|
||||||
|
return &Resolver{
|
||||||
|
records: make(map[dns.Question][]dns.RR),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Resolver) MatchSubdomains() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the local resolver
|
||||||
|
func (d *Resolver) String() string {
|
||||||
|
return fmt.Sprintf("local resolver [%d records]", len(d.records))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Resolver) Stop() {}
|
||||||
|
|
||||||
|
// ID returns the unique handler ID
|
||||||
|
func (d *Resolver) ID() types.HandlerID {
|
||||||
|
return "local-resolver"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Resolver) ProbeAvailability() {}
|
||||||
|
|
||||||
|
// ServeDNS handles a DNS request
|
||||||
|
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
if len(r.Question) == 0 {
|
||||||
|
log.Debugf("received local resolver request with no question")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
question := r.Question[0]
|
||||||
|
question.Name = strings.ToLower(dns.Fqdn(question.Name))
|
||||||
|
|
||||||
|
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
|
||||||
|
|
||||||
|
replyMessage := &dns.Msg{}
|
||||||
|
replyMessage.SetReply(r)
|
||||||
|
replyMessage.RecursionAvailable = true
|
||||||
|
|
||||||
|
// lookup all records matching the question
|
||||||
|
records := d.lookupRecords(question)
|
||||||
|
if len(records) > 0 {
|
||||||
|
replyMessage.Rcode = dns.RcodeSuccess
|
||||||
|
replyMessage.Answer = append(replyMessage.Answer, records...)
|
||||||
|
} else {
|
||||||
|
// TODO: return success if we have a different record type for the same name, relevant for search domains
|
||||||
|
replyMessage.Rcode = dns.RcodeNameError
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(replyMessage); err != nil {
|
||||||
|
log.Warnf("failed to write the local resolver response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||||
|
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
||||||
|
d.mu.RLock()
|
||||||
|
records, found := d.records[question]
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
d.mu.RUnlock()
|
||||||
|
// alternatively check if we have a cname
|
||||||
|
if question.Qtype != dns.TypeCNAME {
|
||||||
|
question.Qtype = dns.TypeCNAME
|
||||||
|
return d.lookupRecords(question)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
recordsCopy := slices.Clone(records)
|
||||||
|
d.mu.RUnlock()
|
||||||
|
|
||||||
|
// if there's more than one record, rotate them (round-robin)
|
||||||
|
if len(recordsCopy) > 1 {
|
||||||
|
d.mu.Lock()
|
||||||
|
records = d.records[question]
|
||||||
|
if len(records) > 1 {
|
||||||
|
first := records[0]
|
||||||
|
records = append(records[1:], first)
|
||||||
|
d.records[question] = records
|
||||||
|
}
|
||||||
|
d.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
return recordsCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
maps.Clear(d.records)
|
||||||
|
|
||||||
|
for _, rec := range update {
|
||||||
|
if err := d.registerRecord(rec); err != nil {
|
||||||
|
log.Warnf("failed to register the record (%s): %v", rec, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRecord stores a new record by appending it to any existing list
|
||||||
|
func (d *Resolver) RegisterRecord(record nbdns.SimpleRecord) error {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
return d.registerRecord(record)
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerRecord performs the registration with the lock already held
|
||||||
|
func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
|
||||||
|
rr, err := dns.NewRR(record.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("register record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rr.Header().Rdlength = record.Len()
|
||||||
|
header := rr.Header()
|
||||||
|
q := dns.Question{
|
||||||
|
Name: strings.ToLower(dns.Fqdn(header.Name)),
|
||||||
|
Qtype: header.Rrtype,
|
||||||
|
Qclass: header.Class,
|
||||||
|
}
|
||||||
|
|
||||||
|
d.records[q] = append(d.records[q], rr)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
472
client/internal/dns/local/local_test.go
Normal file
472
client/internal/dns/local/local_test.go
Normal file
@@ -0,0 +1,472 @@
|
|||||||
|
package local
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||||
|
recordA := nbdns.SimpleRecord{
|
||||||
|
Name: "peera.netbird.cloud.",
|
||||||
|
Type: 1,
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "1.2.3.4",
|
||||||
|
}
|
||||||
|
|
||||||
|
recordCNAME := nbdns.SimpleRecord{
|
||||||
|
Name: "peerb.netbird.cloud.",
|
||||||
|
Type: 5,
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "www.netbird.io",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
inputRecord nbdns.SimpleRecord
|
||||||
|
inputMSG *dns.Msg
|
||||||
|
responseShouldBeNil bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Should Resolve A Record",
|
||||||
|
inputRecord: recordA,
|
||||||
|
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should Resolve CNAME Record",
|
||||||
|
inputRecord: recordCNAME,
|
||||||
|
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should Not Write When Not Found A Record",
|
||||||
|
inputRecord: recordA,
|
||||||
|
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
|
||||||
|
responseShouldBeNil: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
_ = resolver.RegisterRecord(testCase.inputRecord)
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.ServeDNS(responseWriter, testCase.inputMSG)
|
||||||
|
|
||||||
|
if responseMSG == nil || len(responseMSG.Answer) == 0 {
|
||||||
|
if testCase.responseShouldBeNil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("should write a response message")
|
||||||
|
}
|
||||||
|
|
||||||
|
answerString := responseMSG.Answer[0].String()
|
||||||
|
if !strings.Contains(answerString, testCase.inputRecord.Name) {
|
||||||
|
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
|
||||||
|
}
|
||||||
|
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
|
||||||
|
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
|
||||||
|
}
|
||||||
|
if !strings.Contains(answerString, testCase.inputRecord.RData) {
|
||||||
|
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_Update_StaleRecord verifies that updating
|
||||||
|
// a record correctly replaces the old one, preventing stale entries.
|
||||||
|
func TestLocalResolver_Update_StaleRecord(t *testing.T) {
|
||||||
|
recordName := "host.example.com."
|
||||||
|
recordType := dns.TypeA
|
||||||
|
recordClass := dns.ClassINET
|
||||||
|
|
||||||
|
record1 := nbdns.SimpleRecord{
|
||||||
|
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1",
|
||||||
|
}
|
||||||
|
record2 := nbdns.SimpleRecord{
|
||||||
|
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "2.2.2.2",
|
||||||
|
}
|
||||||
|
|
||||||
|
recordKey := dns.Question{Name: recordName, Qtype: uint16(recordClass), Qclass: recordType}
|
||||||
|
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
update1 := []nbdns.SimpleRecord{record1}
|
||||||
|
update2 := []nbdns.SimpleRecord{record2}
|
||||||
|
|
||||||
|
// Apply first update
|
||||||
|
resolver.Update(update1)
|
||||||
|
|
||||||
|
// Verify first update
|
||||||
|
resolver.mu.RLock()
|
||||||
|
rrSlice1, found1 := resolver.records[recordKey]
|
||||||
|
resolver.mu.RUnlock()
|
||||||
|
|
||||||
|
require.True(t, found1, "Record key %s not found after first update", recordKey)
|
||||||
|
require.Len(t, rrSlice1, 1, "Should have exactly 1 record after first update")
|
||||||
|
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
|
||||||
|
|
||||||
|
// Apply second update
|
||||||
|
resolver.Update(update2)
|
||||||
|
|
||||||
|
// Verify second update
|
||||||
|
resolver.mu.RLock()
|
||||||
|
rrSlice2, found2 := resolver.records[recordKey]
|
||||||
|
resolver.mu.RUnlock()
|
||||||
|
|
||||||
|
require.True(t, found2, "Record key %s not found after second update", recordKey)
|
||||||
|
require.Len(t, rrSlice2, 1, "Should have exactly 1 record after update overwriting the key")
|
||||||
|
assert.Contains(t, rrSlice2[0].String(), record2.RData, "The single record should be the updated one (%s)", record2.RData)
|
||||||
|
assert.NotContains(t, rrSlice2[0].String(), record1.RData, "The stale record (%s) should not be present", record1.RData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_MultipleRecords_SameQuestion verifies that multiple records
|
||||||
|
// with the same question are stored properly
|
||||||
|
func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
recordName := "multi.example.com."
|
||||||
|
recordType := dns.TypeA
|
||||||
|
|
||||||
|
// Create two records with the same name and type but different IPs
|
||||||
|
record1 := nbdns.SimpleRecord{
|
||||||
|
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1",
|
||||||
|
}
|
||||||
|
record2 := nbdns.SimpleRecord{
|
||||||
|
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
|
||||||
|
}
|
||||||
|
|
||||||
|
update := []nbdns.SimpleRecord{record1, record2}
|
||||||
|
|
||||||
|
// Apply update with both records
|
||||||
|
resolver.Update(update)
|
||||||
|
|
||||||
|
// Create question that matches both records
|
||||||
|
question := dns.Question{
|
||||||
|
Name: recordName,
|
||||||
|
Qtype: recordType,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify both records are stored
|
||||||
|
resolver.mu.RLock()
|
||||||
|
records, found := resolver.records[question]
|
||||||
|
resolver.mu.RUnlock()
|
||||||
|
|
||||||
|
require.True(t, found, "Records for question %v not found", question)
|
||||||
|
require.Len(t, records, 2, "Should have exactly 2 records for the same question")
|
||||||
|
|
||||||
|
// Verify both record data values are present
|
||||||
|
recordStrings := []string{records[0].String(), records[1].String()}
|
||||||
|
assert.Contains(t, recordStrings[0]+recordStrings[1], record1.RData, "First record data should be present")
|
||||||
|
assert.Contains(t, recordStrings[0]+recordStrings[1], record2.RData, "Second record data should be present")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_RecordRotation verifies that records are rotated in a round-robin fashion
|
||||||
|
func TestLocalResolver_RecordRotation(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
recordName := "rotation.example.com."
|
||||||
|
recordType := dns.TypeA
|
||||||
|
|
||||||
|
// Create three records with the same name and type but different IPs
|
||||||
|
record1 := nbdns.SimpleRecord{
|
||||||
|
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1",
|
||||||
|
}
|
||||||
|
record2 := nbdns.SimpleRecord{
|
||||||
|
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.2",
|
||||||
|
}
|
||||||
|
record3 := nbdns.SimpleRecord{
|
||||||
|
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
|
||||||
|
}
|
||||||
|
|
||||||
|
update := []nbdns.SimpleRecord{record1, record2, record3}
|
||||||
|
|
||||||
|
// Apply update with all three records
|
||||||
|
resolver.Update(update)
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion(recordName, recordType)
|
||||||
|
|
||||||
|
// First lookup - should return the records in original order
|
||||||
|
var responses [3]*dns.Msg
|
||||||
|
|
||||||
|
// Perform three lookups to verify rotation
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responses[i] = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.ServeDNS(responseWriter, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all three responses contain answers
|
||||||
|
for i, resp := range responses {
|
||||||
|
require.NotNil(t, resp, "Response %d should not be nil", i)
|
||||||
|
require.Len(t, resp.Answer, 3, "Response %d should have 3 answers", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the first record in each response is different due to rotation
|
||||||
|
firstRecordIPs := []string{
|
||||||
|
responses[0].Answer[0].String(),
|
||||||
|
responses[1].Answer[0].String(),
|
||||||
|
responses[2].Answer[0].String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each record should be different (rotated)
|
||||||
|
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[1], "First lookup should differ from second lookup due to rotation")
|
||||||
|
assert.NotEqual(t, firstRecordIPs[1], firstRecordIPs[2], "Second lookup should differ from third lookup due to rotation")
|
||||||
|
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[2], "First lookup should differ from third lookup due to rotation")
|
||||||
|
|
||||||
|
// After three rotations, we should have cycled through all records
|
||||||
|
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record1.RData)
|
||||||
|
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record2.RData)
|
||||||
|
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record3.RData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_CaseInsensitiveMatching verifies that DNS record lookups are case-insensitive
|
||||||
|
func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
// Create record with lowercase name
|
||||||
|
lowerCaseRecord := nbdns.SimpleRecord{
|
||||||
|
Name: "lower.example.com.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "10.10.10.10",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create record with mixed case name
|
||||||
|
mixedCaseRecord := nbdns.SimpleRecord{
|
||||||
|
Name: "MiXeD.ExAmPlE.CoM.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "20.20.20.20",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update resolver with the records
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
queryName string
|
||||||
|
expectedRData string
|
||||||
|
shouldResolve bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Query lowercase with lowercase record",
|
||||||
|
queryName: "lower.example.com.",
|
||||||
|
expectedRData: "10.10.10.10",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query uppercase with lowercase record",
|
||||||
|
queryName: "LOWER.EXAMPLE.COM.",
|
||||||
|
expectedRData: "10.10.10.10",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query mixed case with lowercase record",
|
||||||
|
queryName: "LoWeR.eXaMpLe.CoM.",
|
||||||
|
expectedRData: "10.10.10.10",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query lowercase with mixed case record",
|
||||||
|
queryName: "mixed.example.com.",
|
||||||
|
expectedRData: "20.20.20.20",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query uppercase with mixed case record",
|
||||||
|
queryName: "MIXED.EXAMPLE.COM.",
|
||||||
|
expectedRData: "20.20.20.20",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query with different casing pattern",
|
||||||
|
queryName: "mIxEd.ExaMpLe.cOm.",
|
||||||
|
expectedRData: "20.20.20.20",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query non-existent domain",
|
||||||
|
queryName: "nonexistent.example.com.",
|
||||||
|
shouldResolve: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
|
||||||
|
// Create DNS query with the test case name
|
||||||
|
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
|
||||||
|
|
||||||
|
// Create mock response writer to capture the response
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform DNS query
|
||||||
|
resolver.ServeDNS(responseWriter, msg)
|
||||||
|
|
||||||
|
// Check if we expect a successful resolution
|
||||||
|
if !tc.shouldResolve {
|
||||||
|
if responseMSG == nil || len(responseMSG.Answer) == 0 {
|
||||||
|
// Expected no answer, test passes
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify we got a response
|
||||||
|
require.NotNil(t, responseMSG, "Should have received a response message")
|
||||||
|
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
|
||||||
|
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
answerString := responseMSG.Answer[0].String()
|
||||||
|
assert.Contains(t, answerString, tc.expectedRData,
|
||||||
|
"Answer should contain the expected IP address %s, got: %s",
|
||||||
|
tc.expectedRData, answerString)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_CNAMEFallback verifies that the resolver correctly falls back
|
||||||
|
// to checking for CNAME records when the requested record type isn't found
|
||||||
|
func TestLocalResolver_CNAMEFallback(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
// Create a CNAME record (but no A record for this name)
|
||||||
|
cnameRecord := nbdns.SimpleRecord{
|
||||||
|
Name: "alias.example.com.",
|
||||||
|
Type: int(dns.TypeCNAME),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "target.example.com.",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an A record for the CNAME target
|
||||||
|
targetRecord := nbdns.SimpleRecord{
|
||||||
|
Name: "target.example.com.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.100.100",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update resolver with both records
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
queryName string
|
||||||
|
queryType uint16
|
||||||
|
expectedType string
|
||||||
|
expectedRData string
|
||||||
|
shouldResolve bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Directly query CNAME record",
|
||||||
|
queryName: "alias.example.com.",
|
||||||
|
queryType: dns.TypeCNAME,
|
||||||
|
expectedType: "CNAME",
|
||||||
|
expectedRData: "target.example.com.",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query A record but get CNAME fallback",
|
||||||
|
queryName: "alias.example.com.",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
expectedType: "CNAME",
|
||||||
|
expectedRData: "target.example.com.",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query AAAA record but get CNAME fallback",
|
||||||
|
queryName: "alias.example.com.",
|
||||||
|
queryType: dns.TypeAAAA,
|
||||||
|
expectedType: "CNAME",
|
||||||
|
expectedRData: "target.example.com.",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query direct A record",
|
||||||
|
queryName: "target.example.com.",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
expectedType: "A",
|
||||||
|
expectedRData: "192.168.100.100",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query non-existent name",
|
||||||
|
queryName: "nonexistent.example.com.",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
shouldResolve: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
|
||||||
|
// Create DNS query with the test case parameters
|
||||||
|
msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType)
|
||||||
|
|
||||||
|
// Create mock response writer to capture the response
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform DNS query
|
||||||
|
resolver.ServeDNS(responseWriter, msg)
|
||||||
|
|
||||||
|
// Check if we expect a successful resolution
|
||||||
|
if !tc.shouldResolve {
|
||||||
|
if responseMSG == nil || len(responseMSG.Answer) == 0 || responseMSG.Rcode != dns.RcodeSuccess {
|
||||||
|
// Expected no resolution, test passes
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify we got a successful response
|
||||||
|
require.NotNil(t, responseMSG, "Should have received a response message")
|
||||||
|
require.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "Response should have success status code")
|
||||||
|
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
|
||||||
|
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
answerString := responseMSG.Answer[0].String()
|
||||||
|
assert.Contains(t, answerString, tc.expectedType,
|
||||||
|
"Answer should be of type %s, got: %s", tc.expectedType, answerString)
|
||||||
|
assert.Contains(t, answerString, tc.expectedRData,
|
||||||
|
"Answer should contain the expected data %s, got: %s", tc.expectedRData, answerString)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@@ -1,88 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
|
||||||
recordA := nbdns.SimpleRecord{
|
|
||||||
Name: "peera.netbird.cloud.",
|
|
||||||
Type: 1,
|
|
||||||
Class: nbdns.DefaultClass,
|
|
||||||
TTL: 300,
|
|
||||||
RData: "1.2.3.4",
|
|
||||||
}
|
|
||||||
|
|
||||||
recordCNAME := nbdns.SimpleRecord{
|
|
||||||
Name: "peerb.netbird.cloud.",
|
|
||||||
Type: 5,
|
|
||||||
Class: nbdns.DefaultClass,
|
|
||||||
TTL: 300,
|
|
||||||
RData: "www.netbird.io",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
inputRecord nbdns.SimpleRecord
|
|
||||||
inputMSG *dns.Msg
|
|
||||||
responseShouldBeNil bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Should Resolve A Record",
|
|
||||||
inputRecord: recordA,
|
|
||||||
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Resolve CNAME Record",
|
|
||||||
inputRecord: recordCNAME,
|
|
||||||
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Write When Not Found A Record",
|
|
||||||
inputRecord: recordA,
|
|
||||||
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
|
|
||||||
responseShouldBeNil: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range testCases {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
resolver := &localResolver{
|
|
||||||
registeredMap: make(registrationMap),
|
|
||||||
}
|
|
||||||
_, _ = resolver.registerRecord(testCase.inputRecord)
|
|
||||||
var responseMSG *dns.Msg
|
|
||||||
responseWriter := &mockResponseWriter{
|
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
responseMSG = m
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
resolver.ServeDNS(responseWriter, testCase.inputMSG)
|
|
||||||
|
|
||||||
if responseMSG == nil || len(responseMSG.Answer) == 0 {
|
|
||||||
if testCase.responseShouldBeNil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Fatalf("should write a response message")
|
|
||||||
}
|
|
||||||
|
|
||||||
answerString := responseMSG.Answer[0].String()
|
|
||||||
if !strings.Contains(answerString, testCase.inputRecord.Name) {
|
|
||||||
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
|
|
||||||
}
|
|
||||||
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
|
|
||||||
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
|
|
||||||
}
|
|
||||||
if !strings.Contains(answerString, testCase.inputRecord.RData) {
|
|
||||||
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,26 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockResponseWriter struct {
|
|
||||||
WriteMsgFunc func(m *dns.Msg) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
|
|
||||||
if rw.WriteMsgFunc != nil {
|
|
||||||
return rw.WriteMsgFunc(m)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
|
||||||
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
|
||||||
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
|
||||||
func (rw *mockResponseWriter) Close() error { return nil }
|
|
||||||
func (rw *mockResponseWriter) TsigStatus() error { return nil }
|
|
||||||
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
|
|
||||||
func (rw *mockResponseWriter) Hijack() {}
|
|
@@ -15,6 +15,8 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -46,8 +48,6 @@ type Server interface {
|
|||||||
ProbeAvailability()
|
ProbeAvailability()
|
||||||
}
|
}
|
||||||
|
|
||||||
type handlerID string
|
|
||||||
|
|
||||||
type nsGroupsByDomain struct {
|
type nsGroupsByDomain struct {
|
||||||
domain string
|
domain string
|
||||||
groups []*nbdns.NameServerGroup
|
groups []*nbdns.NameServerGroup
|
||||||
@@ -61,7 +61,7 @@ type DefaultServer struct {
|
|||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
service service
|
service service
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxMap registeredHandlerMap
|
||||||
localResolver *localResolver
|
localResolver *local.Resolver
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
hostManager hostManager
|
hostManager hostManager
|
||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
@@ -84,9 +84,9 @@ type DefaultServer struct {
|
|||||||
|
|
||||||
type handlerWithStop interface {
|
type handlerWithStop interface {
|
||||||
dns.Handler
|
dns.Handler
|
||||||
stop()
|
Stop()
|
||||||
probeAvailability()
|
ProbeAvailability()
|
||||||
id() handlerID
|
ID() types.HandlerID
|
||||||
}
|
}
|
||||||
|
|
||||||
type handlerWrapper struct {
|
type handlerWrapper struct {
|
||||||
@@ -95,7 +95,7 @@ type handlerWrapper struct {
|
|||||||
priority int
|
priority int
|
||||||
}
|
}
|
||||||
|
|
||||||
type registeredHandlerMap map[handlerID]handlerWrapper
|
type registeredHandlerMap map[types.HandlerID]handlerWrapper
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(
|
func NewDefaultServer(
|
||||||
@@ -178,9 +178,7 @@ func newDefaultServer(
|
|||||||
handlerChain: handlerChain,
|
handlerChain: handlerChain,
|
||||||
extraDomains: make(map[domain.Domain]int),
|
extraDomains: make(map[domain.Domain]int),
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
localResolver: &localResolver{
|
localResolver: local.NewResolver(),
|
||||||
registeredMap: make(registrationMap),
|
|
||||||
},
|
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
stateManager: stateManager,
|
stateManager: stateManager,
|
||||||
@@ -403,7 +401,7 @@ func (s *DefaultServer) ProbeAvailability() {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(mux handlerWithStop) {
|
go func(mux handlerWithStop) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
mux.probeAvailability()
|
mux.ProbeAvailability()
|
||||||
}(mux.handler)
|
}(mux.handler)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
@@ -420,7 +418,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
s.service.Stop()
|
s.service.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("local handler updater: %w", err)
|
return fmt.Errorf("local handler updater: %w", err)
|
||||||
}
|
}
|
||||||
@@ -434,7 +432,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
s.updateMux(muxUpdates)
|
s.updateMux(muxUpdates)
|
||||||
|
|
||||||
// register local records
|
// register local records
|
||||||
s.updateLocalResolver(localRecordsByDomain)
|
s.localResolver.Update(localRecords)
|
||||||
|
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||||
|
|
||||||
@@ -516,11 +514,9 @@ func (s *DefaultServer) handleErrNoGroupaAll(err error) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) buildLocalHandlerUpdate(
|
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
|
||||||
customZones []nbdns.CustomZone,
|
|
||||||
) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) {
|
|
||||||
var muxUpdates []handlerWrapper
|
var muxUpdates []handlerWrapper
|
||||||
localRecords := make(map[string][]nbdns.SimpleRecord)
|
var localRecords []nbdns.SimpleRecord
|
||||||
|
|
||||||
for _, customZone := range customZones {
|
for _, customZone := range customZones {
|
||||||
if len(customZone.Records) == 0 {
|
if len(customZone.Records) == 0 {
|
||||||
@@ -534,17 +530,13 @@ func (s *DefaultServer) buildLocalHandlerUpdate(
|
|||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
})
|
})
|
||||||
|
|
||||||
// group all records under this domain
|
|
||||||
for _, record := range customZone.Records {
|
for _, record := range customZone.Records {
|
||||||
var class uint16 = dns.ClassINET
|
|
||||||
if record.Class != nbdns.DefaultClass {
|
if record.Class != nbdns.DefaultClass {
|
||||||
log.Warnf("received an invalid class type: %s", record.Class)
|
log.Warnf("received an invalid class type: %s", record.Class)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// zone records contain the fqdn, so we can just flatten them
|
||||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
localRecords = append(localRecords, record)
|
||||||
|
|
||||||
localRecords[key] = append(localRecords[key], record)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -627,7 +619,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(handler.upstreamServers) == 0 {
|
if len(handler.upstreamServers) == 0 {
|
||||||
handler.stop()
|
handler.Stop()
|
||||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
log.Errorf("received a nameserver group with an invalid nameserver list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -656,7 +648,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
|||||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||||
for _, existing := range s.dnsMuxMap {
|
for _, existing := range s.dnsMuxMap {
|
||||||
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||||
existing.handler.stop()
|
existing.handler.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
muxUpdateMap := make(registeredHandlerMap)
|
||||||
@@ -667,7 +659,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
|||||||
containsRootUpdate = true
|
containsRootUpdate = true
|
||||||
}
|
}
|
||||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||||
muxUpdateMap[update.handler.id()] = update
|
muxUpdateMap[update.handler.ID()] = update
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there's no root update and we had a root handler, restore it
|
// If there's no root update and we had a root handler, restore it
|
||||||
@@ -683,33 +675,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
|||||||
s.dnsMuxMap = muxUpdateMap
|
s.dnsMuxMap = muxUpdateMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateLocalResolver(update map[string][]nbdns.SimpleRecord) {
|
|
||||||
// remove old records that are no longer present
|
|
||||||
for key := range s.localResolver.registeredMap {
|
|
||||||
_, found := update[key]
|
|
||||||
if !found {
|
|
||||||
s.localResolver.deleteRecord(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedMap := make(registrationMap)
|
|
||||||
for _, recs := range update {
|
|
||||||
for _, rec := range recs {
|
|
||||||
// convert the record to a dns.RR and register
|
|
||||||
key, err := s.localResolver.registerRecord(rec)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("got an error while registering the record (%s), error: %v",
|
|
||||||
rec.String(), err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedMap[key] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.localResolver.registeredMap = updatedMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func getNSHostPort(ns nbdns.NameServer) string {
|
func getNSHostPort(ns nbdns.NameServer) string {
|
||||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||||
}
|
}
|
||||||
|
@@ -23,6 +23,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -107,6 +110,7 @@ func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
IP: netip.MustParseAddr("8.8.8.8"),
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
@@ -120,22 +124,21 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
dummyHandler := &localResolver{}
|
dummyHandler := local.NewResolver()
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
initUpstreamMap registeredHandlerMap
|
initUpstreamMap registeredHandlerMap
|
||||||
initLocalMap registrationMap
|
initLocalRecords []nbdns.SimpleRecord
|
||||||
initSerial uint64
|
initSerial uint64
|
||||||
inputSerial uint64
|
inputSerial uint64
|
||||||
inputUpdate nbdns.Config
|
inputUpdate nbdns.Config
|
||||||
shouldFail bool
|
shouldFail bool
|
||||||
expectedUpstreamMap registeredHandlerMap
|
expectedUpstreamMap registeredHandlerMap
|
||||||
expectedLocalMap registrationMap
|
expectedLocalQs []dns.Question
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Initial Config Should Succeed",
|
name: "Initial Config Should Succeed",
|
||||||
initLocalMap: make(registrationMap),
|
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registeredHandlerMap),
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
@@ -159,30 +162,30 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{
|
expectedUpstreamMap: registeredHandlerMap{
|
||||||
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||||
domain: "netbird.io",
|
domain: "netbird.io",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
dummyHandler.id(): handlerWrapper{
|
dummyHandler.ID(): handlerWrapper{
|
||||||
domain: "netbird.cloud",
|
domain: "netbird.cloud",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||||
domain: nbdns.RootZone,
|
domain: nbdns.RootZone,
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityDefault,
|
priority: PriorityDefault,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "New Config Should Succeed",
|
name: "New Config Should Succeed",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||||
initUpstreamMap: registeredHandlerMap{
|
initUpstreamMap: registeredHandlerMap{
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||||
domain: buildRecordKey(zoneRecords[0].Name, 1, 1),
|
domain: "netbird.cloud",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
@@ -205,7 +208,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{
|
expectedUpstreamMap: registeredHandlerMap{
|
||||||
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||||
domain: "netbird.io",
|
domain: "netbird.io",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
@@ -216,11 +219,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Smaller Config Serial Should Be Skipped",
|
name: "Smaller Config Serial Should Be Skipped",
|
||||||
initLocalMap: make(registrationMap),
|
initLocalRecords: []nbdns.SimpleRecord{},
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registeredHandlerMap),
|
||||||
initSerial: 2,
|
initSerial: 2,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
@@ -228,7 +231,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||||
initLocalMap: make(registrationMap),
|
initLocalRecords: []nbdns.SimpleRecord{},
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registeredHandlerMap),
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
@@ -250,7 +253,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Invalid NS Group Nameservers list Should Fail",
|
name: "Invalid NS Group Nameservers list Should Fail",
|
||||||
initLocalMap: make(registrationMap),
|
initLocalRecords: []nbdns.SimpleRecord{},
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registeredHandlerMap),
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
@@ -272,7 +275,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Invalid Custom Zone Records list Should Skip",
|
name: "Invalid Custom Zone Records list Should Skip",
|
||||||
initLocalMap: make(registrationMap),
|
initLocalRecords: []nbdns.SimpleRecord{},
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registeredHandlerMap),
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
@@ -290,7 +293,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||||
domain: ".",
|
domain: ".",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityDefault,
|
priority: PriorityDefault,
|
||||||
@@ -298,9 +301,9 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Empty Config Should Succeed and Clean Maps",
|
name: "Empty Config Should Succeed and Clean Maps",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||||
initUpstreamMap: registeredHandlerMap{
|
initUpstreamMap: registeredHandlerMap{
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
@@ -310,13 +313,13 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||||
expectedUpstreamMap: make(registeredHandlerMap),
|
expectedUpstreamMap: make(registeredHandlerMap),
|
||||||
expectedLocalMap: make(registrationMap),
|
expectedLocalQs: []dns.Question{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Disabled Service Should clean map",
|
name: "Disabled Service Should clean map",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||||
initUpstreamMap: registeredHandlerMap{
|
initUpstreamMap: registeredHandlerMap{
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
@@ -326,7 +329,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||||
expectedUpstreamMap: make(registeredHandlerMap),
|
expectedUpstreamMap: make(registeredHandlerMap),
|
||||||
expectedLocalMap: make(registrationMap),
|
expectedLocalQs: []dns.Question{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -377,7 +380,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||||
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
dnsServer.localResolver.Update(testCase.initLocalRecords)
|
||||||
dnsServer.updateSerial = testCase.initSerial
|
dnsServer.updateSerial = testCase.initSerial
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||||
@@ -399,15 +402,23 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) {
|
var responseMSG *dns.Msg
|
||||||
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap))
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, q := range testCase.expectedLocalQs {
|
||||||
|
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||||
|
Question: []dns.Question{q},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for key := range testCase.expectedLocalMap {
|
if len(testCase.expectedLocalQs) > 0 {
|
||||||
_, found := dnsServer.localResolver.registeredMap[key]
|
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||||
if !found {
|
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||||
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap)
|
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -491,11 +502,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||||
"id1": handlerWrapper{
|
"id1": handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: &localResolver{},
|
handler: &local.Resolver{},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
|
||||||
|
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
|
||||||
dnsServer.updateSerial = 0
|
dnsServer.updateSerial = 0
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
@@ -582,7 +594,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
defer dnsServer.Stop()
|
defer dnsServer.Stop()
|
||||||
_, err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
err = dnsServer.localResolver.RegisterRecord(zoneRecords[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -632,9 +644,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
server := DefaultServer{
|
server := DefaultServer{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
service: NewServiceViaMemory(&mocWGIface{}),
|
service: NewServiceViaMemory(&mocWGIface{}),
|
||||||
localResolver: &localResolver{
|
localResolver: local.NewResolver(),
|
||||||
registeredMap: make(registrationMap),
|
|
||||||
},
|
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
hostManager: hostManager,
|
hostManager: hostManager,
|
||||||
currentConfig: HostDNSConfig{
|
currentConfig: HostDNSConfig{
|
||||||
@@ -1004,7 +1014,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tc.query, dns.TypeA)
|
r.SetQuestion(tc.query, dns.TypeA)
|
||||||
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||||
mh.On("ServeDNS", mock.Anything, r).Once()
|
mh.On("ServeDNS", mock.Anything, r).Once()
|
||||||
@@ -1037,9 +1047,9 @@ type mockHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||||
func (m *mockHandler) stop() {}
|
func (m *mockHandler) Stop() {}
|
||||||
func (m *mockHandler) probeAvailability() {}
|
func (m *mockHandler) ProbeAvailability() {}
|
||||||
func (m *mockHandler) id() handlerID { return handlerID(m.Id) }
|
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
||||||
|
|
||||||
type mockService struct{}
|
type mockService struct{}
|
||||||
|
|
||||||
@@ -1113,7 +1123,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
initialHandlers registeredHandlerMap
|
initialHandlers registeredHandlerMap
|
||||||
updates []handlerWrapper
|
updates []handlerWrapper
|
||||||
expectedHandlers map[string]string // map[handlerID]domain
|
expectedHandlers map[string]string // map[HandlerID]domain
|
||||||
description string
|
description string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -1409,7 +1419,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
|
|
||||||
// Check each expected handler
|
// Check each expected handler
|
||||||
for id, expectedDomain := range tt.expectedHandlers {
|
for id, expectedDomain := range tt.expectedHandlers {
|
||||||
handler, exists := server.dnsMuxMap[handlerID(id)]
|
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
|
||||||
assert.True(t, exists, "Expected handler %s not found", id)
|
assert.True(t, exists, "Expected handler %s not found", id)
|
||||||
if exists {
|
if exists {
|
||||||
assert.Equal(t, expectedDomain, handler.domain,
|
assert.Equal(t, expectedDomain, handler.domain,
|
||||||
@@ -1418,9 +1428,9 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify no unexpected handlers exist
|
// Verify no unexpected handlers exist
|
||||||
for handlerID := range server.dnsMuxMap {
|
for HandlerID := range server.dnsMuxMap {
|
||||||
_, expected := tt.expectedHandlers[string(handlerID)]
|
_, expected := tt.expectedHandlers[string(HandlerID)]
|
||||||
assert.True(t, expected, "Unexpected handler found: %s", handlerID)
|
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the handlerChain state and order
|
// Verify the handlerChain state and order
|
||||||
@@ -1696,7 +1706,7 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
wgInterface: &mocWGIface{},
|
wgInterface: &mocWGIface{},
|
||||||
hostManager: mockHostConfig,
|
hostManager: mockHostConfig,
|
||||||
localResolver: &localResolver{},
|
localResolver: &local.Resolver{},
|
||||||
service: mockSvc,
|
service: mockSvc,
|
||||||
statusRecorder: peer.NewRecorder("test"),
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
extraDomains: make(map[domain.Domain]int),
|
extraDomains: make(map[domain.Domain]int),
|
||||||
@@ -1781,7 +1791,7 @@ func TestExtraDomainsRefCounting(t *testing.T) {
|
|||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
hostManager: mockHostConfig,
|
hostManager: mockHostConfig,
|
||||||
localResolver: &localResolver{},
|
localResolver: &local.Resolver{},
|
||||||
service: mockSvc,
|
service: mockSvc,
|
||||||
statusRecorder: peer.NewRecorder("test"),
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
extraDomains: make(map[domain.Domain]int),
|
extraDomains: make(map[domain.Domain]int),
|
||||||
@@ -1833,7 +1843,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
|||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
hostManager: mockHostConfig,
|
hostManager: mockHostConfig,
|
||||||
localResolver: &localResolver{},
|
localResolver: &local.Resolver{},
|
||||||
service: mockSvc,
|
service: mockSvc,
|
||||||
statusRecorder: peer.NewRecorder("test"),
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
extraDomains: make(map[domain.Domain]int),
|
extraDomains: make(map[domain.Domain]int),
|
||||||
@@ -1916,7 +1926,7 @@ func TestDomainCaseHandling(t *testing.T) {
|
|||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
hostManager: mockHostConfig,
|
hostManager: mockHostConfig,
|
||||||
localResolver: &localResolver{},
|
localResolver: &local.Resolver{},
|
||||||
service: mockSvc,
|
service: mockSvc,
|
||||||
statusRecorder: peer.NewRecorder("test"),
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
extraDomains: make(map[domain.Domain]int),
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
26
client/internal/dns/test/mock.go
Normal file
26
client/internal/dns/test/mock.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockResponseWriter struct {
|
||||||
|
WriteMsgFunc func(m *dns.Msg) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||||
|
if rw.WriteMsgFunc != nil {
|
||||||
|
return rw.WriteMsgFunc(m)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||||
|
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||||
|
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||||
|
func (rw *MockResponseWriter) Close() error { return nil }
|
||||||
|
func (rw *MockResponseWriter) TsigStatus() error { return nil }
|
||||||
|
func (rw *MockResponseWriter) TsigTimersOnly(bool) {}
|
||||||
|
func (rw *MockResponseWriter) Hijack() {}
|
3
client/internal/dns/types/types.go
Normal file
3
client/internal/dns/types/types.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type HandlerID string
|
@@ -19,6 +19,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
)
|
)
|
||||||
@@ -81,21 +82,21 @@ func (u *upstreamResolverBase) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the unique handler ID
|
// ID returns the unique handler ID
|
||||||
func (u *upstreamResolverBase) id() handlerID {
|
func (u *upstreamResolverBase) ID() types.HandlerID {
|
||||||
servers := slices.Clone(u.upstreamServers)
|
servers := slices.Clone(u.upstreamServers)
|
||||||
slices.Sort(servers)
|
slices.Sort(servers)
|
||||||
|
|
||||||
hash := sha256.New()
|
hash := sha256.New()
|
||||||
hash.Write([]byte(u.domain + ":"))
|
hash.Write([]byte(u.domain + ":"))
|
||||||
hash.Write([]byte(strings.Join(servers, ",")))
|
hash.Write([]byte(strings.Join(servers, ",")))
|
||||||
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) stop() {
|
func (u *upstreamResolverBase) Stop() {
|
||||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||||
u.cancel()
|
u.cancel()
|
||||||
}
|
}
|
||||||
@@ -198,9 +199,9 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// probeAvailability tests all upstream servers simultaneously and
|
// ProbeAvailability tests all upstream servers simultaneously and
|
||||||
// disables the resolver if none work
|
// disables the resolver if none work
|
||||||
func (u *upstreamResolverBase) probeAvailability() {
|
func (u *upstreamResolverBase) ProbeAvailability() {
|
||||||
u.mutex.Lock()
|
u.mutex.Lock()
|
||||||
defer u.mutex.Unlock()
|
defer u.mutex.Unlock()
|
||||||
|
|
||||||
|
@@ -8,6 +8,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||||
@@ -66,7 +68,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var responseMSG *dns.Msg
|
var responseMSG *dns.Msg
|
||||||
responseWriter := &mockResponseWriter{
|
responseWriter := &test.MockResponseWriter{
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
responseMSG = m
|
responseMSG = m
|
||||||
return nil
|
return nil
|
||||||
@@ -130,7 +132,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
|||||||
resolver.failsTillDeact = 0
|
resolver.failsTillDeact = 0
|
||||||
resolver.reactivatePeriod = time.Microsecond * 100
|
resolver.reactivatePeriod = time.Microsecond * 100
|
||||||
|
|
||||||
responseWriter := &mockResponseWriter{
|
responseWriter := &test.MockResponseWriter{
|
||||||
WriteMsgFunc: func(m *dns.Msg) error { return nil },
|
WriteMsgFunc: func(m *dns.Msg) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -66,17 +66,17 @@ func (s SimpleRecord) String() string {
|
|||||||
func (s SimpleRecord) Len() uint16 {
|
func (s SimpleRecord) Len() uint16 {
|
||||||
emptyString := s.RData == ""
|
emptyString := s.RData == ""
|
||||||
switch s.Type {
|
switch s.Type {
|
||||||
case 1:
|
case int(dns.TypeA):
|
||||||
if emptyString {
|
if emptyString {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
return net.IPv4len
|
return net.IPv4len
|
||||||
case 5:
|
case int(dns.TypeCNAME):
|
||||||
if emptyString || s.RData == "." {
|
if emptyString || s.RData == "." {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
return uint16(len(s.RData) + 1)
|
return uint16(len(s.RData) + 1)
|
||||||
case 28:
|
case int(dns.TypeAAAA):
|
||||||
if emptyString {
|
if emptyString {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user