mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-25 07:19:05 +01:00
ddc365f7a0
--------- Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
512 lines
15 KiB
Go
512 lines
15 KiB
Go
package dns_test
|
|
|
|
import (
|
|
"net"
|
|
"testing"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
|
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
|
)
|
|
|
|
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
|
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
|
chain := nbdns.NewHandlerChain()
|
|
|
|
// Create mock handlers for different priorities
|
|
defaultHandler := &nbdns.MockHandler{}
|
|
matchDomainHandler := &nbdns.MockHandler{}
|
|
dnsRouteHandler := &nbdns.MockHandler{}
|
|
|
|
// Setup handlers with different priorities
|
|
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
|
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
|
|
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
|
|
|
|
// Create test request
|
|
r := new(dns.Msg)
|
|
r.SetQuestion("example.com.", dns.TypeA)
|
|
|
|
// Create test writer
|
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
|
|
|
// Setup expectations - only highest priority handler should be called
|
|
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
|
matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe()
|
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe()
|
|
|
|
// Execute
|
|
chain.ServeDNS(w, r)
|
|
|
|
// Verify all expectations were met
|
|
dnsRouteHandler.AssertExpectations(t)
|
|
matchDomainHandler.AssertExpectations(t)
|
|
defaultHandler.AssertExpectations(t)
|
|
}
|
|
|
|
// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios
|
|
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
handlerDomain string
|
|
queryDomain string
|
|
isWildcard bool
|
|
matchSubdomains bool
|
|
shouldMatch bool
|
|
}{
|
|
{
|
|
name: "exact match",
|
|
handlerDomain: "example.com.",
|
|
queryDomain: "example.com.",
|
|
isWildcard: false,
|
|
matchSubdomains: false,
|
|
shouldMatch: true,
|
|
},
|
|
{
|
|
name: "subdomain with non-wildcard and MatchSubdomains true",
|
|
handlerDomain: "example.com.",
|
|
queryDomain: "sub.example.com.",
|
|
isWildcard: false,
|
|
matchSubdomains: true,
|
|
shouldMatch: true,
|
|
},
|
|
{
|
|
name: "subdomain with non-wildcard and MatchSubdomains false",
|
|
handlerDomain: "example.com.",
|
|
queryDomain: "sub.example.com.",
|
|
isWildcard: false,
|
|
matchSubdomains: false,
|
|
shouldMatch: false,
|
|
},
|
|
{
|
|
name: "wildcard match",
|
|
handlerDomain: "*.example.com.",
|
|
queryDomain: "sub.example.com.",
|
|
isWildcard: true,
|
|
matchSubdomains: false,
|
|
shouldMatch: true,
|
|
},
|
|
{
|
|
name: "wildcard no match on apex",
|
|
handlerDomain: "*.example.com.",
|
|
queryDomain: "example.com.",
|
|
isWildcard: true,
|
|
matchSubdomains: false,
|
|
shouldMatch: false,
|
|
},
|
|
{
|
|
name: "root zone match",
|
|
handlerDomain: ".",
|
|
queryDomain: "anything.com.",
|
|
isWildcard: false,
|
|
matchSubdomains: false,
|
|
shouldMatch: true,
|
|
},
|
|
{
|
|
name: "no match different domain",
|
|
handlerDomain: "example.com.",
|
|
queryDomain: "example.org.",
|
|
isWildcard: false,
|
|
matchSubdomains: false,
|
|
shouldMatch: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
chain := nbdns.NewHandlerChain()
|
|
var handler dns.Handler
|
|
|
|
if tt.matchSubdomains {
|
|
mockSubHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
|
handler = mockSubHandler
|
|
if tt.shouldMatch {
|
|
mockSubHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
|
}
|
|
} else {
|
|
mockHandler := &nbdns.MockHandler{}
|
|
handler = mockHandler
|
|
if tt.shouldMatch {
|
|
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
|
}
|
|
}
|
|
|
|
pattern := tt.handlerDomain
|
|
if tt.isWildcard {
|
|
pattern = "*." + tt.handlerDomain[2:]
|
|
}
|
|
|
|
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
|
|
|
|
r := new(dns.Msg)
|
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
|
|
|
chain.ServeDNS(w, r)
|
|
|
|
if h, ok := handler.(*nbdns.MockHandler); ok {
|
|
h.AssertExpectations(t)
|
|
} else if h, ok := handler.(*nbdns.MockSubdomainHandler); ok {
|
|
h.AssertExpectations(t)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns
|
|
func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
handlers []struct {
|
|
pattern string
|
|
priority int
|
|
}
|
|
queryDomain string
|
|
expectedCalls int
|
|
expectedHandler int // index of the handler that should be called
|
|
}{
|
|
{
|
|
name: "wildcard and exact same priority - exact should win",
|
|
handlers: []struct {
|
|
pattern string
|
|
priority int
|
|
}{
|
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
|
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
|
},
|
|
queryDomain: "example.com.",
|
|
expectedCalls: 1,
|
|
expectedHandler: 1, // exact match handler should be called
|
|
},
|
|
{
|
|
name: "higher priority wildcard over lower priority exact",
|
|
handlers: []struct {
|
|
pattern string
|
|
priority int
|
|
}{
|
|
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
|
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
|
},
|
|
queryDomain: "test.example.com.",
|
|
expectedCalls: 1,
|
|
expectedHandler: 1, // higher priority wildcard handler should be called
|
|
},
|
|
{
|
|
name: "multiple wildcards different priorities",
|
|
handlers: []struct {
|
|
pattern string
|
|
priority int
|
|
}{
|
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
|
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
|
|
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
|
},
|
|
queryDomain: "test.example.com.",
|
|
expectedCalls: 1,
|
|
expectedHandler: 2, // highest priority handler should be called
|
|
},
|
|
{
|
|
name: "subdomain with mix of patterns",
|
|
handlers: []struct {
|
|
pattern string
|
|
priority int
|
|
}{
|
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
|
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
|
|
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
|
|
},
|
|
queryDomain: "sub.test.example.com.",
|
|
expectedCalls: 1,
|
|
expectedHandler: 2, // highest priority matching handler should be called
|
|
},
|
|
{
|
|
name: "root zone with specific domain",
|
|
handlers: []struct {
|
|
pattern string
|
|
priority int
|
|
}{
|
|
{pattern: ".", priority: nbdns.PriorityDefault},
|
|
{pattern: "example.com.", priority: nbdns.PriorityDNSRoute},
|
|
},
|
|
queryDomain: "example.com.",
|
|
expectedCalls: 1,
|
|
expectedHandler: 1, // higher priority specific domain should win over root
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
chain := nbdns.NewHandlerChain()
|
|
var handlers []*nbdns.MockHandler
|
|
|
|
// Setup handlers and expectations
|
|
for i := range tt.handlers {
|
|
handler := &nbdns.MockHandler{}
|
|
handlers = append(handlers, handler)
|
|
|
|
// Set expectation based on whether this handler should be called
|
|
if i == tt.expectedHandler {
|
|
handler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
|
} else {
|
|
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
|
}
|
|
|
|
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
|
|
}
|
|
|
|
// Create and execute request
|
|
r := new(dns.Msg)
|
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
|
chain.ServeDNS(w, r)
|
|
|
|
// Verify expectations
|
|
for _, handler := range handlers {
|
|
handler.AssertExpectations(t)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality
|
|
func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
|
chain := nbdns.NewHandlerChain()
|
|
|
|
// Create handlers
|
|
handler1 := &nbdns.MockHandler{}
|
|
handler2 := &nbdns.MockHandler{}
|
|
handler3 := &nbdns.MockHandler{}
|
|
|
|
// Add handlers in priority order
|
|
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
|
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
|
|
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
|
|
|
|
// Create test request
|
|
r := new(dns.Msg)
|
|
r.SetQuestion("example.com.", dns.TypeA)
|
|
|
|
// Setup mock responses to simulate chain continuation
|
|
handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
|
// First handler signals continue
|
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
|
resp := new(dns.Msg)
|
|
resp.SetRcode(r, dns.RcodeNameError)
|
|
resp.MsgHdr.Zero = true // Signal to continue
|
|
assert.NoError(t, w.WriteMsg(resp))
|
|
}).Once()
|
|
|
|
handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
|
// Second handler signals continue
|
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
|
resp := new(dns.Msg)
|
|
resp.SetRcode(r, dns.RcodeNameError)
|
|
resp.MsgHdr.Zero = true
|
|
assert.NoError(t, w.WriteMsg(resp))
|
|
}).Once()
|
|
|
|
handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
|
// Last handler responds normally
|
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
|
resp := new(dns.Msg)
|
|
resp.SetRcode(r, dns.RcodeSuccess)
|
|
assert.NoError(t, w.WriteMsg(resp))
|
|
}).Once()
|
|
|
|
// Execute
|
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
|
chain.ServeDNS(w, r)
|
|
|
|
// Verify all handlers were called in order
|
|
handler1.AssertExpectations(t)
|
|
handler2.AssertExpectations(t)
|
|
handler3.AssertExpectations(t)
|
|
}
|
|
|
|
// mockResponseWriter implements dns.ResponseWriter for testing
|
|
type mockResponseWriter struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
|
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
|
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
|
|
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
|
func (m *mockResponseWriter) Close() error { return nil }
|
|
func (m *mockResponseWriter) TsigStatus() error { return nil }
|
|
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
|
|
func (m *mockResponseWriter) Hijack() {}
|
|
|
|
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
ops []struct {
|
|
action string // "add" or "remove"
|
|
pattern string
|
|
priority int
|
|
}
|
|
query string
|
|
expectedCalls map[int]bool // map[priority]shouldBeCalled
|
|
}{
|
|
{
|
|
name: "remove high priority keeps lower priority handler",
|
|
ops: []struct {
|
|
action string
|
|
pattern string
|
|
priority int
|
|
}{
|
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
|
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
|
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
|
},
|
|
query: "example.com.",
|
|
expectedCalls: map[int]bool{
|
|
nbdns.PriorityDNSRoute: false,
|
|
nbdns.PriorityMatchDomain: true,
|
|
},
|
|
},
|
|
{
|
|
name: "remove lower priority keeps high priority handler",
|
|
ops: []struct {
|
|
action string
|
|
pattern string
|
|
priority int
|
|
}{
|
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
|
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
|
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
|
},
|
|
query: "example.com.",
|
|
expectedCalls: map[int]bool{
|
|
nbdns.PriorityDNSRoute: true,
|
|
nbdns.PriorityMatchDomain: false,
|
|
},
|
|
},
|
|
{
|
|
name: "remove all handlers in order",
|
|
ops: []struct {
|
|
action string
|
|
pattern string
|
|
priority int
|
|
}{
|
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
|
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
|
{"add", "example.com.", nbdns.PriorityDefault},
|
|
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
|
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
|
},
|
|
query: "example.com.",
|
|
expectedCalls: map[int]bool{
|
|
nbdns.PriorityDNSRoute: false,
|
|
nbdns.PriorityMatchDomain: false,
|
|
nbdns.PriorityDefault: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
chain := nbdns.NewHandlerChain()
|
|
handlers := make(map[int]*nbdns.MockHandler)
|
|
|
|
// Execute operations
|
|
for _, op := range tt.ops {
|
|
if op.action == "add" {
|
|
handler := &nbdns.MockHandler{}
|
|
handlers[op.priority] = handler
|
|
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
|
} else {
|
|
chain.RemoveHandler(op.pattern, op.priority)
|
|
}
|
|
}
|
|
|
|
// Create test request
|
|
r := new(dns.Msg)
|
|
r.SetQuestion(tt.query, dns.TypeA)
|
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
|
|
|
// Setup expectations
|
|
for priority, handler := range handlers {
|
|
if shouldCall, exists := tt.expectedCalls[priority]; exists && shouldCall {
|
|
handler.On("ServeDNS", mock.Anything, r).Once()
|
|
} else {
|
|
handler.On("ServeDNS", mock.Anything, r).Maybe()
|
|
}
|
|
}
|
|
|
|
// Execute request
|
|
chain.ServeDNS(w, r)
|
|
|
|
// Verify expectations
|
|
for _, handler := range handlers {
|
|
handler.AssertExpectations(t)
|
|
}
|
|
|
|
// Verify handler exists check
|
|
for priority, shouldExist := range tt.expectedCalls {
|
|
if shouldExist {
|
|
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
|
|
"Handler chain should have handlers for pattern after removing priority %d", priority)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|
chain := nbdns.NewHandlerChain()
|
|
|
|
testDomain := "example.com."
|
|
testQuery := "test.example.com."
|
|
|
|
// Create handlers with MatchSubdomains enabled
|
|
routeHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
|
matchHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
|
defaultHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
|
|
|
// Create test request that will be reused
|
|
r := new(dns.Msg)
|
|
r.SetQuestion(testQuery, dns.TypeA)
|
|
|
|
// Add handlers in mixed order
|
|
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
|
|
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
|
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
|
|
|
|
// Test 1: Initial state with all three handlers
|
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
|
// Highest priority handler (routeHandler) should be called
|
|
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
|
|
|
chain.ServeDNS(w, r)
|
|
routeHandler.AssertExpectations(t)
|
|
|
|
// Test 2: Remove highest priority handler
|
|
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
|
assert.True(t, chain.HasHandlers(testDomain))
|
|
|
|
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
|
// Now middle priority handler (matchHandler) should be called
|
|
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
|
|
|
chain.ServeDNS(w, r)
|
|
matchHandler.AssertExpectations(t)
|
|
|
|
// Test 3: Remove middle priority handler
|
|
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
|
assert.True(t, chain.HasHandlers(testDomain))
|
|
|
|
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
|
// Now lowest priority handler (defaultHandler) should be called
|
|
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
|
|
|
chain.ServeDNS(w, r)
|
|
defaultHandler.AssertExpectations(t)
|
|
|
|
// Test 4: Remove last handler
|
|
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
|
assert.False(t, chain.HasHandlers(testDomain))
|
|
}
|