[client] Automatically register match domains for DNS routes (#3614)

This commit is contained in:
Viktor Liu
2025-04-07 15:18:45 +02:00
committed by GitHub
parent 6162aeb82d
commit 87e600a4f3
21 changed files with 892 additions and 120 deletions

View File

@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
continue continue
} }
listOfDomains = append(listOfDomains, dConf.Domain) listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, "."))
} }
return listOfDomains return listOfDomains
} }

View File

@ -75,12 +75,7 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
} }
// First remove any existing handler with same pattern (case-insensitive) and priority // First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { c.removeEntry(origPattern, priority)
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break
}
}
// Check if handler implements SubdomainMatcher interface // Check if handler implements SubdomainMatcher interface
matchSubdomains := false matchSubdomains := false
@ -133,30 +128,20 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
pattern = dns.Fqdn(pattern) pattern = dns.Fqdn(pattern)
c.removeEntry(pattern, priority)
}
func (c *HandlerChain) removeEntry(pattern string, priority int) {
// Find and remove handlers matching both original pattern (case-insensitive) and priority // Find and remove handlers matching both original pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i] entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return break
} }
} }
} }
// HasHandlers returns true if there are any handlers remaining for the given pattern
func (c *HandlerChain) HasHandlers(pattern string) bool {
c.mu.RLock()
defer c.mu.RUnlock()
pattern = strings.ToLower(dns.Fqdn(pattern))
for _, entry := range c.handlers {
if strings.EqualFold(entry.Pattern, pattern) {
return true
}
}
return false
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return

View File

@ -443,14 +443,6 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
for _, handler := range handlers { for _, handler := range handlers {
handler.AssertExpectations(t) handler.AssertExpectations(t)
} }
// Verify handler exists check
for priority, shouldExist := range tt.expectedCalls {
if shouldExist {
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
"Handler chain should have handlers for pattern after removing priority %d", priority)
}
}
}) })
} }
} }
@ -470,45 +462,69 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(testQuery, dns.TypeA) r.SetQuestion(testQuery, dns.TypeA)
// Keep track of mocks for the final assertion in Step 4
mocks := []*nbdns.MockSubdomainHandler{routeHandler, matchHandler, defaultHandler}
// Add handlers in mixed order // Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault) chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
// Test 1: Initial state with all three handlers // Test 1: Initial state
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Highest priority handler (routeHandler) should be called // Highest priority handler (routeHandler) should be called
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once() routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
chain.ServeDNS(w, r) chain.ServeDNS(w1, r)
routeHandler.AssertExpectations(t) routeHandler.AssertExpectations(t)
routeHandler.ExpectedCalls = nil
routeHandler.Calls = nil
matchHandler.ExpectedCalls = nil
matchHandler.Calls = nil
defaultHandler.ExpectedCalls = nil
defaultHandler.Calls = nil
// Test 2: Remove highest priority handler // Test 2: Remove highest priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute) chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
assert.True(t, chain.HasHandlers(testDomain))
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Now middle priority handler (matchHandler) should be called // Now middle priority handler (matchHandler) should be called
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once() matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
chain.ServeDNS(w, r) chain.ServeDNS(w2, r)
matchHandler.AssertExpectations(t) matchHandler.AssertExpectations(t)
matchHandler.ExpectedCalls = nil
matchHandler.Calls = nil
defaultHandler.ExpectedCalls = nil
defaultHandler.Calls = nil
// Test 3: Remove middle priority handler // Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
assert.True(t, chain.HasHandlers(testDomain))
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called // Now lowest priority handler (defaultHandler) should be called
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once() defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
chain.ServeDNS(w, r) chain.ServeDNS(w3, r)
defaultHandler.AssertExpectations(t) defaultHandler.AssertExpectations(t)
defaultHandler.ExpectedCalls = nil
defaultHandler.Calls = nil
// Test 4: Remove last handler // Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault) chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
assert.False(t, chain.HasHandlers(testDomain)) w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
for _, m := range mocks {
m.AssertNumberOfCalls(t, "ServeDNS", 0)
}
} }
func TestHandlerChain_CaseSensitivity(t *testing.T) { func TestHandlerChain_CaseSensitivity(t *testing.T) {
@ -830,3 +846,165 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
}) })
} }
} }
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
tests := []struct {
name string
addPattern string
removePattern string
queryPattern string
shouldBeRemoved bool
description string
}{
{
name: "exact same pattern",
addPattern: "example.com.",
removePattern: "example.com.",
queryPattern: "example.com.",
shouldBeRemoved: true,
description: "Adding and removing with identical patterns",
},
{
name: "case difference",
addPattern: "Example.Com.",
removePattern: "EXAMPLE.COM.",
queryPattern: "example.com.",
shouldBeRemoved: true,
description: "Adding with mixed case, removing with uppercase",
},
{
name: "reversed case difference",
addPattern: "EXAMPLE.ORG.",
removePattern: "example.org.",
queryPattern: "example.org.",
shouldBeRemoved: true,
description: "Adding with uppercase, removing with lowercase",
},
{
name: "add wildcard, remove wildcard",
addPattern: "*.example.com.",
removePattern: "*.example.com.",
queryPattern: "sub.example.com.",
shouldBeRemoved: true,
description: "Adding and removing with identical wildcard patterns",
},
{
name: "add wildcard, remove transformed pattern",
addPattern: "*.example.net.",
removePattern: "example.net.",
queryPattern: "sub.example.net.",
shouldBeRemoved: false,
description: "Adding with wildcard, removing with non-wildcard pattern",
},
{
name: "add transformed pattern, remove wildcard",
addPattern: "example.io.",
removePattern: "*.example.io.",
queryPattern: "example.io.",
shouldBeRemoved: false,
description: "Adding with non-wildcard pattern, removing with wildcard pattern",
},
{
name: "trailing dot difference",
addPattern: "example.dev",
removePattern: "example.dev.",
queryPattern: "example.dev.",
shouldBeRemoved: true,
description: "Adding without trailing dot, removing with trailing dot",
},
{
name: "reversed trailing dot difference",
addPattern: "example.app.",
removePattern: "example.app",
queryPattern: "example.app.",
shouldBeRemoved: true,
description: "Adding with trailing dot, removing without trailing dot",
},
{
name: "mixed case and wildcard",
addPattern: "*.Example.Site.",
removePattern: "*.EXAMPLE.SITE.",
queryPattern: "sub.example.site.",
shouldBeRemoved: true,
description: "Adding mixed case wildcard, removing uppercase wildcard",
},
{
name: "root zone",
addPattern: ".",
removePattern: ".",
queryPattern: "random.domain.",
shouldBeRemoved: true,
description: "Adding and removing root zone",
},
{
name: "wrong domain",
addPattern: "example.com.",
removePattern: "different.com.",
queryPattern: "example.com.",
shouldBeRemoved: false,
description: "Adding one domain, trying to remove a different domain",
},
{
name: "subdomain mismatch",
addPattern: "sub.example.com.",
removePattern: "example.com.",
queryPattern: "sub.example.com.",
shouldBeRemoved: false,
description: "Adding subdomain, trying to remove parent domain",
},
{
name: "parent domain mismatch",
addPattern: "example.com.",
removePattern: "sub.example.com.",
queryPattern: "example.com.",
shouldBeRemoved: false,
description: "Adding parent domain, trying to remove subdomain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handler := &nbdns.MockHandler{}
r := new(dns.Msg)
r.SetQuestion(tt.queryPattern, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// First verify no handler is called before adding any
chain.ServeDNS(w, r)
handler.AssertNotCalled(t, "ServeDNS")
// Add handler
chain.AddHandler(tt.addPattern, handler, nbdns.PriorityDefault)
// Verify handler is called after adding
handler.On("ServeDNS", mock.Anything, r).Once()
chain.ServeDNS(w, r)
handler.AssertExpectations(t)
// Reset mock for the next test
handler.ExpectedCalls = nil
// Remove handler
chain.RemoveHandler(tt.removePattern, nbdns.PriorityDefault)
// Set up expectations based on whether removal should succeed
if !tt.shouldBeRemoved {
handler.On("ServeDNS", mock.Anything, r).Once()
}
// Test if handler is still called after removal attempt
chain.ServeDNS(w, r)
if tt.shouldBeRemoved {
handler.AssertNotCalled(t, "ServeDNS",
"Handler should not be called after successful removal with pattern %q",
tt.removePattern)
} else {
handler.AssertExpectations(t)
handler.ExpectedCalls = nil
}
})
}
}

