mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-16 18:11:58 +02:00
[client] Automatically register match domains for DNS routes (#3614)
This commit is contained in:
@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
listOfDomains = append(listOfDomains, dConf.Domain)
|
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
return listOfDomains
|
return listOfDomains
|
||||||
}
|
}
|
||||||
|
@ -75,12 +75,7 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
}
|
}
|
||||||
|
|
||||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
c.removeEntry(origPattern, priority)
|
||||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if handler implements SubdomainMatcher interface
|
// Check if handler implements SubdomainMatcher interface
|
||||||
matchSubdomains := false
|
matchSubdomains := false
|
||||||
@ -133,30 +128,20 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
|||||||
|
|
||||||
pattern = dns.Fqdn(pattern)
|
pattern = dns.Fqdn(pattern)
|
||||||
|
|
||||||
|
c.removeEntry(pattern, priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HandlerChain) removeEntry(pattern string, priority int) {
|
||||||
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
entry := c.handlers[i]
|
entry := c.handlers[i]
|
||||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
return
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasHandlers returns true if there are any handlers remaining for the given pattern
|
|
||||||
func (c *HandlerChain) HasHandlers(pattern string) bool {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
|
||||||
for _, entry := range c.handlers {
|
|
||||||
if strings.EqualFold(entry.Pattern, pattern) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
|
@ -443,14 +443,6 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
for _, handler := range handlers {
|
for _, handler := range handlers {
|
||||||
handler.AssertExpectations(t)
|
handler.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify handler exists check
|
|
||||||
for priority, shouldExist := range tt.expectedCalls {
|
|
||||||
if shouldExist {
|
|
||||||
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
|
|
||||||
"Handler chain should have handlers for pattern after removing priority %d", priority)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -470,45 +462,69 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(testQuery, dns.TypeA)
|
r.SetQuestion(testQuery, dns.TypeA)
|
||||||
|
|
||||||
|
// Keep track of mocks for the final assertion in Step 4
|
||||||
|
mocks := []*nbdns.MockSubdomainHandler{routeHandler, matchHandler, defaultHandler}
|
||||||
|
|
||||||
// Add handlers in mixed order
|
// Add handlers in mixed order
|
||||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||||
|
|
||||||
// Test 1: Initial state with all three handlers
|
// Test 1: Initial state
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
// Highest priority handler (routeHandler) should be called
|
// Highest priority handler (routeHandler) should be called
|
||||||
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w1, r)
|
||||||
routeHandler.AssertExpectations(t)
|
routeHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
routeHandler.ExpectedCalls = nil
|
||||||
|
routeHandler.Calls = nil
|
||||||
|
matchHandler.ExpectedCalls = nil
|
||||||
|
matchHandler.Calls = nil
|
||||||
|
defaultHandler.ExpectedCalls = nil
|
||||||
|
defaultHandler.Calls = nil
|
||||||
|
|
||||||
// Test 2: Remove highest priority handler
|
// Test 2: Remove highest priority handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||||
assert.True(t, chain.HasHandlers(testDomain))
|
|
||||||
|
|
||||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
// Now middle priority handler (matchHandler) should be called
|
// Now middle priority handler (matchHandler) should be called
|
||||||
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w2, r)
|
||||||
matchHandler.AssertExpectations(t)
|
matchHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
matchHandler.ExpectedCalls = nil
|
||||||
|
matchHandler.Calls = nil
|
||||||
|
defaultHandler.ExpectedCalls = nil
|
||||||
|
defaultHandler.Calls = nil
|
||||||
|
|
||||||
// Test 3: Remove middle priority handler
|
// Test 3: Remove middle priority handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||||
assert.True(t, chain.HasHandlers(testDomain))
|
|
||||||
|
|
||||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
// Now lowest priority handler (defaultHandler) should be called
|
// Now lowest priority handler (defaultHandler) should be called
|
||||||
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w3, r)
|
||||||
defaultHandler.AssertExpectations(t)
|
defaultHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
defaultHandler.ExpectedCalls = nil
|
||||||
|
defaultHandler.Calls = nil
|
||||||
|
|
||||||
// Test 4: Remove last handler
|
// Test 4: Remove last handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||||
|
|
||||||
assert.False(t, chain.HasHandlers(testDomain))
|
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
|
||||||
|
|
||||||
|
for _, m := range mocks {
|
||||||
|
m.AssertNumberOfCalls(t, "ServeDNS", 0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||||
@ -830,3 +846,165 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addPattern string
|
||||||
|
removePattern string
|
||||||
|
queryPattern string
|
||||||
|
shouldBeRemoved bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact same pattern",
|
||||||
|
addPattern: "example.com.",
|
||||||
|
removePattern: "example.com.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding and removing with identical patterns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case difference",
|
||||||
|
addPattern: "Example.Com.",
|
||||||
|
removePattern: "EXAMPLE.COM.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding with mixed case, removing with uppercase",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reversed case difference",
|
||||||
|
addPattern: "EXAMPLE.ORG.",
|
||||||
|
removePattern: "example.org.",
|
||||||
|
queryPattern: "example.org.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding with uppercase, removing with lowercase",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add wildcard, remove wildcard",
|
||||||
|
addPattern: "*.example.com.",
|
||||||
|
removePattern: "*.example.com.",
|
||||||
|
queryPattern: "sub.example.com.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding and removing with identical wildcard patterns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add wildcard, remove transformed pattern",
|
||||||
|
addPattern: "*.example.net.",
|
||||||
|
removePattern: "example.net.",
|
||||||
|
queryPattern: "sub.example.net.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding with wildcard, removing with non-wildcard pattern",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add transformed pattern, remove wildcard",
|
||||||
|
addPattern: "example.io.",
|
||||||
|
removePattern: "*.example.io.",
|
||||||
|
queryPattern: "example.io.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding with non-wildcard pattern, removing with wildcard pattern",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trailing dot difference",
|
||||||
|
addPattern: "example.dev",
|
||||||
|
removePattern: "example.dev.",
|
||||||
|
queryPattern: "example.dev.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding without trailing dot, removing with trailing dot",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reversed trailing dot difference",
|
||||||
|
addPattern: "example.app.",
|
||||||
|
removePattern: "example.app",
|
||||||
|
queryPattern: "example.app.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding with trailing dot, removing without trailing dot",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case and wildcard",
|
||||||
|
addPattern: "*.Example.Site.",
|
||||||
|
removePattern: "*.EXAMPLE.SITE.",
|
||||||
|
queryPattern: "sub.example.site.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding mixed case wildcard, removing uppercase wildcard",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone",
|
||||||
|
addPattern: ".",
|
||||||
|
removePattern: ".",
|
||||||
|
queryPattern: "random.domain.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding and removing root zone",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong domain",
|
||||||
|
addPattern: "example.com.",
|
||||||
|
removePattern: "different.com.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding one domain, trying to remove a different domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain mismatch",
|
||||||
|
addPattern: "sub.example.com.",
|
||||||
|
removePattern: "example.com.",
|
||||||
|
queryPattern: "sub.example.com.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding subdomain, trying to remove parent domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parent domain mismatch",
|
||||||
|
addPattern: "example.com.",
|
||||||
|
removePattern: "sub.example.com.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding parent domain, trying to remove subdomain",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
handler := &nbdns.MockHandler{}
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.queryPattern, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
// First verify no handler is called before adding any
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
handler.AssertNotCalled(t, "ServeDNS")
|
||||||
|
|
||||||
|
// Add handler
|
||||||
|
chain.AddHandler(tt.addPattern, handler, nbdns.PriorityDefault)
|
||||||
|
|
||||||
|
// Verify handler is called after adding
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Reset mock for the next test
|
||||||
|
handler.ExpectedCalls = nil
|
||||||
|
|
||||||
|
// Remove handler
|
||||||
|
chain.RemoveHandler(tt.removePattern, nbdns.PriorityDefault)
|
||||||
|
|
||||||
|
// Set up expectations based on whether removal should succeed
|
||||||
|
if !tt.shouldBeRemoved {
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test if handler is still called after removal attempt
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
if tt.shouldBeRemoved {
|
||||||
|
handler.AssertNotCalled(t, "ServeDNS",
|
||||||
|
"Handler should not be called after successful removal with pattern %q",
|
||||||
|
tt.removePattern)
|
||||||
|
} else {
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
handler.ExpectedCalls = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -5,6 +5,8 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
@ -12,8 +14,8 @@ import (
|
|||||||
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ipv4ReverseZone = ".in-addr.arpa"
|
ipv4ReverseZone = ".in-addr.arpa."
|
||||||
ipv6ReverseZone = ".ip6.arpa"
|
ipv6ReverseZone = ".ip6.arpa."
|
||||||
)
|
)
|
||||||
|
|
||||||
type hostManager interface {
|
type hostManager interface {
|
||||||
@ -103,7 +105,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
|||||||
|
|
||||||
for _, domain := range nsConfig.Domains {
|
for _, domain := range nsConfig.Domains {
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: strings.TrimSuffix(domain, "."),
|
Domain: strings.ToLower(dns.Fqdn(domain)),
|
||||||
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -112,7 +114,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
|||||||
for _, customZone := range dnsConfig.CustomZones {
|
for _, customZone := range dnsConfig.CustomZones {
|
||||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: strings.TrimSuffix(customZone.Domain, "."),
|
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
||||||
MatchOnly: matchOnly,
|
MatchOnly: matchOnly,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if dConf.MatchOnly {
|
if dConf.MatchOnly {
|
||||||
matchDomains = append(matchDomains, dConf.Domain)
|
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
searchDomains = append(searchDomains, dConf.Domain)
|
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
@ -17,9 +17,12 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
userenv = syscall.NewLazyDLL("userenv.dll")
|
userenv = syscall.NewLazyDLL("userenv.dll")
|
||||||
|
dnsapi = syscall.NewLazyDLL("dnsapi.dll")
|
||||||
|
|
||||||
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
|
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
|
||||||
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
|
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
|
||||||
|
|
||||||
|
dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -97,9 +100,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !dConf.MatchOnly {
|
if !dConf.MatchOnly {
|
||||||
searchDomains = append(searchDomains, dConf.Domain)
|
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
matchDomains = append(matchDomains, "."+dConf.Domain)
|
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
@ -116,6 +119,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
return fmt.Errorf("update search domains: %w", err)
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.flushDNSCache(); err != nil {
|
||||||
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -184,6 +191,26 @@ func (r *registryConfigurator) string() string {
|
|||||||
return "registry"
|
return "registry"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) flushDNSCache() error {
|
||||||
|
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
|
||||||
|
defer func() {
|
||||||
|
if rec := recover(); rec != nil {
|
||||||
|
log.Errorf("Recovered from panic in flushDNSCache: %v", rec)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ret, _, err := dnsFlushResolverCacheFn.Call()
|
||||||
|
if ret == 0 {
|
||||||
|
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
||||||
|
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("DnsFlushResolverCache failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("flushed DNS cache")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||||
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
|
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
|
||||||
return fmt.Errorf("update search domains: %w", err)
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
@ -236,6 +263,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("remove interface registry key: %w", err)
|
return fmt.Errorf("remove interface registry key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.flushDNSCache(); err != nil {
|
||||||
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockServer is the mock instance of a dns server
|
// MockServer is the mock instance of a dns server
|
||||||
@ -13,17 +14,17 @@ type MockServer struct {
|
|||||||
InitializeFunc func() error
|
InitializeFunc func() error
|
||||||
StopFunc func()
|
StopFunc func()
|
||||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||||
RegisterHandlerFunc func([]string, dns.Handler, int)
|
RegisterHandlerFunc func(domain.List, dns.Handler, int)
|
||||||
DeregisterHandlerFunc func([]string, int)
|
DeregisterHandlerFunc func(domain.List, int)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||||
if m.RegisterHandlerFunc != nil {
|
if m.RegisterHandlerFunc != nil {
|
||||||
m.RegisterHandlerFunc(domains, handler, priority)
|
m.RegisterHandlerFunc(domains, handler, priority)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) DeregisterHandler(domains []string, priority int) {
|
func (m *MockServer) DeregisterHandler(domains domain.List, priority int) {
|
||||||
if m.DeregisterHandlerFunc != nil {
|
if m.DeregisterHandlerFunc != nil {
|
||||||
m.DeregisterHandlerFunc(domains, priority)
|
m.DeregisterHandlerFunc(domains, priority)
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,6 @@ import (
|
|||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/miekg/dns"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@ -126,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if dConf.MatchOnly {
|
if dConf.MatchOnly {
|
||||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain))
|
matchDomains = append(matchDomains, "~."+dConf.Domain)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain))
|
searchDomains = append(searchDomains, dConf.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
||||||
|
@ -6,11 +6,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
@ -18,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
@ -32,8 +35,8 @@ type IosDnsManager interface {
|
|||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
RegisterHandler(domains []string, handler dns.Handler, priority int)
|
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
||||||
DeregisterHandler(domains []string, priority int)
|
DeregisterHandler(domains domain.List, priority int)
|
||||||
Initialize() error
|
Initialize() error
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
DnsIP() string
|
||||||
@ -65,6 +68,7 @@ type DefaultServer struct {
|
|||||||
previousConfigHash uint64
|
previousConfigHash uint64
|
||||||
currentConfig HostDNSConfig
|
currentConfig HostDNSConfig
|
||||||
handlerChain *HandlerChain
|
handlerChain *HandlerChain
|
||||||
|
extraDomains map[domain.Domain]int
|
||||||
|
|
||||||
// permanent related properties
|
// permanent related properties
|
||||||
permanent bool
|
permanent bool
|
||||||
@ -164,13 +168,15 @@ func newDefaultServer(
|
|||||||
stateManager *statemanager.Manager,
|
stateManager *statemanager.Manager,
|
||||||
disableSys bool,
|
disableSys bool,
|
||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
|
handlerChain := NewHandlerChain()
|
||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: stop,
|
ctxCancel: stop,
|
||||||
disableSys: disableSys,
|
disableSys: disableSys,
|
||||||
service: dnsService,
|
service: dnsService,
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: handlerChain,
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
@ -181,14 +187,26 @@ func newDefaultServer(
|
|||||||
hostsDNSHolder: newHostsDNSHolder(),
|
hostsDNSHolder: newHostsDNSHolder(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// register with root zone, handler chain takes care of the routing
|
||||||
|
dnsService.RegisterMux(".", handlerChain)
|
||||||
|
|
||||||
return defaultServer
|
return defaultServer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
// RegisterHandler registers a handler for the given domains with the given priority.
|
||||||
|
// Any previously registered handler for the same domain and priority will be replaced.
|
||||||
|
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
s.registerHandler(domains, handler, priority)
|
s.registerHandler(domains.ToPunycodeList(), handler, priority)
|
||||||
|
|
||||||
|
// TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain
|
||||||
|
for _, domain := range domains {
|
||||||
|
// convert to zone with simple ref counter
|
||||||
|
s.extraDomains[toZone(domain)]++
|
||||||
|
}
|
||||||
|
s.applyHostConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
@ -200,15 +218,23 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.handlerChain.AddHandler(domain, handler, priority)
|
s.handlerChain.AddHandler(domain, handler, priority)
|
||||||
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) DeregisterHandler(domains []string, priority int) {
|
// DeregisterHandler deregisters the handler for the given domains with the given priority.
|
||||||
|
func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
s.deregisterHandler(domains, priority)
|
s.deregisterHandler(domains.ToPunycodeList(), priority)
|
||||||
|
for _, domain := range domains {
|
||||||
|
zone := toZone(domain)
|
||||||
|
s.extraDomains[zone]--
|
||||||
|
if s.extraDomains[zone] <= 0 {
|
||||||
|
delete(s.extraDomains, zone)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.applyHostConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||||
@ -221,11 +247,6 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.handlerChain.RemoveHandler(domain, priority)
|
s.handlerChain.RemoveHandler(domain, priority)
|
||||||
|
|
||||||
// Only deregister from service if no handlers remain
|
|
||||||
if !s.handlerChain.HasHandlers(domain) {
|
|
||||||
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -286,6 +307,8 @@ func (s *DefaultServer) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.service.Stop()
|
s.service.Stop()
|
||||||
|
|
||||||
|
maps.Clear(s.extraDomains)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
@ -390,7 +413,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
// is the service should be Disabled, we stop the listener or fake resolver
|
// is the service should be Disabled, we stop the listener or fake resolver
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
if update.ServiceEnable {
|
if update.ServiceEnable {
|
||||||
_ = s.service.Listen()
|
if err := s.service.Listen(); err != nil {
|
||||||
|
log.Errorf("failed to start DNS service: %v", err)
|
||||||
|
}
|
||||||
} else if !s.permanent {
|
} else if !s.permanent {
|
||||||
s.service.Stop()
|
s.service.Stop()
|
||||||
}
|
}
|
||||||
@ -413,17 +438,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
|
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||||
|
|
||||||
hostUpdate := s.currentConfig
|
|
||||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||||
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
||||||
hostUpdate.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
|
s.applyHostConfig()
|
||||||
log.Error(err)
|
|
||||||
s.handleErrNoGroupaAll(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// persist dns state right away
|
// persist dns state right away
|
||||||
@ -441,6 +462,38 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) applyHostConfig() {
|
||||||
|
if s.hostManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
config := s.currentConfig
|
||||||
|
|
||||||
|
existingDomains := make(map[string]struct{})
|
||||||
|
for _, d := range config.Domains {
|
||||||
|
existingDomains[d.Domain] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add extra domains only if they're not already in the config
|
||||||
|
for domain := range s.extraDomains {
|
||||||
|
domainStr := domain.PunycodeString()
|
||||||
|
|
||||||
|
if _, exists := existingDomains[domainStr]; !exists {
|
||||||
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
|
Domain: domainStr,
|
||||||
|
MatchOnly: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("extra match domains: %v", s.extraDomains)
|
||||||
|
|
||||||
|
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
||||||
|
log.Errorf("failed to apply DNS host manager update: %v", err)
|
||||||
|
s.handleErrNoGroupaAll(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) handleErrNoGroupaAll(err error) {
|
func (s *DefaultServer) handleErrNoGroupaAll(err error) {
|
||||||
if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) {
|
if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) {
|
||||||
return
|
return
|
||||||
@ -690,10 +743,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
s.applyHostConfig()
|
||||||
s.handleErrNoGroupaAll(err)
|
|
||||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||||
@ -728,12 +778,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.hostManager != nil {
|
s.applyHostConfig()
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
|
||||||
s.handleErrNoGroupaAll(err)
|
|
||||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.updateNSState(nsGroup, nil, true)
|
s.updateNSState(nsGroup, nil, true)
|
||||||
}
|
}
|
||||||
@ -836,3 +881,13 @@ func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toZone(d domain.Domain) domain.Domain {
|
||||||
|
return domain.Domain(
|
||||||
|
nbdns.NormalizeZone(
|
||||||
|
dns.Fqdn(
|
||||||
|
strings.ToLower(d.PunycodeString()),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
@ -29,6 +29,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
@ -38,7 +39,7 @@ type mocWGIface struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *mocWGIface) Name() string {
|
func (w *mocWGIface) Name() string {
|
||||||
panic("implement me")
|
return "utun2301"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *mocWGIface) Address() wgaddr.Address {
|
func (w *mocWGIface) Address() wgaddr.Address {
|
||||||
@ -1448,3 +1449,497 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtraDomains(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
initialConfig nbdns.Config
|
||||||
|
registerDomains []domain.List
|
||||||
|
deregisterDomains []domain.List
|
||||||
|
finalConfig nbdns.Config
|
||||||
|
expectedDomains []string
|
||||||
|
expectedMatchOnly []string
|
||||||
|
applyHostConfigCall int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Register domains before config update",
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra1.example.com", "extra2.example.com"},
|
||||||
|
},
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"extra1.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra1.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register domains after config update",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra1.example.com", "extra2.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"extra1.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra1.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register overlapping domains",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
{Domain: "overlap.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "overlap.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"overlap.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register and deregister domains",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra1.example.com", "extra2.example.com"},
|
||||||
|
{"extra3.example.com", "extra4.example.com"},
|
||||||
|
},
|
||||||
|
deregisterDomains: []domain.List{
|
||||||
|
{"extra1.example.com", "extra3.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
"extra4.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra2.example.com.",
|
||||||
|
"extra4.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register domains with ref counter",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "duplicate.example.com"},
|
||||||
|
{"other.example.com", "duplicate.example.com"},
|
||||||
|
},
|
||||||
|
deregisterDomains: []domain.List{
|
||||||
|
{"duplicate.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
"other.example.com.",
|
||||||
|
"duplicate.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
"other.example.com.",
|
||||||
|
"duplicate.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Config update with new domains after registration",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "duplicate.example.com"},
|
||||||
|
},
|
||||||
|
finalConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
{Domain: "newconfig.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"newconfig.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
"duplicate.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
"duplicate.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Deregister domain that is part of customZones",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
{Domain: "protected.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "protected.example.com"},
|
||||||
|
},
|
||||||
|
deregisterDomains: []domain.List{
|
||||||
|
{"protected.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
"config.example.com.",
|
||||||
|
"protected.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register domain that is part of nameserver group",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"ns.example.com", "overlap.ns.example.com"},
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "overlap.ns.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"ns.example.com.",
|
||||||
|
"overlap.ns.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"ns.example.com.",
|
||||||
|
"overlap.ns.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var capturedConfigs []HostDNSConfig
|
||||||
|
mockHostConfig := &mockHostConfigurator{
|
||||||
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
|
capturedConfigs = append(capturedConfigs, config)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
restoreHostDNSFunc: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
supportCustomPortFunc: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
stringFunc: func() string {
|
||||||
|
return "mock"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSvc := &mockService{}
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
wgInterface: &mocWGIface{},
|
||||||
|
hostManager: mockHostConfig,
|
||||||
|
localResolver: &localResolver{},
|
||||||
|
service: mockSvc,
|
||||||
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply initial configuration
|
||||||
|
if tt.initialConfig.ServiceEnable {
|
||||||
|
err := server.applyConfiguration(tt.initialConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register domains
|
||||||
|
for _, domains := range tt.registerDomains {
|
||||||
|
server.RegisterHandler(domains, &MockHandler{}, PriorityDefault)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deregister domains if specified
|
||||||
|
for _, domains := range tt.deregisterDomains {
|
||||||
|
server.DeregisterHandler(domains, PriorityDefault)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply final configuration if specified
|
||||||
|
if tt.finalConfig.ServiceEnable {
|
||||||
|
err := server.applyConfiguration(tt.finalConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify number of calls
|
||||||
|
assert.Equal(t, tt.applyHostConfigCall, len(capturedConfigs),
|
||||||
|
"Expected %d calls to applyDNSConfig, got %d", tt.applyHostConfigCall, len(capturedConfigs))
|
||||||
|
|
||||||
|
// Get the last applied config
|
||||||
|
lastConfig := capturedConfigs[len(capturedConfigs)-1]
|
||||||
|
|
||||||
|
// Check all expected domains are present
|
||||||
|
domainMap := make(map[string]bool)
|
||||||
|
matchOnlyMap := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, d := range lastConfig.Domains {
|
||||||
|
domainMap[d.Domain] = true
|
||||||
|
if d.MatchOnly {
|
||||||
|
matchOnlyMap[d.Domain] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify expected domains
|
||||||
|
for _, d := range tt.expectedDomains {
|
||||||
|
assert.True(t, domainMap[d], "Expected domain %s not found in final config", d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify match-only domains
|
||||||
|
for _, d := range tt.expectedMatchOnly {
|
||||||
|
assert.True(t, matchOnlyMap[d], "Expected match-only domain %s not found in final config", d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no unexpected domains
|
||||||
|
assert.Equal(t, len(tt.expectedDomains), len(domainMap), "Unexpected number of domains in final config")
|
||||||
|
assert.Equal(t, len(tt.expectedMatchOnly), len(matchOnlyMap), "Unexpected number of match-only domains in final config")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtraDomainsRefCounting(t *testing.T) {
|
||||||
|
mockHostConfig := &mockHostConfigurator{
|
||||||
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
restoreHostDNSFunc: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
supportCustomPortFunc: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
stringFunc: func() string {
|
||||||
|
return "mock"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSvc := &mockService{}
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
hostManager: mockHostConfig,
|
||||||
|
localResolver: &localResolver{},
|
||||||
|
service: mockSvc,
|
||||||
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register domains from different handlers with same domain
|
||||||
|
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
|
||||||
|
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
|
||||||
|
|
||||||
|
// Verify refcount is 2
|
||||||
|
zoneKey := toZone("shared.example.com")
|
||||||
|
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
|
||||||
|
|
||||||
|
// Deregister one handler
|
||||||
|
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
|
||||||
|
|
||||||
|
// Verify refcount is 1
|
||||||
|
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
|
||||||
|
|
||||||
|
// Deregister the other handler
|
||||||
|
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityDNSRoute)
|
||||||
|
|
||||||
|
// Verify domain is removed
|
||||||
|
_, exists := server.extraDomains[zoneKey]
|
||||||
|
assert.False(t, exists, "Domain should be removed after deregistering all handlers")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
||||||
|
var capturedConfig HostDNSConfig
|
||||||
|
mockHostConfig := &mockHostConfigurator{
|
||||||
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
|
capturedConfig = config
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
restoreHostDNSFunc: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
supportCustomPortFunc: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
stringFunc: func() string {
|
||||||
|
return "mock"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSvc := &mockService{}
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
hostManager: mockHostConfig,
|
||||||
|
localResolver: &localResolver{},
|
||||||
|
service: mockSvc,
|
||||||
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterHandler(domain.List{"extra.example.com"}, &MockHandler{}, PriorityDefault)
|
||||||
|
|
||||||
|
initialConfig := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := server.applyConfiguration(initialConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var domains []string
|
||||||
|
for _, d := range capturedConfig.Domains {
|
||||||
|
domains = append(domains, d.Domain)
|
||||||
|
}
|
||||||
|
assert.Contains(t, domains, "config.example.com.")
|
||||||
|
assert.Contains(t, domains, "extra.example.com.")
|
||||||
|
|
||||||
|
// Now apply a new configuration with overlapping domain
|
||||||
|
updatedConfig := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
{Domain: "extra.example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = server.applyConfiguration(updatedConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify both domains are in config, but no duplicates
|
||||||
|
domains = []string{}
|
||||||
|
matchOnlyCount := 0
|
||||||
|
for _, d := range capturedConfig.Domains {
|
||||||
|
domains = append(domains, d.Domain)
|
||||||
|
if d.MatchOnly {
|
||||||
|
matchOnlyCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Contains(t, domains, "config.example.com.")
|
||||||
|
assert.Contains(t, domains, "extra.example.com.")
|
||||||
|
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
|
||||||
|
|
||||||
|
// Extra domain should no longer be marked as match-only when in config
|
||||||
|
matchOnlyDomain := ""
|
||||||
|
for _, d := range capturedConfig.Domains {
|
||||||
|
if d.Domain == "extra.example.com." && d.MatchOnly {
|
||||||
|
matchOnlyDomain = d.Domain
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Empty(t, matchOnlyDomain, "Domain should not be match-only when included in config")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDomainCaseHandling(t *testing.T) {
|
||||||
|
var capturedConfig HostDNSConfig
|
||||||
|
mockHostConfig := &mockHostConfigurator{
|
||||||
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
|
capturedConfig = config
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
restoreHostDNSFunc: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
supportCustomPortFunc: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
stringFunc: func() string {
|
||||||
|
return "mock"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSvc := &mockService{}
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
hostManager: mockHostConfig,
|
||||||
|
localResolver: &localResolver{},
|
||||||
|
service: mockSvc,
|
||||||
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
|
||||||
|
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
|
||||||
|
|
||||||
|
config := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := server.applyConfiguration(config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var domains []string
|
||||||
|
for _, d := range capturedConfig.Domains {
|
||||||
|
domains = append(domains, d.Domain)
|
||||||
|
}
|
||||||
|
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
|
||||||
|
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
|
||||||
|
}
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
"github.com/miekg/dns"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
@ -111,7 +110,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||||
Domain: dns.Fqdn(dConf.Domain),
|
Domain: dConf.Domain,
|
||||||
MatchOnly: dConf.MatchOnly,
|
MatchOnly: dConf.MatchOnly,
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -151,6 +150,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := s.flushDNSCache(); err != nil {
|
||||||
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,7 +167,8 @@ func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdD
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setting domains configuration failed with error: %w", err)
|
return fmt.Errorf("setting domains configuration failed with error: %w", err)
|
||||||
}
|
}
|
||||||
return s.flushCaches()
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
||||||
@ -183,10 +188,14 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.flushCaches()
|
if err := s.flushDNSCache(); err != nil {
|
||||||
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) flushCaches() error {
|
func (s *systemdDbusConfigurator) flushDNSCache() error {
|
||||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err)
|
return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err)
|
||||||
|
@ -23,9 +23,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
UpstreamTimeout = 15 * time.Second
|
||||||
|
|
||||||
failsTillDeact = int32(5)
|
failsTillDeact = int32(5)
|
||||||
reactivatePeriod = 30 * time.Second
|
reactivatePeriod = 30 * time.Second
|
||||||
upstreamTimeout = 15 * time.Second
|
|
||||||
probeTimeout = 2 * time.Second
|
probeTimeout = 2 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -66,7 +67,7 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
domain: domain,
|
domain: domain,
|
||||||
upstreamTimeout: upstreamTimeout,
|
upstreamTimeout: UpstreamTimeout,
|
||||||
reactivatePeriod: reactivatePeriod,
|
reactivatePeriod: reactivatePeriod,
|
||||||
failsTillDeact: failsTillDeact,
|
failsTillDeact: failsTillDeact,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
|
@ -55,7 +55,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
|
|||||||
|
|
||||||
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||||
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
timeout := upstreamTimeout
|
timeout := UpstreamTimeout
|
||||||
if deadline, ok := ctx.Deadline(); ok {
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
timeout = time.Until(deadline)
|
timeout = time.Until(deadline)
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
|
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
timeout := upstreamTimeout
|
timeout := UpstreamTimeout
|
||||||
if deadline, ok := ctx.Deadline(); ok {
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
timeout = time.Until(deadline)
|
timeout = time.Until(deadline)
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
name: "Should Resolve A Record",
|
name: "Should Resolve A Record",
|
||||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||||
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||||
timeout: upstreamTimeout,
|
timeout: UpstreamTimeout,
|
||||||
expectedAnswer: "1.1.1.1",
|
expectedAnswer: "1.1.1.1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -48,7 +48,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||||
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
||||||
cancelCTX: true,
|
cancelCTX: true,
|
||||||
timeout: upstreamTimeout,
|
timeout: UpstreamTimeout,
|
||||||
responseShouldBeNil: true,
|
responseShouldBeNil: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -122,7 +122,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
|||||||
r: new(dns.Msg),
|
r: new(dns.Msg),
|
||||||
rtt: time.Millisecond,
|
rtt: time.Millisecond,
|
||||||
},
|
},
|
||||||
upstreamTimeout: upstreamTimeout,
|
upstreamTimeout: UpstreamTimeout,
|
||||||
reactivatePeriod: reactivatePeriod,
|
reactivatePeriod: reactivatePeriod,
|
||||||
failsTillDeact: failsTillDeact,
|
failsTillDeact: failsTillDeact,
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@ -60,7 +59,7 @@ func (d *DnsInterceptor) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) AddRoute(context.Context) error {
|
func (d *DnsInterceptor) AddRoute(context.Context) error {
|
||||||
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute)
|
d.dnsServer.RegisterHandler(d.route.Domains, d, nbdns.PriorityDNSRoute)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,7 +88,7 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
|||||||
clear(d.interceptedDomains)
|
clear(d.interceptedDomains)
|
||||||
d.mu.Unlock()
|
d.mu.Unlock()
|
||||||
|
|
||||||
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute)
|
d.dnsServer.DeregisterHandler(d.route.Domains, nbdns.PriorityDNSRoute)
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
@ -142,21 +141,24 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
log.Tracef("received DNS request for domain=%s type=%v class=%v",
|
log.Tracef("received DNS request for domain=%s type=%v class=%v",
|
||||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
|
|
||||||
|
// pass if non A/AAAA query
|
||||||
|
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
|
||||||
|
d.continueToNextHandler(w, r, "non A/AAAA query")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
d.mu.RLock()
|
d.mu.RLock()
|
||||||
peerKey := d.currentPeerKey
|
peerKey := d.currentPeerKey
|
||||||
d.mu.RUnlock()
|
d.mu.RUnlock()
|
||||||
|
|
||||||
if peerKey == "" {
|
if peerKey == "" {
|
||||||
log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name)
|
d.writeDNSError(w, r, "no current peer key")
|
||||||
|
|
||||||
d.continueToNextHandler(w, r, "no current peer key")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamIP, err := d.getUpstreamIP(peerKey)
|
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to get upstream IP: %v", err)
|
d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err))
|
||||||
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -165,34 +167,43 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
r.SetEdns0(4096, false)
|
r.SetEdns0(4096, false)
|
||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &dns.Client{
|
client := &dns.Client{
|
||||||
Timeout: 5 * time.Second,
|
Timeout: nbdns.UpstreamTimeout,
|
||||||
Net: "udp",
|
Net: "udp",
|
||||||
}
|
}
|
||||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||||
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
||||||
|
|
||||||
var answer []dns.RR
|
|
||||||
if reply != nil {
|
|
||||||
answer = reply.Answer
|
|
||||||
}
|
|
||||||
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||||
log.Errorf("failed writing DNS response: %v", err)
|
log.Errorf("failed writing DNS response: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var answer []dns.RR
|
||||||
|
if reply != nil {
|
||||||
|
answer = reply.Answer
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
||||||
|
|
||||||
reply.Id = r.Id
|
reply.Id = r.Id
|
||||||
if err := d.writeMsg(w, reply); err != nil {
|
if err := d.writeMsg(w, reply); err != nil {
|
||||||
log.Errorf("failed writing DNS response: %v", err)
|
log.Errorf("failed writing DNS response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||||
|
log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
|
||||||
|
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeServerFailure)
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS error response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// continueToNextHandler signals the handler chain to try the next handler
|
// continueToNextHandler signals the handler chain to try the next handler
|
||||||
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||||
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
||||||
|
@ -235,7 +235,7 @@ func (r *Route) resolve(results chan resolveResult) {
|
|||||||
ips, err := r.getIPsFromResolver(domain)
|
ips, err := r.getIPsFromResolver(domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err)
|
log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err)
|
||||||
ips, err = net.LookupIP(string(domain))
|
ips, err = net.LookupIP(domain.PunycodeString())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
|
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
|
||||||
return
|
return
|
||||||
|
@ -9,5 +9,5 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
||||||
return net.LookupIP(string(domain))
|
return net.LookupIP(domain.PunycodeString())
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
msg := new(dns.Msg)
|
msg := new(dns.Msg)
|
||||||
msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA)
|
msg.SetQuestion(dns.Fqdn(domain.PunycodeString()), dns.TypeA)
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
|||||||
|
|
||||||
// Convert to proto format
|
// Convert to proto format
|
||||||
for domain, ips := range domainMap {
|
for domain, ips := range domainMap {
|
||||||
pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{
|
pbRoute.ResolvedIPs[domain.PunycodeString()] = &proto.IPList{
|
||||||
Ips: ips,
|
Ips: ips,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,11 @@ func (d Domain) SafeString() string {
|
|||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PunycodeString returns the punycode representation of the Domain.
|
||||||
|
func (d Domain) PunycodeString() string {
|
||||||
|
return string(d)
|
||||||
|
}
|
||||||
|
|
||||||
// FromString creates a Domain from a string, converting it to punycode.
|
// FromString creates a Domain from a string, converting it to punycode.
|
||||||
func FromString(s string) (Domain, error) {
|
func FromString(s string) (Domain, error) {
|
||||||
ascii, err := idna.ToASCII(s)
|
ascii, err := idna.ToASCII(s)
|
||||||
|
Reference in New Issue
Block a user