mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-06 06:06:35 +02:00
730 lines
21 KiB
Go
730 lines
21 KiB
Go
package dnsfwd
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/netip"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
|
"github.com/netbirdio/netbird/client/internal/peer"
|
|
"github.com/netbirdio/netbird/management/domain"
|
|
"github.com/netbirdio/netbird/route"
|
|
)
|
|
|
|
func Test_getMatchingEntries(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
storedMappings map[string]route.ResID
|
|
queryDomain string
|
|
expectedResId route.ResID
|
|
}{
|
|
{
|
|
name: "Empty map returns empty string",
|
|
storedMappings: map[string]route.ResID{},
|
|
queryDomain: "example.com",
|
|
expectedResId: "",
|
|
},
|
|
{
|
|
name: "Exact match returns stored resId",
|
|
storedMappings: map[string]route.ResID{"example.com": "res1"},
|
|
queryDomain: "example.com",
|
|
expectedResId: "res1",
|
|
},
|
|
{
|
|
name: "Wildcard pattern matches base domain",
|
|
storedMappings: map[string]route.ResID{"*.example.com": "res2"},
|
|
queryDomain: "example.com",
|
|
expectedResId: "res2",
|
|
},
|
|
{
|
|
name: "Wildcard pattern matches subdomain",
|
|
storedMappings: map[string]route.ResID{"*.example.com": "res3"},
|
|
queryDomain: "foo.example.com",
|
|
expectedResId: "res3",
|
|
},
|
|
{
|
|
name: "Wildcard pattern does not match different domain",
|
|
storedMappings: map[string]route.ResID{"*.example.com": "res4"},
|
|
queryDomain: "foo.example.org",
|
|
expectedResId: "",
|
|
},
|
|
{
|
|
name: "Non-wildcard pattern does not match subdomain",
|
|
storedMappings: map[string]route.ResID{"example.com": "res5"},
|
|
queryDomain: "foo.example.com",
|
|
expectedResId: "",
|
|
},
|
|
{
|
|
name: "Exact match over overlapping wildcard",
|
|
storedMappings: map[string]route.ResID{
|
|
"*.example.com": "resWildcard",
|
|
"foo.example.com": "resExact",
|
|
},
|
|
queryDomain: "foo.example.com",
|
|
expectedResId: "resExact",
|
|
},
|
|
{
|
|
name: "Overlapping wildcards: Select more specific wildcard",
|
|
storedMappings: map[string]route.ResID{
|
|
"*.example.com": "resA",
|
|
"*.sub.example.com": "resB",
|
|
},
|
|
queryDomain: "bar.sub.example.com",
|
|
expectedResId: "resB",
|
|
},
|
|
{
|
|
name: "Wildcard multi-level subdomain match",
|
|
storedMappings: map[string]route.ResID{
|
|
"*.example.com": "resMulti",
|
|
},
|
|
queryDomain: "a.b.example.com",
|
|
expectedResId: "resMulti",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
fwd := &DNSForwarder{}
|
|
|
|
var entries []*ForwarderEntry
|
|
for domainPattern, resId := range tc.storedMappings {
|
|
d, err := domain.FromString(domainPattern)
|
|
require.NoError(t, err)
|
|
entries = append(entries, &ForwarderEntry{
|
|
Domain: d,
|
|
ResID: resId,
|
|
})
|
|
}
|
|
fwd.UpdateDomains(entries)
|
|
|
|
got, _ := fwd.getMatchingEntries(tc.queryDomain)
|
|
assert.Equal(t, got, tc.expectedResId)
|
|
})
|
|
}
|
|
}
|
|
|
|
type MockFirewall struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|
args := m.Called(set, prefixes)
|
|
return args.Error(0)
|
|
}
|
|
|
|
type MockResolver struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
|
args := m.Called(ctx, network, host)
|
|
return args.Get(0).([]netip.Addr), args.Error(1)
|
|
}
|
|
|
|
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
configuredDomain string
|
|
queryDomain string
|
|
shouldMatch bool
|
|
expectedResID route.ResID
|
|
description string
|
|
}{
|
|
{
|
|
name: "exact domain match should be allowed",
|
|
configuredDomain: "example.com",
|
|
queryDomain: "example.com",
|
|
shouldMatch: true,
|
|
expectedResID: "test-res-id",
|
|
description: "Direct match to configured domain should work",
|
|
},
|
|
{
|
|
name: "subdomain access should be restricted",
|
|
configuredDomain: "example.com",
|
|
queryDomain: "mail.example.com",
|
|
shouldMatch: false,
|
|
expectedResID: "",
|
|
description: "Subdomain should not be accessible unless explicitly configured",
|
|
},
|
|
{
|
|
name: "wildcard should allow subdomains",
|
|
configuredDomain: "*.example.com",
|
|
queryDomain: "mail.example.com",
|
|
shouldMatch: true,
|
|
expectedResID: "test-res-id",
|
|
description: "Wildcard domains should allow subdomain access",
|
|
},
|
|
{
|
|
name: "wildcard should allow base domain",
|
|
configuredDomain: "*.example.com",
|
|
queryDomain: "example.com",
|
|
shouldMatch: true,
|
|
expectedResID: "test-res-id",
|
|
description: "Wildcard should also match the base domain",
|
|
},
|
|
{
|
|
name: "deep subdomain should be restricted",
|
|
configuredDomain: "example.com",
|
|
queryDomain: "deep.mail.example.com",
|
|
shouldMatch: false,
|
|
expectedResID: "",
|
|
description: "Deep subdomains should not be accessible",
|
|
},
|
|
{
|
|
name: "wildcard allows deep subdomains",
|
|
configuredDomain: "*.example.com",
|
|
queryDomain: "deep.mail.example.com",
|
|
shouldMatch: true,
|
|
expectedResID: "test-res-id",
|
|
description: "Wildcard should allow deep subdomains",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
forwarder := &DNSForwarder{}
|
|
|
|
d, err := domain.FromString(tt.configuredDomain)
|
|
require.NoError(t, err)
|
|
|
|
entries := []*ForwarderEntry{
|
|
{
|
|
Domain: d,
|
|
ResID: "test-res-id",
|
|
},
|
|
}
|
|
|
|
forwarder.UpdateDomains(entries)
|
|
|
|
resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain)
|
|
|
|
if tt.shouldMatch {
|
|
assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID")
|
|
assert.NotEmpty(t, matchingEntries, "Expected matching entries")
|
|
t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain)
|
|
} else {
|
|
assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match")
|
|
assert.Empty(t, matchingEntries, "Expected no matching entries")
|
|
t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping integration test in short mode")
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
configuredDomain string
|
|
queryDomain string
|
|
shouldResolve bool
|
|
description string
|
|
}{
|
|
{
|
|
name: "configured exact domain resolves",
|
|
configuredDomain: "example.com",
|
|
queryDomain: "example.com",
|
|
shouldResolve: true,
|
|
description: "Exact match should resolve",
|
|
},
|
|
{
|
|
name: "unauthorized subdomain blocked",
|
|
configuredDomain: "example.com",
|
|
queryDomain: "mail.example.com",
|
|
shouldResolve: false,
|
|
description: "Subdomain should be blocked without wildcard",
|
|
},
|
|
{
|
|
name: "wildcard allows subdomain",
|
|
configuredDomain: "*.example.com",
|
|
queryDomain: "mail.example.com",
|
|
shouldResolve: true,
|
|
description: "Wildcard should allow subdomain",
|
|
},
|
|
{
|
|
name: "wildcard allows base domain",
|
|
configuredDomain: "*.example.com",
|
|
queryDomain: "example.com",
|
|
shouldResolve: true,
|
|
description: "Wildcard should allow base domain",
|
|
},
|
|
{
|
|
name: "unrelated domain blocked",
|
|
configuredDomain: "example.com",
|
|
queryDomain: "example.org",
|
|
shouldResolve: false,
|
|
description: "Unrelated domain should be blocked",
|
|
},
|
|
{
|
|
name: "deep subdomain blocked",
|
|
configuredDomain: "example.com",
|
|
queryDomain: "deep.mail.example.com",
|
|
shouldResolve: false,
|
|
description: "Deep subdomain should be blocked",
|
|
},
|
|
{
|
|
name: "wildcard allows deep subdomain",
|
|
configuredDomain: "*.example.com",
|
|
queryDomain: "deep.mail.example.com",
|
|
shouldResolve: true,
|
|
description: "Wildcard should allow deep subdomain",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockFirewall := &MockFirewall{}
|
|
mockResolver := &MockResolver{}
|
|
|
|
if tt.shouldResolve {
|
|
mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil)
|
|
|
|
// Mock successful DNS resolution
|
|
fakeIP := netip.MustParseAddr("1.2.3.4")
|
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
|
|
}
|
|
|
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
|
forwarder.resolver = mockResolver
|
|
|
|
d, err := domain.FromString(tt.configuredDomain)
|
|
require.NoError(t, err)
|
|
|
|
entries := []*ForwarderEntry{
|
|
{
|
|
Domain: d,
|
|
ResID: "test-res-id",
|
|
Set: firewall.NewDomainSet([]domain.Domain{d}),
|
|
},
|
|
}
|
|
|
|
forwarder.UpdateDomains(entries)
|
|
|
|
query := &dns.Msg{}
|
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
|
|
|
mockWriter := &test.MockResponseWriter{}
|
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
|
|
|
if tt.shouldResolve {
|
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
|
assert.NotEmpty(t, resp.Answer, "Expected DNS answer records")
|
|
|
|
time.Sleep(10 * time.Millisecond)
|
|
mockFirewall.AssertExpectations(t)
|
|
mockResolver.AssertExpectations(t)
|
|
} else {
|
|
if resp != nil {
|
|
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
|
"Unauthorized domain should not return successful answers")
|
|
}
|
|
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
|
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
configuredDomains []string
|
|
query string
|
|
mockIP string
|
|
shouldResolve bool
|
|
expectedSetCount int // How many sets should be updated
|
|
description string
|
|
}{
|
|
{
|
|
name: "exact domain gets firewall update",
|
|
configuredDomains: []string{"example.com"},
|
|
query: "example.com",
|
|
mockIP: "1.1.1.1",
|
|
shouldResolve: true,
|
|
expectedSetCount: 1,
|
|
description: "Single exact match updates one set",
|
|
},
|
|
{
|
|
name: "wildcard domain gets firewall update",
|
|
configuredDomains: []string{"*.example.com"},
|
|
query: "mail.example.com",
|
|
mockIP: "1.1.1.2",
|
|
shouldResolve: true,
|
|
expectedSetCount: 1,
|
|
description: "Wildcard match updates one set",
|
|
},
|
|
{
|
|
name: "overlapping exact and wildcard both get updates",
|
|
configuredDomains: []string{"*.example.com", "mail.example.com"},
|
|
query: "mail.example.com",
|
|
mockIP: "1.1.1.3",
|
|
shouldResolve: true,
|
|
expectedSetCount: 2,
|
|
description: "Both exact and wildcard sets should be updated",
|
|
},
|
|
{
|
|
name: "unauthorized domain gets no firewall update",
|
|
configuredDomains: []string{"example.com"},
|
|
query: "mail.example.com",
|
|
mockIP: "1.1.1.4",
|
|
shouldResolve: false,
|
|
expectedSetCount: 0,
|
|
description: "No firewall update for unauthorized domains",
|
|
},
|
|
{
|
|
name: "multiple wildcards matching get all updated",
|
|
configuredDomains: []string{"*.example.com", "*.sub.example.com"},
|
|
query: "test.sub.example.com",
|
|
mockIP: "1.1.1.5",
|
|
shouldResolve: true,
|
|
expectedSetCount: 2,
|
|
description: "All matching wildcard sets should be updated",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockFirewall := &MockFirewall{}
|
|
mockResolver := &MockResolver{}
|
|
|
|
// Set up forwarder
|
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
|
forwarder.resolver = mockResolver
|
|
|
|
// Create entries and track sets
|
|
var entries []*ForwarderEntry
|
|
sets := make([]firewall.Set, 0)
|
|
|
|
for i, configDomain := range tt.configuredDomains {
|
|
d, err := domain.FromString(configDomain)
|
|
require.NoError(t, err)
|
|
|
|
set := firewall.NewDomainSet([]domain.Domain{d})
|
|
sets = append(sets, set)
|
|
|
|
entries = append(entries, &ForwarderEntry{
|
|
Domain: d,
|
|
ResID: route.ResID(fmt.Sprintf("res-%d", i)),
|
|
Set: set,
|
|
})
|
|
}
|
|
|
|
forwarder.UpdateDomains(entries)
|
|
|
|
// Set up mocks
|
|
if tt.shouldResolve {
|
|
fakeIP := netip.MustParseAddr(tt.mockIP)
|
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)).
|
|
Return([]netip.Addr{fakeIP}, nil).Once()
|
|
|
|
expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)}
|
|
|
|
// Count how many sets should actually match
|
|
updateCount := 0
|
|
for i, entry := range entries {
|
|
domain := strings.ToLower(tt.query)
|
|
pattern := entry.Domain.PunycodeString()
|
|
|
|
matches := false
|
|
if strings.HasPrefix(pattern, "*.") {
|
|
baseDomain := strings.TrimPrefix(pattern, "*.")
|
|
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
|
|
matches = true
|
|
}
|
|
} else if domain == pattern {
|
|
matches = true
|
|
}
|
|
|
|
if matches {
|
|
mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once()
|
|
updateCount++
|
|
}
|
|
}
|
|
|
|
assert.Equal(t, tt.expectedSetCount, updateCount,
|
|
"Expected %d sets to be updated, but mock expects %d",
|
|
tt.expectedSetCount, updateCount)
|
|
}
|
|
|
|
// Execute query
|
|
dnsQuery := &dns.Msg{}
|
|
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
|
|
|
mockWriter := &test.MockResponseWriter{}
|
|
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
|
|
|
|
// Verify response
|
|
if tt.shouldResolve {
|
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
|
require.NotEmpty(t, resp.Answer)
|
|
} else if resp != nil {
|
|
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
|
"Unauthorized domain should be refused or have no answers")
|
|
}
|
|
|
|
// Verify all mock expectations were met
|
|
mockFirewall.AssertExpectations(t)
|
|
mockResolver.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
// Test to verify that multiple IPs for one domain result in all prefixes being sent together
|
|
func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
|
mockFirewall := &MockFirewall{}
|
|
mockResolver := &MockResolver{}
|
|
|
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
|
forwarder.resolver = mockResolver
|
|
|
|
// Configure a single domain
|
|
d, err := domain.FromString("example.com")
|
|
require.NoError(t, err)
|
|
|
|
set := firewall.NewDomainSet([]domain.Domain{d})
|
|
entries := []*ForwarderEntry{{
|
|
Domain: d,
|
|
ResID: "test-res",
|
|
Set: set,
|
|
}}
|
|
|
|
forwarder.UpdateDomains(entries)
|
|
|
|
// Mock resolver returns multiple IPs
|
|
ips := []netip.Addr{
|
|
netip.MustParseAddr("1.1.1.1"),
|
|
netip.MustParseAddr("1.1.1.2"),
|
|
netip.MustParseAddr("1.1.1.3"),
|
|
}
|
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
|
Return(ips, nil).Once()
|
|
|
|
// Expect ONE UpdateSet call with ALL prefixes
|
|
expectedPrefixes := []netip.Prefix{
|
|
netip.PrefixFrom(ips[0], 32),
|
|
netip.PrefixFrom(ips[1], 32),
|
|
netip.PrefixFrom(ips[2], 32),
|
|
}
|
|
mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once()
|
|
|
|
// Execute query
|
|
query := &dns.Msg{}
|
|
query.SetQuestion("example.com.", dns.TypeA)
|
|
|
|
mockWriter := &test.MockResponseWriter{}
|
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
|
|
|
// Verify response contains all IPs
|
|
require.NotNil(t, resp)
|
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
|
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
|
|
|
// Verify mocks
|
|
mockFirewall.AssertExpectations(t)
|
|
mockResolver.AssertExpectations(t)
|
|
}
|
|
|
|
func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
queryType uint16
|
|
queryDomain string
|
|
configured string
|
|
expectedCode int
|
|
description string
|
|
}{
|
|
{
|
|
name: "unauthorized domain returns REFUSED",
|
|
queryType: dns.TypeA,
|
|
queryDomain: "evil.com",
|
|
configured: "example.com",
|
|
expectedCode: dns.RcodeRefused,
|
|
description: "RFC compliant REFUSED for unauthorized queries",
|
|
},
|
|
{
|
|
name: "unsupported query type returns NOTIMP",
|
|
queryType: dns.TypeMX,
|
|
queryDomain: "example.com",
|
|
configured: "example.com",
|
|
expectedCode: dns.RcodeNotImplemented,
|
|
description: "RFC compliant NOTIMP for unsupported types",
|
|
},
|
|
{
|
|
name: "CNAME query returns NOTIMP",
|
|
queryType: dns.TypeCNAME,
|
|
queryDomain: "example.com",
|
|
configured: "example.com",
|
|
expectedCode: dns.RcodeNotImplemented,
|
|
description: "CNAME queries not supported",
|
|
},
|
|
{
|
|
name: "TXT query returns NOTIMP",
|
|
queryType: dns.TypeTXT,
|
|
queryDomain: "example.com",
|
|
configured: "example.com",
|
|
expectedCode: dns.RcodeNotImplemented,
|
|
description: "TXT queries not supported",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
|
|
|
d, err := domain.FromString(tt.configured)
|
|
require.NoError(t, err)
|
|
|
|
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
|
|
forwarder.UpdateDomains(entries)
|
|
|
|
query := &dns.Msg{}
|
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
|
|
|
|
// Capture the written response
|
|
var writtenResp *dns.Msg
|
|
mockWriter := &test.MockResponseWriter{
|
|
WriteMsgFunc: func(m *dns.Msg) error {
|
|
writtenResp = m
|
|
return nil
|
|
},
|
|
}
|
|
|
|
_ = forwarder.handleDNSQuery(mockWriter, query)
|
|
|
|
// Check the response written to the writer
|
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
|
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
|
// Test that large UDP responses are truncated with TC bit set
|
|
mockResolver := &MockResolver{}
|
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
|
forwarder.resolver = mockResolver
|
|
|
|
d, _ := domain.FromString("example.com")
|
|
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
|
|
forwarder.UpdateDomains(entries)
|
|
|
|
// Mock many IPs to create a large response
|
|
var manyIPs []netip.Addr
|
|
for i := 0; i < 100; i++ {
|
|
manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256)))
|
|
}
|
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil)
|
|
|
|
// Query without EDNS0
|
|
query := &dns.Msg{}
|
|
query.SetQuestion("example.com.", dns.TypeA)
|
|
|
|
var writtenResp *dns.Msg
|
|
mockWriter := &test.MockResponseWriter{
|
|
WriteMsgFunc: func(m *dns.Msg) error {
|
|
writtenResp = m
|
|
return nil
|
|
},
|
|
}
|
|
forwarder.handleDNSQueryUDP(mockWriter, query)
|
|
|
|
require.NotNil(t, writtenResp)
|
|
assert.True(t, writtenResp.Truncated, "Large response should be truncated")
|
|
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
|
}
|
|
|
|
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
|
// Test complex overlapping pattern scenarios
|
|
mockFirewall := &MockFirewall{}
|
|
mockResolver := &MockResolver{}
|
|
|
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
|
forwarder.resolver = mockResolver
|
|
|
|
// Set up complex overlapping patterns
|
|
patterns := []string{
|
|
"*.example.com", // Matches all subdomains
|
|
"*.mail.example.com", // More specific wildcard
|
|
"smtp.mail.example.com", // Exact match
|
|
"example.com", // Base domain
|
|
}
|
|
|
|
var entries []*ForwarderEntry
|
|
sets := make(map[string]firewall.Set)
|
|
|
|
for _, pattern := range patterns {
|
|
d, _ := domain.FromString(pattern)
|
|
set := firewall.NewDomainSet([]domain.Domain{d})
|
|
sets[pattern] = set
|
|
entries = append(entries, &ForwarderEntry{
|
|
Domain: d,
|
|
ResID: route.ResID("res-" + pattern),
|
|
Set: set,
|
|
})
|
|
}
|
|
|
|
forwarder.UpdateDomains(entries)
|
|
|
|
// Test smtp.mail.example.com - should match 3 patterns
|
|
fakeIP := netip.MustParseAddr("1.2.3.4")
|
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil)
|
|
|
|
expectedPrefix := netip.PrefixFrom(fakeIP, 32)
|
|
// All three matching patterns should get firewall updates
|
|
mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
|
mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
|
mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
|
|
|
query := &dns.Msg{}
|
|
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
|
|
|
mockWriter := &test.MockResponseWriter{}
|
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
|
|
|
require.NotNil(t, resp)
|
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
|
|
|
// Verify all three sets were updated
|
|
mockFirewall.AssertExpectations(t)
|
|
|
|
// Verify the most specific ResID was selected
|
|
// (exact match should win over wildcards)
|
|
resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com")
|
|
assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID)
|
|
assert.Len(t, matches, 3, "Should match 3 patterns")
|
|
}
|
|
|
|
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
|
// Test handling of malformed query with no questions
|
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
|
|
|
query := &dns.Msg{}
|
|
// Don't set any question
|
|
|
|
writeCalled := false
|
|
mockWriter := &test.MockResponseWriter{
|
|
WriteMsgFunc: func(m *dns.Msg) error {
|
|
writeCalled = true
|
|
return nil
|
|
},
|
|
}
|
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
|
|
|
assert.Nil(t, resp, "Should return nil for empty query")
|
|
assert.False(t, writeCalled, "Should not write response for empty query")
|
|
}
|