View File

@ -5,6 +5,8 @@ import (
"net/netip" "net/netip"
"strings" "strings"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
@ -12,8 +14,8 @@ import (
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
const ( const (
ipv4ReverseZone = ".in-addr.arpa" ipv4ReverseZone = ".in-addr.arpa."
ipv6ReverseZone = ".ip6.arpa" ipv6ReverseZone = ".ip6.arpa."
) )
type hostManager interface { type hostManager interface {
@ -103,7 +105,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
for _, domain := range nsConfig.Domains { for _, domain := range nsConfig.Domains {
config.Domains = append(config.Domains, DomainConfig{ config.Domains = append(config.Domains, DomainConfig{
Domain: strings.TrimSuffix(domain, "."), Domain: strings.ToLower(dns.Fqdn(domain)),
MatchOnly: !nsConfig.SearchDomainsEnabled, MatchOnly: !nsConfig.SearchDomainsEnabled,
}) })
} }
@ -112,7 +114,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
for _, customZone := range dnsConfig.CustomZones { for _, customZone := range dnsConfig.CustomZones {
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone) matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
config.Domains = append(config.Domains, DomainConfig{ config.Domains = append(config.Domains, DomainConfig{
Domain: strings.TrimSuffix(customZone.Domain, "."), Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
MatchOnly: matchOnly, MatchOnly: matchOnly,
}) })
} }

View File

@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
continue continue
} }
if dConf.MatchOnly { if dConf.MatchOnly {
matchDomains = append(matchDomains, dConf.Domain) matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, "."))
continue continue
} }
searchDomains = append(searchDomains, dConf.Domain) searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
} }
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)

