mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-17 02:21:47 +02:00
[client] Automatically register match domains for DNS routes (#3614)
This commit is contained in:
@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
|
||||
continue
|
||||
}
|
||||
|
||||
listOfDomains = append(listOfDomains, dConf.Domain)
|
||||
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||
}
|
||||
return listOfDomains
|
||||
}
|
||||
|
@ -75,12 +75,7 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
}
|
||||
|
||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
c.removeEntry(origPattern, priority)
|
||||
|
||||
// Check if handler implements SubdomainMatcher interface
|
||||
matchSubdomains := false
|
||||
@ -133,30 +128,20 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
||||
|
||||
pattern = dns.Fqdn(pattern)
|
||||
|
||||
c.removeEntry(pattern, priority)
|
||||
}
|
||||
|
||||
func (c *HandlerChain) removeEntry(pattern string, priority int) {
|
||||
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
return
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HasHandlers returns true if there are any handlers remaining for the given pattern
|
||||
func (c *HandlerChain) HasHandlers(pattern string) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||
for _, entry := range c.handlers {
|
||||
if strings.EqualFold(entry.Pattern, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
|
@ -443,14 +443,6 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
for _, handler := range handlers {
|
||||
handler.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Verify handler exists check
|
||||
for priority, shouldExist := range tt.expectedCalls {
|
||||
if shouldExist {
|
||||
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
|
||||
"Handler chain should have handlers for pattern after removing priority %d", priority)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -470,45 +462,69 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(testQuery, dns.TypeA)
|
||||
|
||||
// Keep track of mocks for the final assertion in Step 4
|
||||
mocks := []*nbdns.MockSubdomainHandler{routeHandler, matchHandler, defaultHandler}
|
||||
|
||||
// Add handlers in mixed order
|
||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||
|
||||
// Test 1: Initial state with all three handlers
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Test 1: Initial state
|
||||
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Highest priority handler (routeHandler) should be called
|
||||
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
chain.ServeDNS(w1, r)
|
||||
routeHandler.AssertExpectations(t)
|
||||
|
||||
routeHandler.ExpectedCalls = nil
|
||||
routeHandler.Calls = nil
|
||||
matchHandler.ExpectedCalls = nil
|
||||
matchHandler.Calls = nil
|
||||
defaultHandler.ExpectedCalls = nil
|
||||
defaultHandler.Calls = nil
|
||||
|
||||
// Test 2: Remove highest priority handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||
assert.True(t, chain.HasHandlers(testDomain))
|
||||
|
||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Now middle priority handler (matchHandler) should be called
|
||||
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
chain.ServeDNS(w2, r)
|
||||
matchHandler.AssertExpectations(t)
|
||||
|
||||
matchHandler.ExpectedCalls = nil
|
||||
matchHandler.Calls = nil
|
||||
defaultHandler.ExpectedCalls = nil
|
||||
defaultHandler.Calls = nil
|
||||
|
||||
// Test 3: Remove middle priority handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||
assert.True(t, chain.HasHandlers(testDomain))
|
||||
|
||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Now lowest priority handler (defaultHandler) should be called
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
chain.ServeDNS(w3, r)
|
||||
defaultHandler.AssertExpectations(t)
|
||||
|
||||
defaultHandler.ExpectedCalls = nil
|
||||
defaultHandler.Calls = nil
|
||||
|
||||
// Test 4: Remove last handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||
|
||||
assert.False(t, chain.HasHandlers(testDomain))
|
||||
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
|
||||
|
||||
for _, m := range mocks {
|
||||
m.AssertNumberOfCalls(t, "ServeDNS", 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
@ -830,3 +846,165 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addPattern string
|
||||
removePattern string
|
||||
queryPattern string
|
||||
shouldBeRemoved bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "exact same pattern",
|
||||
addPattern: "example.com.",
|
||||
removePattern: "example.com.",
|
||||
queryPattern: "example.com.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding and removing with identical patterns",
|
||||
},
|
||||
{
|
||||
name: "case difference",
|
||||
addPattern: "Example.Com.",
|
||||
removePattern: "EXAMPLE.COM.",
|
||||
queryPattern: "example.com.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding with mixed case, removing with uppercase",
|
||||
},
|
||||
{
|
||||
name: "reversed case difference",
|
||||
addPattern: "EXAMPLE.ORG.",
|
||||
removePattern: "example.org.",
|
||||
queryPattern: "example.org.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding with uppercase, removing with lowercase",
|
||||
},
|
||||
{
|
||||
name: "add wildcard, remove wildcard",
|
||||
addPattern: "*.example.com.",
|
||||
removePattern: "*.example.com.",
|
||||
queryPattern: "sub.example.com.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding and removing with identical wildcard patterns",
|
||||
},
|
||||
{
|
||||
name: "add wildcard, remove transformed pattern",
|
||||
addPattern: "*.example.net.",
|
||||
removePattern: "example.net.",
|
||||
queryPattern: "sub.example.net.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding with wildcard, removing with non-wildcard pattern",
|
||||
},
|
||||
{
|
||||
name: "add transformed pattern, remove wildcard",
|
||||
addPattern: "example.io.",
|
||||
removePattern: "*.example.io.",
|
||||
queryPattern: "example.io.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding with non-wildcard pattern, removing with wildcard pattern",
|
||||
},
|
||||
{
|
||||
name: "trailing dot difference",
|
||||
addPattern: "example.dev",
|
||||
removePattern: "example.dev.",
|
||||
queryPattern: "example.dev.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding without trailing dot, removing with trailing dot",
|
||||
},
|
||||
{
|
||||
name: "reversed trailing dot difference",
|
||||
addPattern: "example.app.",
|
||||
removePattern: "example.app",
|
||||
queryPattern: "example.app.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding with trailing dot, removing without trailing dot",
|
||||
},
|
||||
{
|
||||
name: "mixed case and wildcard",
|
||||
addPattern: "*.Example.Site.",
|
||||
removePattern: "*.EXAMPLE.SITE.",
|
||||
queryPattern: "sub.example.site.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding mixed case wildcard, removing uppercase wildcard",
|
||||
},
|
||||
{
|
||||
name: "root zone",
|
||||
addPattern: ".",
|
||||
removePattern: ".",
|
||||
queryPattern: "random.domain.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding and removing root zone",
|
||||
},
|
||||
{
|
||||
name: "wrong domain",
|
||||
addPattern: "example.com.",
|
||||
removePattern: "different.com.",
|
||||
queryPattern: "example.com.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding one domain, trying to remove a different domain",
|
||||
},
|
||||
{
|
||||
name: "subdomain mismatch",
|
||||
addPattern: "sub.example.com.",
|
||||
removePattern: "example.com.",
|
||||
queryPattern: "sub.example.com.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding subdomain, trying to remove parent domain",
|
||||
},
|
||||
{
|
||||
name: "parent domain mismatch",
|
||||
addPattern: "example.com.",
|
||||
removePattern: "sub.example.com.",
|
||||
queryPattern: "example.com.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding parent domain, trying to remove subdomain",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
handler := &nbdns.MockHandler{}
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryPattern, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
|
||||
// First verify no handler is called before adding any
|
||||
chain.ServeDNS(w, r)
|
||||
handler.AssertNotCalled(t, "ServeDNS")
|
||||
|
||||
// Add handler
|
||||
chain.AddHandler(tt.addPattern, handler, nbdns.PriorityDefault)
|
||||
|
||||
// Verify handler is called after adding
|
||||
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||
chain.ServeDNS(w, r)
|
||||
handler.AssertExpectations(t)
|
||||
|
||||
// Reset mock for the next test
|
||||
handler.ExpectedCalls = nil
|
||||
|
||||
// Remove handler
|
||||
chain.RemoveHandler(tt.removePattern, nbdns.PriorityDefault)
|
||||
|
||||
// Set up expectations based on whether removal should succeed
|
||||
if !tt.shouldBeRemoved {
|
||||
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||
}
|
||||
|
||||
// Test if handler is still called after removal attempt
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
if tt.shouldBeRemoved {
|
||||
handler.AssertNotCalled(t, "ServeDNS",
|
||||
"Handler should not be called after successful removal with pattern %q",
|
||||
tt.removePattern)
|
||||
} else {
|
||||
handler.AssertExpectations(t)
|
||||
handler.ExpectedCalls = nil
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,8 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
@ -12,8 +14,8 @@ import (
|
||||
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||
|
||||
const (
|
||||
ipv4ReverseZone = ".in-addr.arpa"
|
||||
ipv6ReverseZone = ".ip6.arpa"
|
||||
ipv4ReverseZone = ".in-addr.arpa."
|
||||
ipv6ReverseZone = ".ip6.arpa."
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
@ -103,7 +105,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
||||
|
||||
for _, domain := range nsConfig.Domains {
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.TrimSuffix(domain, "."),
|
||||
Domain: strings.ToLower(dns.Fqdn(domain)),
|
||||
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
||||
})
|
||||
}
|
||||
@ -112,7 +114,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
||||
for _, customZone := range dnsConfig.CustomZones {
|
||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.TrimSuffix(customZone.Domain, "."),
|
||||
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
||||
MatchOnly: matchOnly,
|
||||
})
|
||||
}
|
||||
|
@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
continue
|
||||
}
|
||||
if dConf.MatchOnly {
|
||||
matchDomains = append(matchDomains, dConf.Domain)
|
||||
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
||||
}
|
||||
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
|
@ -17,9 +17,12 @@ import (
|
||||
|
||||
var (
|
||||
userenv = syscall.NewLazyDLL("userenv.dll")
|
||||
dnsapi = syscall.NewLazyDLL("dnsapi.dll")
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
|
||||
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
|
||||
|
||||
dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache")
|
||||
)
|
||||
|
||||
const (
|
||||
@ -97,9 +100,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
continue
|
||||
}
|
||||
if !dConf.MatchOnly {
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||
}
|
||||
matchDomains = append(matchDomains, "."+dConf.Domain)
|
||||
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
||||
}
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
@ -116,6 +119,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
return fmt.Errorf("update search domains: %w", err)
|
||||
}
|
||||
|
||||
if err := r.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -184,6 +191,26 @@ func (r *registryConfigurator) string() string {
|
||||
return "registry"
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) flushDNSCache() error {
|
||||
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Errorf("Recovered from panic in flushDNSCache: %v", rec)
|
||||
}
|
||||
}()
|
||||
|
||||
ret, _, err := dnsFlushResolverCacheFn.Call()
|
||||
if ret == 0 {
|
||||
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
||||
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
|
||||
}
|
||||
return fmt.Errorf("DnsFlushResolverCache failed")
|
||||
}
|
||||
|
||||
log.Info("flushed DNS cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
|
||||
return fmt.Errorf("update search domains: %w", err)
|
||||
@ -236,6 +263,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("remove interface registry key: %w", err)
|
||||
}
|
||||
|
||||
if err := r.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// MockServer is the mock instance of a dns server
|
||||
@ -13,17 +14,17 @@ type MockServer struct {
|
||||
InitializeFunc func() error
|
||||
StopFunc func()
|
||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||
RegisterHandlerFunc func([]string, dns.Handler, int)
|
||||
DeregisterHandlerFunc func([]string, int)
|
||||
RegisterHandlerFunc func(domain.List, dns.Handler, int)
|
||||
DeregisterHandlerFunc func(domain.List, int)
|
||||
}
|
||||
|
||||
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||
if m.RegisterHandlerFunc != nil {
|
||||
m.RegisterHandlerFunc(domains, handler, priority)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockServer) DeregisterHandler(domains []string, priority int) {
|
||||
func (m *MockServer) DeregisterHandler(domains domain.List, priority int) {
|
||||
if m.DeregisterHandlerFunc != nil {
|
||||
m.DeregisterHandlerFunc(domains, priority)
|
||||
}
|
||||
|
@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@ -126,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
|
||||
continue
|
||||
}
|
||||
if dConf.MatchOnly {
|
||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain))
|
||||
matchDomains = append(matchDomains, "~."+dConf.Domain)
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain))
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
}
|
||||
|
||||
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
||||
|
@ -6,11 +6,13 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
@ -18,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||
@ -32,8 +35,8 @@ type IosDnsManager interface {
|
||||
|
||||
// Server is a dns server interface
|
||||
type Server interface {
|
||||
RegisterHandler(domains []string, handler dns.Handler, priority int)
|
||||
DeregisterHandler(domains []string, priority int)
|
||||
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
||||
DeregisterHandler(domains domain.List, priority int)
|
||||
Initialize() error
|
||||
Stop()
|
||||
DnsIP() string
|
||||
@ -65,6 +68,7 @@ type DefaultServer struct {
|
||||
previousConfigHash uint64
|
||||
currentConfig HostDNSConfig
|
||||
handlerChain *HandlerChain
|
||||
extraDomains map[domain.Domain]int
|
||||
|
||||
// permanent related properties
|
||||
permanent bool
|
||||
@ -164,13 +168,15 @@ func newDefaultServer(
|
||||
stateManager *statemanager.Manager,
|
||||
disableSys bool,
|
||||
) *DefaultServer {
|
||||
handlerChain := NewHandlerChain()
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
disableSys: disableSys,
|
||||
service: dnsService,
|
||||
handlerChain: NewHandlerChain(),
|
||||
handlerChain: handlerChain,
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
@ -181,14 +187,26 @@ func newDefaultServer(
|
||||
hostsDNSHolder: newHostsDNSHolder(),
|
||||
}
|
||||
|
||||
// register with root zone, handler chain takes care of the routing
|
||||
dnsService.RegisterMux(".", handlerChain)
|
||||
|
||||
return defaultServer
|
||||
}
|
||||
|
||||
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||
// RegisterHandler registers a handler for the given domains with the given priority.
|
||||
// Any previously registered handler for the same domain and priority will be replaced.
|
||||
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.registerHandler(domains, handler, priority)
|
||||
s.registerHandler(domains.ToPunycodeList(), handler, priority)
|
||||
|
||||
// TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain
|
||||
for _, domain := range domains {
|
||||
// convert to zone with simple ref counter
|
||||
s.extraDomains[toZone(domain)]++
|
||||
}
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||
@ -200,15 +218,23 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
|
||||
continue
|
||||
}
|
||||
s.handlerChain.AddHandler(domain, handler, priority)
|
||||
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) DeregisterHandler(domains []string, priority int) {
|
||||
// DeregisterHandler deregisters the handler for the given domains with the given priority.
|
||||
func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.deregisterHandler(domains, priority)
|
||||
s.deregisterHandler(domains.ToPunycodeList(), priority)
|
||||
for _, domain := range domains {
|
||||
zone := toZone(domain)
|
||||
s.extraDomains[zone]--
|
||||
if s.extraDomains[zone] <= 0 {
|
||||
delete(s.extraDomains, zone)
|
||||
}
|
||||
}
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||
@ -221,11 +247,6 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||
}
|
||||
|
||||
s.handlerChain.RemoveHandler(domain, priority)
|
||||
|
||||
// Only deregister from service if no handlers remain
|
||||
if !s.handlerChain.HasHandlers(domain) {
|
||||
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -286,6 +307,8 @@ func (s *DefaultServer) Stop() {
|
||||
}
|
||||
|
||||
s.service.Stop()
|
||||
|
||||
maps.Clear(s.extraDomains)
|
||||
}
|
||||
|
||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||
@ -390,7 +413,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be Disabled, we stop the listener or fake resolver
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if update.ServiceEnable {
|
||||
_ = s.service.Listen()
|
||||
if err := s.service.Listen(); err != nil {
|
||||
log.Errorf("failed to start DNS service: %v", err)
|
||||
}
|
||||
} else if !s.permanent {
|
||||
s.service.Stop()
|
||||
}
|
||||
@ -413,17 +438,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||
|
||||
hostUpdate := s.currentConfig
|
||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
||||
hostUpdate.RouteAll = false
|
||||
s.currentConfig.RouteAll = false
|
||||
}
|
||||
|
||||
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
|
||||
log.Error(err)
|
||||
s.handleErrNoGroupaAll(err)
|
||||
}
|
||||
s.applyHostConfig()
|
||||
|
||||
go func() {
|
||||
// persist dns state right away
|
||||
@ -441,6 +462,38 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) applyHostConfig() {
|
||||
if s.hostManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
config := s.currentConfig
|
||||
|
||||
existingDomains := make(map[string]struct{})
|
||||
for _, d := range config.Domains {
|
||||
existingDomains[d.Domain] = struct{}{}
|
||||
}
|
||||
|
||||
// add extra domains only if they're not already in the config
|
||||
for domain := range s.extraDomains {
|
||||
domainStr := domain.PunycodeString()
|
||||
|
||||
if _, exists := existingDomains[domainStr]; !exists {
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: domainStr,
|
||||
MatchOnly: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("extra match domains: %v", s.extraDomains)
|
||||
|
||||
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
||||
log.Errorf("failed to apply DNS host manager update: %v", err)
|
||||
s.handleErrNoGroupaAll(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) handleErrNoGroupaAll(err error) {
|
||||
if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) {
|
||||
return
|
||||
@ -690,10 +743,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
s.handleErrNoGroupaAll(err)
|
||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||
}
|
||||
s.applyHostConfig()
|
||||
|
||||
go func() {
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
@ -728,12 +778,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||
}
|
||||
|
||||
if s.hostManager != nil {
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
s.handleErrNoGroupaAll(err)
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
}
|
||||
}
|
||||
s.applyHostConfig()
|
||||
|
||||
s.updateNSState(nsGroup, nil, true)
|
||||
}
|
||||
@ -836,3 +881,13 @@ func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func toZone(d domain.Domain) domain.Domain {
|
||||
return domain.Domain(
|
||||
nbdns.NormalizeZone(
|
||||
dns.Fqdn(
|
||||
strings.ToLower(d.PunycodeString()),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
@ -29,6 +29,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
@ -38,7 +39,7 @@ type mocWGIface struct {
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Name() string {
|
||||
panic("implement me")
|
||||
return "utun2301"
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Address() wgaddr.Address {
|
||||
@ -1448,3 +1449,497 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraDomains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialConfig nbdns.Config
|
||||
registerDomains []domain.List
|
||||
deregisterDomains []domain.List
|
||||
finalConfig nbdns.Config
|
||||
expectedDomains []string
|
||||
expectedMatchOnly []string
|
||||
applyHostConfigCall int
|
||||
}{
|
||||
{
|
||||
name: "Register domains before config update",
|
||||
registerDomains: []domain.List{
|
||||
{"extra1.example.com", "extra2.example.com"},
|
||||
},
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 2,
|
||||
},
|
||||
{
|
||||
name: "Register domains after config update",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra1.example.com", "extra2.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 2,
|
||||
},
|
||||
{
|
||||
name: "Register overlapping domains",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
{Domain: "overlap.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "overlap.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"overlap.example.com.",
|
||||
"extra.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 2,
|
||||
},
|
||||
{
|
||||
name: "Register and deregister domains",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra1.example.com", "extra2.example.com"},
|
||||
{"extra3.example.com", "extra4.example.com"},
|
||||
},
|
||||
deregisterDomains: []domain.List{
|
||||
{"extra1.example.com", "extra3.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"extra2.example.com.",
|
||||
"extra4.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra2.example.com.",
|
||||
"extra4.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 4,
|
||||
},
|
||||
{
|
||||
name: "Register domains with ref counter",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "duplicate.example.com"},
|
||||
{"other.example.com", "duplicate.example.com"},
|
||||
},
|
||||
deregisterDomains: []domain.List{
|
||||
{"duplicate.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"extra.example.com.",
|
||||
"other.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra.example.com.",
|
||||
"other.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 4,
|
||||
},
|
||||
{
|
||||
name: "Config update with new domains after registration",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "duplicate.example.com"},
|
||||
},
|
||||
finalConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
{Domain: "newconfig.example.com"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"newconfig.example.com.",
|
||||
"extra.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 3,
|
||||
},
|
||||
{
|
||||
name: "Deregister domain that is part of customZones",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
{Domain: "protected.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "protected.example.com"},
|
||||
},
|
||||
deregisterDomains: []domain.List{
|
||||
{"protected.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"extra.example.com.",
|
||||
"config.example.com.",
|
||||
"protected.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 3,
|
||||
},
|
||||
{
|
||||
name: "Register domain that is part of nameserver group",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"ns.example.com", "overlap.ns.example.com"},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "overlap.ns.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"ns.example.com.",
|
||||
"overlap.ns.example.com.",
|
||||
"extra.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"ns.example.com.",
|
||||
"overlap.ns.example.com.",
|
||||
"extra.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedConfigs []HostDNSConfig
|
||||
mockHostConfig := &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
capturedConfigs = append(capturedConfigs, config)
|
||||
return nil
|
||||
},
|
||||
restoreHostDNSFunc: func() error {
|
||||
return nil
|
||||
},
|
||||
supportCustomPortFunc: func() bool {
|
||||
return true
|
||||
},
|
||||
stringFunc: func() string {
|
||||
return "mock"
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc := &mockService{}
|
||||
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
wgInterface: &mocWGIface{},
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
}
|
||||
|
||||
// Apply initial configuration
|
||||
if tt.initialConfig.ServiceEnable {
|
||||
err := server.applyConfiguration(tt.initialConfig)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Register domains
|
||||
for _, domains := range tt.registerDomains {
|
||||
server.RegisterHandler(domains, &MockHandler{}, PriorityDefault)
|
||||
}
|
||||
|
||||
// Deregister domains if specified
|
||||
for _, domains := range tt.deregisterDomains {
|
||||
server.DeregisterHandler(domains, PriorityDefault)
|
||||
}
|
||||
|
||||
// Apply final configuration if specified
|
||||
if tt.finalConfig.ServiceEnable {
|
||||
err := server.applyConfiguration(tt.finalConfig)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify number of calls
|
||||
assert.Equal(t, tt.applyHostConfigCall, len(capturedConfigs),
|
||||
"Expected %d calls to applyDNSConfig, got %d", tt.applyHostConfigCall, len(capturedConfigs))
|
||||
|
||||
// Get the last applied config
|
||||
lastConfig := capturedConfigs[len(capturedConfigs)-1]
|
||||
|
||||
// Check all expected domains are present
|
||||
domainMap := make(map[string]bool)
|
||||
matchOnlyMap := make(map[string]bool)
|
||||
|
||||
for _, d := range lastConfig.Domains {
|
||||
domainMap[d.Domain] = true
|
||||
if d.MatchOnly {
|
||||
matchOnlyMap[d.Domain] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Verify expected domains
|
||||
for _, d := range tt.expectedDomains {
|
||||
assert.True(t, domainMap[d], "Expected domain %s not found in final config", d)
|
||||
}
|
||||
|
||||
// Verify match-only domains
|
||||
for _, d := range tt.expectedMatchOnly {
|
||||
assert.True(t, matchOnlyMap[d], "Expected match-only domain %s not found in final config", d)
|
||||
}
|
||||
|
||||
// Verify no unexpected domains
|
||||
assert.Equal(t, len(tt.expectedDomains), len(domainMap), "Unexpected number of domains in final config")
|
||||
assert.Equal(t, len(tt.expectedMatchOnly), len(matchOnlyMap), "Unexpected number of match-only domains in final config")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraDomainsRefCounting(t *testing.T) {
|
||||
mockHostConfig := &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
return nil
|
||||
},
|
||||
restoreHostDNSFunc: func() error {
|
||||
return nil
|
||||
},
|
||||
supportCustomPortFunc: func() bool {
|
||||
return true
|
||||
},
|
||||
stringFunc: func() string {
|
||||
return "mock"
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc := &mockService{}
|
||||
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
}
|
||||
|
||||
// Register domains from different handlers with same domain
|
||||
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
|
||||
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
|
||||
|
||||
// Verify refcount is 2
|
||||
zoneKey := toZone("shared.example.com")
|
||||
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
|
||||
|
||||
// Deregister one handler
|
||||
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
|
||||
|
||||
// Verify refcount is 1
|
||||
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
|
||||
|
||||
// Deregister the other handler
|
||||
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityDNSRoute)
|
||||
|
||||
// Verify domain is removed
|
||||
_, exists := server.extraDomains[zoneKey]
|
||||
assert.False(t, exists, "Domain should be removed after deregistering all handlers")
|
||||
}
|
||||
|
||||
func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
||||
var capturedConfig HostDNSConfig
|
||||
mockHostConfig := &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
capturedConfig = config
|
||||
return nil
|
||||
},
|
||||
restoreHostDNSFunc: func() error {
|
||||
return nil
|
||||
},
|
||||
supportCustomPortFunc: func() bool {
|
||||
return true
|
||||
},
|
||||
stringFunc: func() string {
|
||||
return "mock"
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc := &mockService{}
|
||||
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
}
|
||||
|
||||
server.RegisterHandler(domain.List{"extra.example.com"}, &MockHandler{}, PriorityDefault)
|
||||
|
||||
initialConfig := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
}
|
||||
err := server.applyConfiguration(initialConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var domains []string
|
||||
for _, d := range capturedConfig.Domains {
|
||||
domains = append(domains, d.Domain)
|
||||
}
|
||||
assert.Contains(t, domains, "config.example.com.")
|
||||
assert.Contains(t, domains, "extra.example.com.")
|
||||
|
||||
// Now apply a new configuration with overlapping domain
|
||||
updatedConfig := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
{Domain: "extra.example.com"},
|
||||
},
|
||||
}
|
||||
err = server.applyConfiguration(updatedConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify both domains are in config, but no duplicates
|
||||
domains = []string{}
|
||||
matchOnlyCount := 0
|
||||
for _, d := range capturedConfig.Domains {
|
||||
domains = append(domains, d.Domain)
|
||||
if d.MatchOnly {
|
||||
matchOnlyCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Contains(t, domains, "config.example.com.")
|
||||
assert.Contains(t, domains, "extra.example.com.")
|
||||
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
|
||||
|
||||
// Extra domain should no longer be marked as match-only when in config
|
||||
matchOnlyDomain := ""
|
||||
for _, d := range capturedConfig.Domains {
|
||||
if d.Domain == "extra.example.com." && d.MatchOnly {
|
||||
matchOnlyDomain = d.Domain
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Empty(t, matchOnlyDomain, "Domain should not be match-only when included in config")
|
||||
}
|
||||
|
||||
func TestDomainCaseHandling(t *testing.T) {
|
||||
var capturedConfig HostDNSConfig
|
||||
mockHostConfig := &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
capturedConfig = config
|
||||
return nil
|
||||
},
|
||||
restoreHostDNSFunc: func() error {
|
||||
return nil
|
||||
},
|
||||
supportCustomPortFunc: func() bool {
|
||||
return true
|
||||
},
|
||||
stringFunc: func() string {
|
||||
return "mock"
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc := &mockService{}
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
}
|
||||
|
||||
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
|
||||
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
|
||||
|
||||
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
|
||||
|
||||
config := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
}
|
||||
err := server.applyConfiguration(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var domains []string
|
||||
for _, d := range capturedConfig.Domains {
|
||||
domains = append(domains, d.Domain)
|
||||
}
|
||||
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
|
||||
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
|
||||
}
|
||||
|
@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
@ -111,7 +110,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
continue
|
||||
}
|
||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||
Domain: dns.Fqdn(dConf.Domain),
|
||||
Domain: dConf.Domain,
|
||||
MatchOnly: dConf.MatchOnly,
|
||||
})
|
||||
|
||||
@ -151,6 +150,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -163,7 +167,8 @@ func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdD
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting domains configuration failed with error: %w", err)
|
||||
}
|
||||
return s.flushCaches()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
||||
@ -183,10 +188,14 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
||||
}
|
||||
|
||||
return s.flushCaches()
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) flushCaches() error {
|
||||
func (s *systemdDbusConfigurator) flushDNSCache() error {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err)
|
||||
|
@ -23,9 +23,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
UpstreamTimeout = 15 * time.Second
|
||||
|
||||
failsTillDeact = int32(5)
|
||||
reactivatePeriod = 30 * time.Second
|
||||
upstreamTimeout = 15 * time.Second
|
||||
probeTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
@ -66,7 +67,7 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
domain: domain,
|
||||
upstreamTimeout: upstreamTimeout,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
failsTillDeact: failsTillDeact,
|
||||
statusRecorder: statusRecorder,
|
||||
|
@ -55,7 +55,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
|
||||
|
||||
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
timeout := upstreamTimeout
|
||||
timeout := UpstreamTimeout
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
timeout = time.Until(deadline)
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
|
||||
}
|
||||
|
||||
timeout := upstreamTimeout
|
||||
timeout := UpstreamTimeout
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
timeout = time.Until(deadline)
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
name: "Should Resolve A Record",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
timeout: upstreamTimeout,
|
||||
timeout: UpstreamTimeout,
|
||||
expectedAnswer: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
@ -48,7 +48,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
||||
cancelCTX: true,
|
||||
timeout: upstreamTimeout,
|
||||
timeout: UpstreamTimeout,
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
}
|
||||
@ -122,7 +122,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
r: new(dns.Msg),
|
||||
rtt: time.Millisecond,
|
||||
},
|
||||
upstreamTimeout: upstreamTimeout,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
failsTillDeact: failsTillDeact,
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
@ -60,7 +59,7 @@ func (d *DnsInterceptor) String() string {
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) AddRoute(context.Context) error {
|
||||
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute)
|
||||
d.dnsServer.RegisterHandler(d.route.Domains, d, nbdns.PriorityDNSRoute)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -89,7 +88,7 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
||||
clear(d.interceptedDomains)
|
||||
d.mu.Unlock()
|
||||
|
||||
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute)
|
||||
d.dnsServer.DeregisterHandler(d.route.Domains, nbdns.PriorityDNSRoute)
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
@ -142,21 +141,24 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
log.Tracef("received DNS request for domain=%s type=%v class=%v",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
|
||||
// pass if non A/AAAA query
|
||||
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
|
||||
d.continueToNextHandler(w, r, "non A/AAAA query")
|
||||
return
|
||||
}
|
||||
|
||||
d.mu.RLock()
|
||||
peerKey := d.currentPeerKey
|
||||
d.mu.RUnlock()
|
||||
|
||||
if peerKey == "" {
|
||||
log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name)
|
||||
|
||||
d.continueToNextHandler(w, r, "no current peer key")
|
||||
d.writeDNSError(w, r, "no current peer key")
|
||||
return
|
||||
}
|
||||
|
||||
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get upstream IP: %v", err)
|
||||
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
|
||||
d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
@ -165,34 +167,43 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
r.SetEdns0(4096, false)
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Timeout: nbdns.UpstreamTimeout,
|
||||
Net: "udp",
|
||||
}
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
||||
|
||||
var answer []dns.RR
|
||||
if reply != nil {
|
||||
answer = reply.Answer
|
||||
}
|
||||
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
||||
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var answer []dns.RR
|
||||
if reply != nil {
|
||||
answer = reply.Answer
|
||||
}
|
||||
|
||||
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
||||
|
||||
reply.Id = r.Id
|
||||
if err := d.writeMsg(w, reply); err != nil {
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||
log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
|
||||
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeServerFailure)
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS error response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// continueToNextHandler signals the handler chain to try the next handler
|
||||
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
||||
|
@ -235,7 +235,7 @@ func (r *Route) resolve(results chan resolveResult) {
|
||||
ips, err := r.getIPsFromResolver(domain)
|
||||
if err != nil {
|
||||
log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err)
|
||||
ips, err = net.LookupIP(string(domain))
|
||||
ips, err = net.LookupIP(domain.PunycodeString())
|
||||
if err != nil {
|
||||
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
|
||||
return
|
||||
|
@ -9,5 +9,5 @@ import (
|
||||
)
|
||||
|
||||
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
||||
return net.LookupIP(string(domain))
|
||||
return net.LookupIP(domain.PunycodeString())
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA)
|
||||
msg.SetQuestion(dns.Fqdn(domain.PunycodeString()), dns.TypeA)
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
|
@ -100,7 +100,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
|
||||
// Convert to proto format
|
||||
for domain, ips := range domainMap {
|
||||
pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{
|
||||
pbRoute.ResolvedIPs[domain.PunycodeString()] = &proto.IPList{
|
||||
Ips: ips,
|
||||
}
|
||||
}
|
||||
|
@ -24,6 +24,11 @@ func (d Domain) SafeString() string {
|
||||
return str
|
||||
}
|
||||
|
||||
// PunycodeString returns the punycode representation of the Domain.
|
||||
func (d Domain) PunycodeString() string {
|
||||
return string(d)
|
||||
}
|
||||
|
||||
// FromString creates a Domain from a string, converting it to punycode.
|
||||
func FromString(s string) (Domain, error) {
|
||||
ascii, err := idna.ToASCII(s)
|
||||
|
Reference in New Issue
Block a user