View File

@ -17,9 +17,12 @@ import (
var ( var (
userenv = syscall.NewLazyDLL("userenv.dll") userenv = syscall.NewLazyDLL("userenv.dll")
dnsapi = syscall.NewLazyDLL("dnsapi.dll")
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex // https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx") refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache")
) )
const ( const (
@ -97,9 +100,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
continue continue
} }
if !dConf.MatchOnly { if !dConf.MatchOnly {
searchDomains = append(searchDomains, dConf.Domain) searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, "."))
} }
matchDomains = append(matchDomains, "."+dConf.Domain) matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
} }
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
@ -116,6 +119,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return fmt.Errorf("update search domains: %w", err) return fmt.Errorf("update search domains: %w", err)
} }
if err := r.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil return nil
} }
@ -184,6 +191,26 @@ func (r *registryConfigurator) string() string {
return "registry" return "registry"
} }
func (r *registryConfigurator) flushDNSCache() error {
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
defer func() {
if rec := recover(); rec != nil {
log.Errorf("Recovered from panic in flushDNSCache: %v", rec)
}
}()
ret, _, err := dnsFlushResolverCacheFn.Call()
if ret == 0 {
if err != nil && !errors.Is(err, syscall.Errno(0)) {
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
}
return fmt.Errorf("DnsFlushResolverCache failed")
}
log.Info("flushed DNS cache")
return nil
}
func (r *registryConfigurator) updateSearchDomains(domains []string) error { func (r *registryConfigurator) updateSearchDomains(domains []string) error {
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
return fmt.Errorf("update search domains: %w", err) return fmt.Errorf("update search domains: %w", err)
@ -236,6 +263,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err) return fmt.Errorf("remove interface registry key: %w", err)
} }
if err := r.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil return nil
} }

View File

@ -6,6 +6,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
) )
// MockServer is the mock instance of a dns server // MockServer is the mock instance of a dns server
@ -13,17 +14,17 @@ type MockServer struct {
InitializeFunc func() error InitializeFunc func() error
StopFunc func() StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func([]string, dns.Handler, int) RegisterHandlerFunc func(domain.List, dns.Handler, int)
DeregisterHandlerFunc func([]string, int) DeregisterHandlerFunc func(domain.List, int)
} }
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
if m.RegisterHandlerFunc != nil { if m.RegisterHandlerFunc != nil {
m.RegisterHandlerFunc(domains, handler, priority) m.RegisterHandlerFunc(domains, handler, priority)
} }
} }
func (m *MockServer) DeregisterHandler(domains []string, priority int) { func (m *MockServer) DeregisterHandler(domains domain.List, priority int) {
if m.DeregisterHandlerFunc != nil { if m.DeregisterHandlerFunc != nil {
m.DeregisterHandlerFunc(domains, priority) m.DeregisterHandlerFunc(domains, priority)
} }

View File

@ -13,7 +13,6 @@ import (
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@ -126,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
continue continue
} }
if dConf.MatchOnly { if dConf.MatchOnly {
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain)) matchDomains = append(matchDomains, "~."+dConf.Domain)
continue continue
} }
searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain)) searchDomains = append(searchDomains, dConf.Domain)
} }
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic

View File

@ -6,11 +6,13 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"runtime" "runtime"
"strings"
"sync" "sync"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2" "github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
@ -18,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
cProto "github.com/netbirdio/netbird/client/proto" cProto "github.com/netbirdio/netbird/client/proto"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
) )
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
@ -32,8 +35,8 @@ type IosDnsManager interface {
// Server is a dns server interface // Server is a dns server interface
type Server interface { type Server interface {
RegisterHandler(domains []string, handler dns.Handler, priority int) RegisterHandler(domains domain.List, handler dns.Handler, priority int)
DeregisterHandler(domains []string, priority int) DeregisterHandler(domains domain.List, priority int)
Initialize() error Initialize() error
Stop() Stop()
DnsIP() string DnsIP() string
@ -65,6 +68,7 @@ type DefaultServer struct {
previousConfigHash uint64 previousConfigHash uint64
currentConfig HostDNSConfig currentConfig HostDNSConfig
handlerChain *HandlerChain handlerChain *HandlerChain
extraDomains map[domain.Domain]int
// permanent related properties // permanent related properties
permanent bool permanent bool
@ -164,13 +168,15 @@ func newDefaultServer(
stateManager *statemanager.Manager, stateManager *statemanager.Manager,
disableSys bool, disableSys bool,
) *DefaultServer { ) *DefaultServer {
handlerChain := NewHandlerChain()
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
ctxCancel: stop, ctxCancel: stop,
disableSys: disableSys, disableSys: disableSys,
service: dnsService, service: dnsService,
handlerChain: NewHandlerChain(), handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap), dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
@ -181,14 +187,26 @@ func newDefaultServer(
hostsDNSHolder: newHostsDNSHolder(), hostsDNSHolder: newHostsDNSHolder(),
} }
// register with root zone, handler chain takes care of the routing
dnsService.RegisterMux(".", handlerChain)
return defaultServer return defaultServer
} }
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { // RegisterHandler registers a handler for the given domains with the given priority.
// Any previously registered handler for the same domain and priority will be replaced.
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
s.registerHandler(domains, handler, priority) s.registerHandler(domains.ToPunycodeList(), handler, priority)
// TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain
for _, domain := range domains {
// convert to zone with simple ref counter
s.extraDomains[toZone(domain)]++
}
s.applyHostConfig()
} }
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
@ -200,15 +218,23 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
continue continue
} }
s.handlerChain.AddHandler(domain, handler, priority) s.handlerChain.AddHandler(domain, handler, priority)
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
} }
} }
func (s *DefaultServer) DeregisterHandler(domains []string, priority int) { // DeregisterHandler deregisters the handler for the given domains with the given priority.
func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
s.deregisterHandler(domains, priority) s.deregisterHandler(domains.ToPunycodeList(), priority)
for _, domain := range domains {
zone := toZone(domain)
s.extraDomains[zone]--
if s.extraDomains[zone] <= 0 {
delete(s.extraDomains, zone)
}
}
s.applyHostConfig()
} }
func (s *DefaultServer) deregisterHandler(domains []string, priority int) { func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
@ -221,11 +247,6 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
} }
s.handlerChain.RemoveHandler(domain, priority) s.handlerChain.RemoveHandler(domain, priority)
// Only deregister from service if no handlers remain
if !s.handlerChain.HasHandlers(domain) {
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
}
} }
} }
@ -286,6 +307,8 @@ func (s *DefaultServer) Stop() {
} }
s.service.Stop() s.service.Stop()
maps.Clear(s.extraDomains)
} }
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
@ -390,7 +413,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be Disabled, we stop the listener or fake resolver // is the service should be Disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records // and proceed with a regular update to clean up the handlers and records
if update.ServiceEnable { if update.ServiceEnable {
_ = s.service.Listen() if err := s.service.Listen(); err != nil {
log.Errorf("failed to start DNS service: %v", err)
}
} else if !s.permanent { } else if !s.permanent {
s.service.Stop() s.service.Stop()
} }
@ -413,17 +438,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
hostUpdate := s.currentConfig
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() { if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver") "Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
hostUpdate.RouteAll = false s.currentConfig.RouteAll = false
} }
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil { s.applyHostConfig()
log.Error(err)
s.handleErrNoGroupaAll(err)
}
go func() { go func() {
// persist dns state right away // persist dns state right away
@ -441,6 +462,38 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
return nil return nil
} }
func (s *DefaultServer) applyHostConfig() {
if s.hostManager == nil {
return
}
config := s.currentConfig
existingDomains := make(map[string]struct{})
for _, d := range config.Domains {
existingDomains[d.Domain] = struct{}{}
}
// add extra domains only if they're not already in the config
for domain := range s.extraDomains {
domainStr := domain.PunycodeString()
if _, exists := existingDomains[domainStr]; !exists {
config.Domains = append(config.Domains, DomainConfig{
Domain: domainStr,
MatchOnly: true,
})
}
}
log.Debugf("extra match domains: %v", s.extraDomains)
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err)
s.handleErrNoGroupaAll(err)
}
}
func (s *DefaultServer) handleErrNoGroupaAll(err error) { func (s *DefaultServer) handleErrNoGroupaAll(err error) {
if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) { if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) {
return return
@ -690,10 +743,7 @@ func (s *DefaultServer) upstreamCallbacks(
} }
} }
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { s.applyHostConfig()
s.handleErrNoGroupaAll(err)
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
}
go func() { go func() {
if err := s.stateManager.PersistState(s.ctx); err != nil { if err := s.stateManager.PersistState(s.ctx); err != nil {
@ -728,12 +778,7 @@ func (s *DefaultServer) upstreamCallbacks(
s.registerHandler([]string{nbdns.RootZone}, handler, priority) s.registerHandler([]string{nbdns.RootZone}, handler, priority)
} }
if s.hostManager != nil { s.applyHostConfig()
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
s.handleErrNoGroupaAll(err)
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
}
}
s.updateNSState(nsGroup, nil, true) s.updateNSState(nsGroup, nil, true)
} }
@ -836,3 +881,13 @@ func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain
return result return result
} }
func toZone(d domain.Domain) domain.Domain {
return domain.Domain(
nbdns.NormalizeZone(
dns.Fqdn(
strings.ToLower(d.PunycodeString()),
),
),
)
}

View File

@ -29,6 +29,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/management/domain"
) )
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
@ -38,7 +39,7 @@ type mocWGIface struct {
} }
func (w *mocWGIface) Name() string { func (w *mocWGIface) Name() string {
panic("implement me") return "utun2301"
} }
func (w *mocWGIface) Address() wgaddr.Address { func (w *mocWGIface) Address() wgaddr.Address {
@ -1448,3 +1449,497 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
}) })
} }
} }
func TestExtraDomains(t *testing.T) {
tests := []struct {
name string
initialConfig nbdns.Config
registerDomains []domain.List
deregisterDomains []domain.List
finalConfig nbdns.Config
expectedDomains []string
expectedMatchOnly []string
applyHostConfigCall int
}{
{
name: "Register domains before config update",
registerDomains: []domain.List{
{"extra1.example.com", "extra2.example.com"},
},
initialConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
},
},
expectedDomains: []string{
"config.example.com.",
"extra1.example.com.",
"extra2.example.com.",
},
expectedMatchOnly: []string{
"extra1.example.com.",
"extra2.example.com.",
},
applyHostConfigCall: 2,
},
{
name: "Register domains after config update",
initialConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
},
},
registerDomains: []domain.List{
{"extra1.example.com", "extra2.example.com"},
},
expectedDomains: []string{
"config.example.com.",
"extra1.example.com.",
"extra2.example.com.",
},
expectedMatchOnly: []string{
"extra1.example.com.",
"extra2.example.com.",
},
applyHostConfigCall: 2,
},
{
name: "Register overlapping domains",
initialConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
{Domain: "overlap.example.com"},
},
},
registerDomains: []domain.List{
{"extra.example.com", "overlap.example.com"},
},
expectedDomains: []string{
"config.example.com.",
"overlap.example.com.",
"extra.example.com.",
},
expectedMatchOnly: []string{
"extra.example.com.",
},
applyHostConfigCall: 2,
},
{
name: "Register and deregister domains",
initialConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
},
},
registerDomains: []domain.List{
{"extra1.example.com", "extra2.example.com"},
{"extra3.example.com", "extra4.example.com"},
},
deregisterDomains: []domain.List{
{"extra1.example.com", "extra3.example.com"},
},
expectedDomains: []string{
"config.example.com.",
"extra2.example.com.",
"extra4.example.com.",
},
expectedMatchOnly: []string{
"extra2.example.com.",
"extra4.example.com.",
},
applyHostConfigCall: 4,
},
{
name: "Register domains with ref counter",
initialConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
},
},
registerDomains: []domain.List{
{"extra.example.com", "duplicate.example.com"},
{"other.example.com", "duplicate.example.com"},
},
deregisterDomains: []domain.List{
{"duplicate.example.com"},
},
expectedDomains: []string{
"config.example.com.",
"extra.example.com.",
"other.example.com.",
"duplicate.example.com.",
},
expectedMatchOnly: []string{
"extra.example.com.",
"other.example.com.",
"duplicate.example.com.",
},
applyHostConfigCall: 4,
},
{
name: "Config update with new domains after registration",
initialConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
},
},
registerDomains: []domain.List{
{"extra.example.com", "duplicate.example.com"},
},
finalConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
{Domain: "newconfig.example.com"},
},
},
expectedDomains: []string{
"config.example.com.",
"newconfig.example.com.",
"extra.example.com.",
"duplicate.example.com.",
},
expectedMatchOnly: []string{
"extra.example.com.",
"duplicate.example.com.",
},
applyHostConfigCall: 3,
},
{
name: "Deregister domain that is part of customZones",
initialConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
{Domain: "protected.example.com"},
},
},
registerDomains: []domain.List{
{"extra.example.com", "protected.example.com"},
},
deregisterDomains: []domain.List{
{"protected.example.com"},
},
expectedDomains: []string{
"extra.example.com.",
"config.example.com.",
"protected.example.com.",
},
expectedMatchOnly: []string{
"extra.example.com.",
},
applyHostConfigCall: 3,
},
{
name: "Register domain that is part of nameserver group",
initialConfig: nbdns.Config{
ServiceEnable: true,
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"ns.example.com", "overlap.ns.example.com"},
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
},
},
},
registerDomains: []domain.List{
{"extra.example.com", "overlap.ns.example.com"},
},
expectedDomains: []string{
"ns.example.com.",
"overlap.ns.example.com.",
"extra.example.com.",
},
expectedMatchOnly: []string{
"ns.example.com.",
"overlap.ns.example.com.",
"extra.example.com.",
},
applyHostConfigCall: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedConfigs []HostDNSConfig
mockHostConfig := &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
capturedConfigs = append(capturedConfigs, config)
return nil
},
restoreHostDNSFunc: func() error {
return nil
},
supportCustomPortFunc: func() bool {
return true
},
stringFunc: func() string {
return "mock"
},
}
mockSvc := &mockService{}
server := &DefaultServer{
ctx: context.Background(),
handlerChain: NewHandlerChain(),
wgInterface: &mocWGIface{},
hostManager: mockHostConfig,
localResolver: &localResolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),
}
// Apply initial configuration
if tt.initialConfig.ServiceEnable {
err := server.applyConfiguration(tt.initialConfig)
assert.NoError(t, err)
}
// Register domains
for _, domains := range tt.registerDomains {
server.RegisterHandler(domains, &MockHandler{}, PriorityDefault)
}
// Deregister domains if specified
for _, domains := range tt.deregisterDomains {
server.DeregisterHandler(domains, PriorityDefault)
}
// Apply final configuration if specified
if tt.finalConfig.ServiceEnable {
err := server.applyConfiguration(tt.finalConfig)
assert.NoError(t, err)
}
// Verify number of calls
assert.Equal(t, tt.applyHostConfigCall, len(capturedConfigs),
"Expected %d calls to applyDNSConfig, got %d", tt.applyHostConfigCall, len(capturedConfigs))
// Get the last applied config
lastConfig := capturedConfigs[len(capturedConfigs)-1]
// Check all expected domains are present
domainMap := make(map[string]bool)
matchOnlyMap := make(map[string]bool)
for _, d := range lastConfig.Domains {
domainMap[d.Domain] = true
if d.MatchOnly {
matchOnlyMap[d.Domain] = true
}
}
// Verify expected domains
for _, d := range tt.expectedDomains {
assert.True(t, domainMap[d], "Expected domain %s not found in final config", d)
}
// Verify match-only domains
for _, d := range tt.expectedMatchOnly {
assert.True(t, matchOnlyMap[d], "Expected match-only domain %s not found in final config", d)
}
// Verify no unexpected domains
assert.Equal(t, len(tt.expectedDomains), len(domainMap), "Unexpected number of domains in final config")
assert.Equal(t, len(tt.expectedMatchOnly), len(matchOnlyMap), "Unexpected number of match-only domains in final config")
})
}
}
func TestExtraDomainsRefCounting(t *testing.T) {
mockHostConfig := &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
return nil
},
restoreHostDNSFunc: func() error {
return nil
},
supportCustomPortFunc: func() bool {
return true
},
stringFunc: func() string {
return "mock"
},
}
mockSvc := &mockService{}
server := &DefaultServer{
ctx: context.Background(),
handlerChain: NewHandlerChain(),
hostManager: mockHostConfig,
localResolver: &localResolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),
}
// Register domains from different handlers with same domain
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
// Verify refcount is 2
zoneKey := toZone("shared.example.com")
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
// Deregister one handler
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
// Verify refcount is 1
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
// Deregister the other handler
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityDNSRoute)
// Verify domain is removed
_, exists := server.extraDomains[zoneKey]
assert.False(t, exists, "Domain should be removed after deregistering all handlers")
}
func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
var capturedConfig HostDNSConfig
mockHostConfig := &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
capturedConfig = config
return nil
},
restoreHostDNSFunc: func() error {
return nil
},
supportCustomPortFunc: func() bool {
return true
},
stringFunc: func() string {
return "mock"
},
}
mockSvc := &mockService{}
server := &DefaultServer{
ctx: context.Background(),
handlerChain: NewHandlerChain(),
hostManager: mockHostConfig,
localResolver: &localResolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),
}
server.RegisterHandler(domain.List{"extra.example.com"}, &MockHandler{}, PriorityDefault)
initialConfig := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
},
}
err := server.applyConfiguration(initialConfig)
assert.NoError(t, err)
var domains []string
for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain)
}
assert.Contains(t, domains, "config.example.com.")
assert.Contains(t, domains, "extra.example.com.")
// Now apply a new configuration with overlapping domain
updatedConfig := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
{Domain: "extra.example.com"},
},
}
err = server.applyConfiguration(updatedConfig)
assert.NoError(t, err)
// Verify both domains are in config, but no duplicates
domains = []string{}
matchOnlyCount := 0
for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain)
if d.MatchOnly {
matchOnlyCount++
}
}
assert.Contains(t, domains, "config.example.com.")
assert.Contains(t, domains, "extra.example.com.")
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
// Extra domain should no longer be marked as match-only when in config
matchOnlyDomain := ""
for _, d := range capturedConfig.Domains {
if d.Domain == "extra.example.com." && d.MatchOnly {
matchOnlyDomain = d.Domain
break
}
}
assert.Empty(t, matchOnlyDomain, "Domain should not be match-only when included in config")
}
func TestDomainCaseHandling(t *testing.T) {
var capturedConfig HostDNSConfig
mockHostConfig := &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
capturedConfig = config
return nil
},
restoreHostDNSFunc: func() error {
return nil
},
supportCustomPortFunc: func() bool {
return true
},
stringFunc: func() string {
return "mock"
},
}
mockSvc := &mockService{}
server := &DefaultServer{
ctx: context.Background(),
handlerChain: NewHandlerChain(),
hostManager: mockHostConfig,
localResolver: &localResolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),
}
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
config := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
},
}
err := server.applyConfiguration(config)
assert.NoError(t, err)
var domains []string
for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain)
}
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
}

View File

@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@ -111,7 +110,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
continue continue
} }
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: dns.Fqdn(dConf.Domain), Domain: dConf.Domain,
MatchOnly: dConf.MatchOnly, MatchOnly: dConf.MatchOnly,
}) })
@ -151,6 +150,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
if err != nil { if err != nil {
log.Error(err) log.Error(err)
} }
if err := s.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil return nil
} }
@ -163,7 +167,8 @@ func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdD
if err != nil { if err != nil {
return fmt.Errorf("setting domains configuration failed with error: %w", err) return fmt.Errorf("setting domains configuration failed with error: %w", err)
} }
return s.flushCaches()
return nil
} }
func (s *systemdDbusConfigurator) restoreHostDNS() error { func (s *systemdDbusConfigurator) restoreHostDNS() error {
@ -183,10 +188,14 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
return fmt.Errorf("unable to revert link configuration, got error: %w", err) return fmt.Errorf("unable to revert link configuration, got error: %w", err)
} }
return s.flushCaches() if err := s.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil
} }
func (s *systemdDbusConfigurator) flushCaches() error { func (s *systemdDbusConfigurator) flushDNSCache() error {
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode) obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
if err != nil { if err != nil {
return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err) return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err)

View File

@ -23,9 +23,10 @@ import (
) )
const ( const (
UpstreamTimeout = 15 * time.Second
failsTillDeact = int32(5) failsTillDeact = int32(5)
reactivatePeriod = 30 * time.Second reactivatePeriod = 30 * time.Second
upstreamTimeout = 15 * time.Second
probeTimeout = 2 * time.Second probeTimeout = 2 * time.Second
) )
@ -66,7 +67,7 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
domain: domain, domain: domain,
upstreamTimeout: upstreamTimeout, upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod, reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact, failsTillDeact: failsTillDeact,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,

View File

@ -55,7 +55,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN // exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
timeout := upstreamTimeout timeout := UpstreamTimeout
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline) timeout = time.Until(deadline)
} }

View File

@ -52,7 +52,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err) return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
} }
timeout := upstreamTimeout timeout := UpstreamTimeout
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline) timeout = time.Until(deadline)
} }

View File

@ -26,7 +26,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
name: "Should Resolve A Record", name: "Should Resolve A Record",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"}, InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
timeout: upstreamTimeout, timeout: UpstreamTimeout,
expectedAnswer: "1.1.1.1", expectedAnswer: "1.1.1.1",
}, },
{ {
@ -48,7 +48,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"}, InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
cancelCTX: true, cancelCTX: true,
timeout: upstreamTimeout, timeout: UpstreamTimeout,
responseShouldBeNil: true, responseShouldBeNil: true,
}, },
} }
@ -122,7 +122,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
r: new(dns.Msg), r: new(dns.Msg),
rtt: time.Millisecond, rtt: time.Millisecond,
}, },
upstreamTimeout: upstreamTimeout, upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod, reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact, failsTillDeact: failsTillDeact,
} }

View File

@ -6,7 +6,6 @@ import (
"net/netip" "net/netip"
"strings" "strings"
"sync" "sync"
"time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -60,7 +59,7 @@ func (d *DnsInterceptor) String() string {
} }
func (d *DnsInterceptor) AddRoute(context.Context) error { func (d *DnsInterceptor) AddRoute(context.Context) error {
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute) d.dnsServer.RegisterHandler(d.route.Domains, d, nbdns.PriorityDNSRoute)
return nil return nil
} }
@ -89,7 +88,7 @@ func (d *DnsInterceptor) RemoveRoute() error {
clear(d.interceptedDomains) clear(d.interceptedDomains)
d.mu.Unlock() d.mu.Unlock()
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute) d.dnsServer.DeregisterHandler(d.route.Domains, nbdns.PriorityDNSRoute)
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
@ -142,21 +141,24 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Tracef("received DNS request for domain=%s type=%v class=%v", log.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
d.continueToNextHandler(w, r, "non A/AAAA query")
return
}
d.mu.RLock() d.mu.RLock()
peerKey := d.currentPeerKey peerKey := d.currentPeerKey
d.mu.RUnlock() d.mu.RUnlock()
if peerKey == "" { if peerKey == "" {
log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name) d.writeDNSError(w, r, "no current peer key")
d.continueToNextHandler(w, r, "no current peer key")
return return
} }
upstreamIP, err := d.getUpstreamIP(peerKey) upstreamIP, err := d.getUpstreamIP(peerKey)
if err != nil { if err != nil {
log.Errorf("failed to get upstream IP: %v", err) d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err))
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
return return
} }
@ -165,34 +167,43 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.SetEdns0(4096, false) r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
client := &dns.Client{ client := &dns.Client{
Timeout: 5 * time.Second, Timeout: nbdns.UpstreamTimeout,
Net: "udp", Net: "udp",
} }
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := client.ExchangeContext(context.Background(), r, upstream) reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
var answer []dns.RR
if reply != nil {
answer = reply.Answer
}
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
if err != nil { if err != nil {
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err) log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err) log.Errorf("failed writing DNS response: %v", err)
} }
return return
} }
var answer []dns.RR
if reply != nil {
answer = reply.Answer
}
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
reply.Id = r.Id reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil { if err := d.writeMsg(w, reply); err != nil {
log.Errorf("failed writing DNS response: %v", err) log.Errorf("failed writing DNS response: %v", err)
} }
} }
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS error response: %v", err)
}
}
// continueToNextHandler signals the handler chain to try the next handler // continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) { func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)

View File

@ -235,7 +235,7 @@ func (r *Route) resolve(results chan resolveResult) {
ips, err := r.getIPsFromResolver(domain) ips, err := r.getIPsFromResolver(domain)
if err != nil { if err != nil {
log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err) log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err)
ips, err = net.LookupIP(string(domain)) ips, err = net.LookupIP(domain.PunycodeString())
if err != nil { if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)} results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
return return

View File

@ -9,5 +9,5 @@ import (
) )
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
return net.LookupIP(string(domain)) return net.LookupIP(domain.PunycodeString())
} }

View File

@ -23,7 +23,7 @@ func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
} }
msg := new(dns.Msg) msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA) msg.SetQuestion(dns.Fqdn(domain.PunycodeString()), dns.TypeA)
startTime := time.Now() startTime := time.Now()

View File

@ -100,7 +100,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
// Convert to proto format // Convert to proto format
for domain, ips := range domainMap { for domain, ips := range domainMap {
pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{ pbRoute.ResolvedIPs[domain.PunycodeString()] = &proto.IPList{
Ips: ips, Ips: ips,
} }
} }

View File

@ -24,6 +24,11 @@ func (d Domain) SafeString() string {
return str return str
} }
// PunycodeString returns the punycode representation of the Domain.
func (d Domain) PunycodeString() string {
return string(d)
}
// FromString creates a Domain from a string, converting it to punycode. // FromString creates a Domain from a string, converting it to punycode.
func FromString(s string) (Domain, error) { func FromString(s string) (Domain, error) {
ascii, err := idna.ToASCII(s) ascii, err := idna.ToASCII(s)