[client] Support dns upstream failover for nameserver groups with same match domain (#3178)

This commit is contained in:
Viktor Liu 2025-02-10 18:13:34 +01:00 committed by GitHub
parent 5953b43ead
commit 488b697479
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 747 additions and 186 deletions

View File

@ -12,7 +12,7 @@ import (
const ( const (
PriorityDNSRoute = 100 PriorityDNSRoute = 100
PriorityMatchDomain = 50 PriorityMatchDomain = 50
PriorityDefault = 0 PriorityDefault = 1
) )
type SubdomainMatcher interface { type SubdomainMatcher interface {
@ -26,7 +26,6 @@ type HandlerEntry struct {
Pattern string Pattern string
OrigPattern string OrigPattern string
IsWildcard bool IsWildcard bool
StopHandler handlerWithStop
MatchSubdomains bool MatchSubdomains bool
} }
@ -64,7 +63,7 @@ func (w *ResponseWriterChain) GetOrigPattern() string {
} }
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority // AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) { func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -78,9 +77,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
// First remove any existing handler with same pattern (case-insensitive) and priority // First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority { if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
if c.handlers[i].StopHandler != nil {
c.handlers[i].StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break break
} }
@ -101,7 +97,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
Pattern: pattern, Pattern: pattern,
OrigPattern: origPattern, OrigPattern: origPattern,
IsWildcard: isWildcard, IsWildcard: isWildcard,
StopHandler: stopHandler,
MatchSubdomains: matchSubdomains, MatchSubdomains: matchSubdomains,
} }
@ -142,9 +137,6 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i] entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
if entry.StopHandler != nil {
entry.StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return return
} }
@ -180,8 +172,8 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if log.IsLevelEnabled(log.TraceLevel) { if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("current handlers (%d):", len(handlers)) log.Tracef("current handlers (%d):", len(handlers))
for _, h := range handlers { for _, h := range handlers {
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d", log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority) h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
} }
} }
@ -206,13 +198,13 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
if !matched { if !matched {
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false", log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard) qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
continue continue
} }
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v", log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains) qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{ chainWriter := &ResponseWriterChain{
ResponseWriter: w, ResponseWriter: w,

View File

@ -21,9 +21,9 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
dnsRouteHandler := &nbdns.MockHandler{} dnsRouteHandler := &nbdns.MockHandler{}
// Setup handlers with different priorities // Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil) chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil) chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil) chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
// Create test request // Create test request
r := new(dns.Msg) r := new(dns.Msg)
@ -138,7 +138,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
pattern = "*." + tt.handlerDomain[2:] pattern = "*." + tt.handlerDomain[2:]
} }
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil) chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain, dns.TypeA)
@ -253,7 +253,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe() handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
} }
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil) chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority)
} }
// Create and execute request // Create and execute request
@ -280,9 +280,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
handler3 := &nbdns.MockHandler{} handler3 := &nbdns.MockHandler{}
// Add handlers in priority order // Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil) chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil) chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil) chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
// Create test request // Create test request
r := new(dns.Msg) r := new(dns.Msg)
@ -416,7 +416,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
if op.action == "add" { if op.action == "add" {
handler := &nbdns.MockHandler{} handler := &nbdns.MockHandler{}
handlers[op.priority] = handler handlers[op.priority] = handler
chain.AddHandler(op.pattern, handler, op.priority, nil) chain.AddHandler(op.pattern, handler, op.priority)
} else { } else {
chain.RemoveHandler(op.pattern, op.priority) chain.RemoveHandler(op.pattern, op.priority)
} }
@ -471,9 +471,9 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
r.SetQuestion(testQuery, dns.TypeA) r.SetQuestion(testQuery, dns.TypeA)
// Add handlers in mixed order // Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil) chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil) chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
// Test 1: Initial state with all three handlers // Test 1: Initial state with all three handlers
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
@ -653,7 +653,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
handler = mockHandler handler = mockHandler
} }
chain.AddHandler(pattern, handler, h.priority, nil) chain.AddHandler(pattern, handler, h.priority)
} }
// Execute request // Execute request

View File

@ -29,10 +29,15 @@ func (d *localResolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap)) return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
} }
// ID returns the unique handler ID
func (d *localResolver) id() handlerID {
return "local-resolver"
}
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) > 0 { if len(r.Question) > 0 {
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
} }
replyMessage := &dns.Msg{} replyMessage := &dns.Msg{}

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"runtime" "runtime"
"strings"
"sync" "sync"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -42,7 +41,12 @@ type Server interface {
ProbeAvailability() ProbeAvailability()
} }
type registeredHandlerMap map[string]handlerWithStop type handlerID string
type nsGroupsByDomain struct {
domain string
groups []*nbdns.NameServerGroup
}
// DefaultServer dns server object // DefaultServer dns server object
type DefaultServer struct { type DefaultServer struct {
@ -52,7 +56,6 @@ type DefaultServer struct {
mux sync.Mutex mux sync.Mutex
service service service service
dnsMuxMap registeredHandlerMap dnsMuxMap registeredHandlerMap
handlerPriorities map[string]int
localResolver *localResolver localResolver *localResolver
wgInterface WGIface wgInterface WGIface
hostManager hostManager hostManager hostManager
@ -77,14 +80,17 @@ type handlerWithStop interface {
dns.Handler dns.Handler
stop() stop()
probeAvailability() probeAvailability()
id() handlerID
} }
type muxUpdate struct { type handlerWrapper struct {
domain string domain string
handler handlerWithStop handler handlerWithStop
priority int priority int
} }
type registeredHandlerMap map[handlerID]handlerWrapper
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer( func NewDefaultServer(
ctx context.Context, ctx context.Context,
@ -158,13 +164,12 @@ func newDefaultServer(
) *DefaultServer { ) *DefaultServer {
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
ctxCancel: stop, ctxCancel: stop,
disableSys: disableSys, disableSys: disableSys,
service: dnsService, service: dnsService,
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
dnsMuxMap: make(registeredHandlerMap), dnsMuxMap: make(registeredHandlerMap),
handlerPriorities: make(map[string]int),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
@ -192,8 +197,7 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
log.Warn("skipping empty domain") log.Warn("skipping empty domain")
continue continue
} }
s.handlerChain.AddHandler(domain, handler, priority, nil) s.handlerChain.AddHandler(domain, handler, priority)
s.handlerPriorities[domain] = priority
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain) s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
} }
} }
@ -209,14 +213,15 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
log.Debugf("deregistering handler %v with priority %d", domains, priority) log.Debugf("deregistering handler %v with priority %d", domains, priority)
for _, domain := range domains { for _, domain := range domains {
if domain == "" {
log.Warn("skipping empty domain")
continue
}
s.handlerChain.RemoveHandler(domain, priority) s.handlerChain.RemoveHandler(domain, priority)
// Only deregister from service if no handlers remain // Only deregister from service if no handlers remain
if !s.handlerChain.HasHandlers(domain) { if !s.handlerChain.HasHandlers(domain) {
if domain == "" {
log.Warn("skipping empty domain")
continue
}
s.service.DeregisterMux(nbdns.NormalizeZone(domain)) s.service.DeregisterMux(nbdns.NormalizeZone(domain))
} }
} }
@ -283,14 +288,24 @@ func (s *DefaultServer) Stop() {
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone // It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
s.hostsDNSHolder.set(hostsDnsList) s.hostsDNSHolder.set(hostsDnsList)
_, ok := s.dnsMuxMap[nbdns.RootZone] // Check if there's any root handler
if ok { var hasRootHandler bool
for _, handler := range s.dnsMuxMap {
if handler.domain == nbdns.RootZone {
hasRootHandler = true
break
}
}
if hasRootHandler {
log.Debugf("on new host DNS config but skip to apply it") log.Debugf("on new host DNS config but skip to apply it")
return return
} }
log.Debugf("update host DNS settings: %+v", hostsDnsList) log.Debugf("update host DNS settings: %+v", hostsDnsList)
s.addHostRootZone() s.addHostRootZone()
} }
@ -364,7 +379,7 @@ func (s *DefaultServer) ProbeAvailability() {
go func(mux handlerWithStop) { go func(mux handlerWithStop) {
defer wg.Done() defer wg.Done()
mux.probeAvailability() mux.probeAvailability()
}(mux) }(mux.handler)
} }
wg.Wait() wg.Wait()
} }
@ -419,8 +434,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
return nil return nil
} }
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) { func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, map[string]nbdns.SimpleRecord, error) {
var muxUpdates []muxUpdate var muxUpdates []handlerWrapper
localRecords := make(map[string]nbdns.SimpleRecord, 0) localRecords := make(map[string]nbdns.SimpleRecord, 0)
for _, customZone := range customZones { for _, customZone := range customZones {
@ -428,7 +443,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
return nil, nil, fmt.Errorf("received an empty list of records") return nil, nil, fmt.Errorf("received an empty list of records")
} }
muxUpdates = append(muxUpdates, muxUpdate{ muxUpdates = append(muxUpdates, handlerWrapper{
domain: customZone.Domain, domain: customZone.Domain,
handler: s.localResolver, handler: s.localResolver,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
@ -446,15 +461,59 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
return muxUpdates, localRecords, nil return muxUpdates, localRecords, nil
} }
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) { func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
var muxUpdates []handlerWrapper
var muxUpdates []muxUpdate
for _, nsGroup := range nameServerGroups { for _, nsGroup := range nameServerGroups {
if len(nsGroup.NameServers) == 0 { if len(nsGroup.NameServers) == 0 {
log.Warn("received a nameserver group with empty nameserver list") log.Warn("received a nameserver group with empty nameserver list")
continue continue
} }
if !nsGroup.Primary && len(nsGroup.Domains) == 0 {
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
}
for _, domain := range nsGroup.Domains {
if domain == "" {
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
}
}
}
groupedNS := groupNSGroupsByDomain(nameServerGroups)
for _, domainGroup := range groupedNS {
basePriority := PriorityMatchDomain
if domainGroup.domain == nbdns.RootZone {
basePriority = PriorityDefault
}
updates, err := s.createHandlersForDomainGroup(domainGroup, basePriority)
if err != nil {
return nil, err
}
muxUpdates = append(muxUpdates, updates...)
}
return muxUpdates, nil
}
func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomain, basePriority int) ([]handlerWrapper, error) {
var muxUpdates []handlerWrapper
for i, nsGroup := range domainGroup.groups {
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
priority := basePriority - i
// Check if we're about to overlap with the next priority tier
if basePriority == PriorityMatchDomain && priority <= PriorityDefault {
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
domainGroup.domain, PriorityMatchDomain-PriorityDefault)
break
}
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
handler, err := newUpstreamResolver( handler, err := newUpstreamResolver(
s.ctx, s.ctx,
s.wgInterface.Name(), s.wgInterface.Name(),
@ -462,10 +521,12 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
s.wgInterface.Address().Network, s.wgInterface.Address().Network,
s.statusRecorder, s.statusRecorder,
s.hostsDNSHolder, s.hostsDNSHolder,
domainGroup.domain,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err) return nil, fmt.Errorf("create upstream resolver: %v", err)
} }
for _, ns := range nsGroup.NameServers { for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType { if ns.NSType != nbdns.UDPNameServerType {
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
@ -489,78 +550,47 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
// after some period defined by upstream it tries to reactivate self by calling this hook // after some period defined by upstream it tries to reactivate self by calling this hook
// everything we need here is just to re-apply current configuration because it already // everything we need here is just to re-apply current configuration because it already
// contains this upstream settings (temporal deactivation not removed it) // contains this upstream settings (temporal deactivation not removed it)
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler) handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler, priority)
if nsGroup.Primary { muxUpdates = append(muxUpdates, handlerWrapper{
muxUpdates = append(muxUpdates, muxUpdate{ domain: domainGroup.domain,
domain: nbdns.RootZone, handler: handler,
handler: handler, priority: priority,
priority: PriorityDefault, })
})
continue
}
if len(nsGroup.Domains) == 0 {
handler.stop()
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
}
for _, domain := range nsGroup.Domains {
if domain == "" {
handler.stop()
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
}
muxUpdates = append(muxUpdates, muxUpdate{
domain: domain,
handler: handler,
priority: PriorityMatchDomain,
})
}
} }
return muxUpdates, nil return muxUpdates, nil
} }
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
muxUpdateMap := make(registeredHandlerMap) // this will introduce a short period of time when the server is not able to handle DNS requests
handlersByPriority := make(map[string]int) for _, existing := range s.dnsMuxMap {
s.deregisterHandler([]string{existing.domain}, existing.priority)
var isContainRootUpdate bool existing.handler.stop()
// First register new handlers
for _, update := range muxUpdates {
s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.domain] = update.handler
handlersByPriority[update.domain] = update.priority
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
existingHandler.stop()
}
if update.domain == nbdns.RootZone {
isContainRootUpdate = true
}
} }
// Then deregister old handlers not in the update muxUpdateMap := make(registeredHandlerMap)
for key, existingHandler := range s.dnsMuxMap { var containsRootUpdate bool
_, found := muxUpdateMap[key]
if !found { for _, update := range muxUpdates {
if !isContainRootUpdate && key == nbdns.RootZone { if update.domain == nbdns.RootZone {
containsRootUpdate = true
}
s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.id()] = update
}
// If there's no root update and we had a root handler, restore it
if !containsRootUpdate {
for _, existing := range s.dnsMuxMap {
if existing.domain == nbdns.RootZone {
s.addHostRootZone() s.addHostRootZone()
existingHandler.stop() break
} else {
existingHandler.stop()
// Deregister with the priority that was used to register
if oldPriority, ok := s.handlerPriorities[key]; ok {
s.deregisterHandler([]string{key}, oldPriority)
}
} }
} }
} }
s.dnsMuxMap = muxUpdateMap s.dnsMuxMap = muxUpdateMap
s.handlerPriorities = handlersByPriority
} }
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) { func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
@ -593,6 +623,7 @@ func getNSHostPort(ns nbdns.NameServer) string {
func (s *DefaultServer) upstreamCallbacks( func (s *DefaultServer) upstreamCallbacks(
nsGroup *nbdns.NameServerGroup, nsGroup *nbdns.NameServerGroup,
handler dns.Handler, handler dns.Handler,
priority int,
) (deactivate func(error), reactivate func()) { ) (deactivate func(error), reactivate func()) {
var removeIndex map[string]int var removeIndex map[string]int
deactivate = func(err error) { deactivate = func(err error) {
@ -609,13 +640,13 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary { if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1 removeIndex[nbdns.RootZone] = -1
s.currentConfig.RouteAll = false s.currentConfig.RouteAll = false
s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault) s.deregisterHandler([]string{nbdns.RootZone}, priority)
} }
for i, item := range s.currentConfig.Domains { for i, item := range s.currentConfig.Domains {
if _, found := removeIndex[item.Domain]; found { if _, found := removeIndex[item.Domain]; found {
s.currentConfig.Domains[i].Disabled = true s.currentConfig.Domains[i].Disabled = true
s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain) s.deregisterHandler([]string{item.Domain}, priority)
removeIndex[item.Domain] = i removeIndex[item.Domain] = i
} }
} }
@ -635,8 +666,8 @@ func (s *DefaultServer) upstreamCallbacks(
} }
s.updateNSState(nsGroup, err, false) s.updateNSState(nsGroup, err, false)
} }
reactivate = func() { reactivate = func() {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@ -646,7 +677,7 @@ func (s *DefaultServer) upstreamCallbacks(
continue continue
} }
s.currentConfig.Domains[i].Disabled = false s.currentConfig.Domains[i].Disabled = false
s.registerHandler([]string{domain}, handler, PriorityMatchDomain) s.registerHandler([]string{domain}, handler, priority)
} }
l := log.WithField("nameservers", nsGroup.NameServers) l := log.WithField("nameservers", nsGroup.NameServers)
@ -654,7 +685,7 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary { if nsGroup.Primary {
s.currentConfig.RouteAll = true s.currentConfig.RouteAll = true
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) s.registerHandler([]string{nbdns.RootZone}, handler, priority)
} }
if s.hostManager != nil { if s.hostManager != nil {
@ -676,6 +707,7 @@ func (s *DefaultServer) addHostRootZone() {
s.wgInterface.Address().Network, s.wgInterface.Address().Network,
s.statusRecorder, s.statusRecorder,
s.hostsDNSHolder, s.hostsDNSHolder,
nbdns.RootZone,
) )
if err != nil { if err != nil {
log.Errorf("unable to create a new upstream resolver, error: %v", err) log.Errorf("unable to create a new upstream resolver, error: %v", err)
@ -732,5 +764,34 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
for _, ns := range nsGroup.NameServers { for _, ns := range nsGroup.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
} }
return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ",")) return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
}
// groupNSGroupsByDomain groups nameserver groups by their match domains
func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain {
domainMap := make(map[string][]*nbdns.NameServerGroup)
for _, group := range nsGroups {
if group.Primary {
domainMap[nbdns.RootZone] = append(domainMap[nbdns.RootZone], group)
continue
}
for _, domain := range group.Domains {
if domain == "" {
continue
}
domainMap[domain] = append(domainMap[domain], group)
}
}
var result []nsGroupsByDomain
for domain, groups := range domainMap {
result = append(result, nsGroupsByDomain{
domain: domain,
groups: groups,
})
}
return result
} }

View File

@ -13,6 +13,7 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@ -88,6 +89,18 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger()) formatter.SetTextFormatter(log.StandardLogger())
} }
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []string
for _, srv := range servers {
srvs = append(srvs, getNSHostPort(srv))
}
return &upstreamResolverBase{
domain: domain,
upstreamServers: srvs,
cancel: func() {},
}
}
func TestUpdateDNSServer(t *testing.T) { func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{ nameServers := []nbdns.NameServer{
{ {
@ -140,15 +153,37 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
}, },
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler, nbdns.RootZone: dummyHandler}, expectedUpstreamMap: registeredHandlerMap{
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityMatchDomain,
},
dummyHandler.id(): handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityMatchDomain,
},
generateDummyHandler(".", nameServers).id(): handlerWrapper{
domain: nbdns.RootZone,
handler: dummyHandler,
priority: PriorityDefault,
},
},
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
}, },
{ {
name: "New Config Should Succeed", name: "New Config Should Succeed",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registeredHandlerMap{buildRecordKey(zoneRecords[0].Name, 1, 1): dummyHandler}, initUpstreamMap: registeredHandlerMap{
initSerial: 0, generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
inputSerial: 1, domain: buildRecordKey(zoneRecords[0].Name, 1, 1),
handler: dummyHandler,
priority: PriorityMatchDomain,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
ServiceEnable: true, ServiceEnable: true,
CustomZones: []nbdns.CustomZone{ CustomZones: []nbdns.CustomZone{
@ -164,8 +199,19 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
}, },
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler}, expectedUpstreamMap: registeredHandlerMap{
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
domain: "netbird.io",
handler: dummyHandler,
priority: PriorityMatchDomain,
},
"local-resolver": handlerWrapper{
domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityMatchDomain,
},
},
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
}, },
{ {
name: "Smaller Config Serial Should Be Skipped", name: "Smaller Config Serial Should Be Skipped",
@ -242,9 +288,15 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true, shouldFail: true,
}, },
{ {
name: "Empty Config Should Succeed and Clean Maps", name: "Empty Config Should Succeed and Clean Maps",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler}, initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityMatchDomain,
},
},
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true}, inputUpdate: nbdns.Config{ServiceEnable: true},
@ -252,9 +304,15 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalMap: make(registrationMap), expectedLocalMap: make(registrationMap),
}, },
{ {
name: "Disabled Service Should clean map", name: "Disabled Service Should clean map",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler}, initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
domain: zoneRecords[0].Name,
handler: dummyHandler,
priority: PriorityMatchDomain,
},
},
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false}, inputUpdate: nbdns.Config{ServiceEnable: false},
@ -421,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
} }
}() }()
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}} dnsServer.dnsMuxMap = registeredHandlerMap{
"id1": handlerWrapper{
domain: zoneRecords[0].Name,
handler: &localResolver{},
priority: PriorityMatchDomain,
},
}
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}} dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
dnsServer.updateSerial = 0 dnsServer.updateSerial = 0
@ -562,9 +626,8 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
handlerPriorities: make(map[string]int), hostManager: hostManager,
hostManager: hostManager,
currentConfig: HostDNSConfig{ currentConfig: HostDNSConfig{
Domains: []DomainConfig{ Domains: []DomainConfig{
{false, "domain0", false}, {false, "domain0", false},
@ -593,7 +656,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53}, {IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
}, },
}, nil) }, nil, 0)
deactivate(nil) deactivate(nil)
expected := "domain0,domain2" expected := "domain0,domain2"
@ -903,8 +966,8 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
Subdomains: true, Subdomains: true,
} }
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil) chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil) chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
testCases := []struct { testCases := []struct {
name string name string
@ -959,3 +1022,421 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
}) })
} }
} }
type mockHandler struct {
Id string
}
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) stop() {}
func (m *mockHandler) probeAvailability() {}
func (m *mockHandler) id() handlerID { return handlerID(m.Id) }
type mockService struct{}
func (m *mockService) Listen() error { return nil }
func (m *mockService) Stop() {}
func (m *mockService) RuntimeIP() string { return "127.0.0.1" }
func (m *mockService) RuntimePort() int { return 53 }
func (m *mockService) RegisterMux(string, dns.Handler) {}
func (m *mockService) DeregisterMux(string) {}
func TestDefaultServer_UpdateMux(t *testing.T) {
baseMatchHandlers := registeredHandlerMap{
"upstream-group1": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
},
"upstream-group2": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
},
}
baseRootHandlers := registeredHandlerMap{
"upstream-root1": {
domain: ".",
handler: &mockHandler{
Id: "upstream-root1",
},
priority: PriorityDefault,
},
"upstream-root2": {
domain: ".",
handler: &mockHandler{
Id: "upstream-root2",
},
priority: PriorityDefault - 1,
},
}
baseMixedHandlers := registeredHandlerMap{
"upstream-group1": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
},
"upstream-group2": {
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
},
"upstream-other": {
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityMatchDomain,
},
}
tests := []struct {
name string
initialHandlers registeredHandlerMap
updates []handlerWrapper
expectedHandlers map[string]string // map[handlerID]domain
description string
}{
{
name: "Remove group1 from update",
initialHandlers: baseMatchHandlers,
updates: []handlerWrapper{
// Only group2 remains
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
},
},
expectedHandlers: map[string]string{
"upstream-group2": "example.com",
},
description: "When group1 is not included in the update, it should be removed while group2 remains",
},
{
name: "Remove group2 from update",
initialHandlers: baseMatchHandlers,
updates: []handlerWrapper{
// Only group1 remains
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
},
},
expectedHandlers: map[string]string{
"upstream-group1": "example.com",
},
description: "When group2 is not included in the update, it should be removed while group1 remains",
},
{
name: "Add group3 in first position",
initialHandlers: baseMatchHandlers,
updates: []handlerWrapper{
// Add group3 with highest priority
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group3",
},
priority: PriorityMatchDomain + 1,
},
// Keep existing groups with their original priorities
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
},
},
expectedHandlers: map[string]string{
"upstream-group1": "example.com",
"upstream-group2": "example.com",
"upstream-group3": "example.com",
},
description: "When adding group3 with highest priority, it should be first in chain while maintaining existing groups",
},
{
name: "Add group3 in last position",
initialHandlers: baseMatchHandlers,
updates: []handlerWrapper{
// Keep existing groups with their original priorities
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
},
// Add group3 with lowest priority
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group3",
},
priority: PriorityMatchDomain - 2,
},
},
expectedHandlers: map[string]string{
"upstream-group1": "example.com",
"upstream-group2": "example.com",
"upstream-group3": "example.com",
},
description: "When adding group3 with lowest priority, it should be last in chain while maintaining existing groups",
},
// Root zone tests
{
name: "Remove root1 from update",
initialHandlers: baseRootHandlers,
updates: []handlerWrapper{
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root2",
},
priority: PriorityDefault - 1,
},
},
expectedHandlers: map[string]string{
"upstream-root2": ".",
},
description: "When root1 is not included in the update, it should be removed while root2 remains",
},
{
name: "Remove root2 from update",
initialHandlers: baseRootHandlers,
updates: []handlerWrapper{
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root1",
},
priority: PriorityDefault,
},
},
expectedHandlers: map[string]string{
"upstream-root1": ".",
},
description: "When root2 is not included in the update, it should be removed while root1 remains",
},
{
name: "Add root3 in first position",
initialHandlers: baseRootHandlers,
updates: []handlerWrapper{
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root3",
},
priority: PriorityDefault + 1,
},
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root1",
},
priority: PriorityDefault,
},
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root2",
},
priority: PriorityDefault - 1,
},
},
expectedHandlers: map[string]string{
"upstream-root1": ".",
"upstream-root2": ".",
"upstream-root3": ".",
},
description: "When adding root3 with highest priority, it should be first in chain while maintaining existing root handlers",
},
{
name: "Add root3 in last position",
initialHandlers: baseRootHandlers,
updates: []handlerWrapper{
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root1",
},
priority: PriorityDefault,
},
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root2",
},
priority: PriorityDefault - 1,
},
{
domain: ".",
handler: &mockHandler{
Id: "upstream-root3",
},
priority: PriorityDefault - 2,
},
},
expectedHandlers: map[string]string{
"upstream-root1": ".",
"upstream-root2": ".",
"upstream-root3": ".",
},
description: "When adding root3 with lowest priority, it should be last in chain while maintaining existing root handlers",
},
// Mixed domain tests
{
name: "Update with mixed domains - remove one of duplicate domain",
initialHandlers: baseMixedHandlers,
updates: []handlerWrapper{
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
},
{
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityMatchDomain,
},
},
expectedHandlers: map[string]string{
"upstream-group1": "example.com",
"upstream-other": "other.com",
},
description: "When updating mixed domains, should correctly handle removal of one duplicate while maintaining other domains",
},
{
name: "Update with mixed domains - add new domain",
initialHandlers: baseMixedHandlers,
updates: []handlerWrapper{
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group1",
},
priority: PriorityMatchDomain,
},
{
domain: "example.com",
handler: &mockHandler{
Id: "upstream-group2",
},
priority: PriorityMatchDomain - 1,
},
{
domain: "other.com",
handler: &mockHandler{
Id: "upstream-other",
},
priority: PriorityMatchDomain,
},
{
domain: "new.com",
handler: &mockHandler{
Id: "upstream-new",
},
priority: PriorityMatchDomain,
},
},
expectedHandlers: map[string]string{
"upstream-group1": "example.com",
"upstream-group2": "example.com",
"upstream-other": "other.com",
"upstream-new": "new.com",
},
description: "When updating mixed domains, should maintain existing duplicates and add new domain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &DefaultServer{
dnsMuxMap: tt.initialHandlers,
handlerChain: NewHandlerChain(),
service: &mockService{},
}
// Perform the update
server.updateMux(tt.updates)
// Verify the results
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
"Number of handlers after update doesn't match expected")
// Check each expected handler
for id, expectedDomain := range tt.expectedHandlers {
handler, exists := server.dnsMuxMap[handlerID(id)]
assert.True(t, exists, "Expected handler %s not found", id)
if exists {
assert.Equal(t, expectedDomain, handler.domain,
"Domain mismatch for handler %s", id)
}
}
// Verify no unexpected handlers exist
for handlerID := range server.dnsMuxMap {
_, expected := tt.expectedHandlers[string(handlerID)]
assert.True(t, expected, "Unexpected handler found: %s", handlerID)
}
// Verify the handlerChain state and order
previousPriority := 0
for _, chainEntry := range server.handlerChain.handlers {
// Verify priority order
if previousPriority > 0 {
assert.True(t, chainEntry.Priority <= previousPriority,
"Handlers in chain not properly ordered by priority")
}
previousPriority = chainEntry.Priority
// Verify handler exists in mux
foundInMux := false
for _, muxEntry := range server.dnsMuxMap {
if chainEntry.Handler == muxEntry.handler &&
chainEntry.Priority == muxEntry.priority &&
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
foundInMux = true
break
}
}
assert.True(t, foundInMux,
"Handler in chain not found in dnsMuxMap")
}
})
}
}

View File

@ -2,9 +2,13 @@ package dns
import ( import (
"context" "context"
"crypto/sha256"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"slices"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -40,6 +44,7 @@ type upstreamResolverBase struct {
cancel context.CancelFunc cancel context.CancelFunc
upstreamClient upstreamClient upstreamClient upstreamClient
upstreamServers []string upstreamServers []string
domain string
disabled bool disabled bool
failsCount atomic.Int32 failsCount atomic.Int32
successCount atomic.Int32 successCount atomic.Int32
@ -53,12 +58,13 @@ type upstreamResolverBase struct {
statusRecorder *peer.Status statusRecorder *peer.Status
} }
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase { func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
return &upstreamResolverBase{ return &upstreamResolverBase{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
domain: domain,
upstreamTimeout: upstreamTimeout, upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod, reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact, failsTillDeact: failsTillDeact,
@ -71,6 +77,17 @@ func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("upstream %v", u.upstreamServers) return fmt.Sprintf("upstream %v", u.upstreamServers)
} }
// ID returns the unique handler ID
func (u *upstreamResolverBase) id() handlerID {
servers := slices.Clone(u.upstreamServers)
slices.Sort(servers)
hash := sha256.New()
hash.Write([]byte(u.domain + ":"))
hash.Write([]byte(strings.Join(servers, ",")))
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
}
func (u *upstreamResolverBase) MatchSubdomains() bool { func (u *upstreamResolverBase) MatchSubdomains() bool {
return true return true
} }
@ -87,7 +104,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
u.checkUpstreamFails(err) u.checkUpstreamFails(err)
}() }()
log.WithField("question", r.Question[0]).Trace("received an upstream question") log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records // set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil { if r.Extra == nil {
r.SetEdns0(4096, false) r.SetEdns0(4096, false)
@ -96,6 +113,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
log.Tracef("%s has been stopped", u)
return return
default: default:
} }
@ -112,41 +130,36 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
log.WithError(err).WithField("upstream", upstream). log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
Warn("got an error while connecting to upstream")
continue continue
} }
u.failsCount.Add(1) log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
log.WithError(err).WithField("upstream", upstream). continue
Error("got other error while querying the upstream")
return
} }
if rm == nil { if rm == nil || !rm.Response {
log.WithError(err).WithField("upstream", upstream). log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
Warn("no response from upstream") continue
return
}
// those checks need to be independent of each other due to memory address issues
if !rm.Response {
log.WithError(err).WithField("upstream", upstream).
Warn("no response from upstream")
return
} }
u.successCount.Add(1) u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s", t, upstream) log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
err = w.WriteMsg(rm) if err = w.WriteMsg(rm); err != nil {
if err != nil { log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
log.WithError(err).Error("got an error while writing the upstream resolver response")
} }
// count the fails only if they happen sequentially // count the fails only if they happen sequentially
u.failsCount.Store(0) u.failsCount.Store(0)
return return
} }
u.failsCount.Add(1) u.failsCount.Add(1)
log.Error("all queries to the upstream nameservers failed with timeout") log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil {
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
}
} }
// checkUpstreamFails counts fails and disables or enables upstream resolving // checkUpstreamFails counts fails and disables or enables upstream resolving

View File

@ -27,8 +27,9 @@ func newUpstreamResolver(
_ *net.IPNet, _ *net.IPNet,
statusRecorder *peer.Status, statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder, hostsDNSHolder *hostsDNSHolder,
domain string,
) (*upstreamResolver, error) { ) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
c := &upstreamResolver{ c := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,
hostsDNSHolder: hostsDNSHolder, hostsDNSHolder: hostsDNSHolder,

View File

@ -23,8 +23,9 @@ func newUpstreamResolver(
_ *net.IPNet, _ *net.IPNet,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string,
) (*upstreamResolver, error) { ) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
nonIOS := &upstreamResolver{ nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,
} }

View File

@ -30,8 +30,9 @@ func newUpstreamResolver(
net *net.IPNet, net *net.IPNet,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string,
) (*upstreamResolverIOS, error) { ) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
ios := &upstreamResolverIOS{ ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,

View File

@ -20,6 +20,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
timeout time.Duration timeout time.Duration
cancelCTX bool cancelCTX bool
expectedAnswer string expectedAnswer string
acceptNXDomain bool
}{ }{
{ {
name: "Should Resolve A Record", name: "Should Resolve A Record",
@ -36,11 +37,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
expectedAnswer: "1.1.1.1", expectedAnswer: "1.1.1.1",
}, },
{ {
name: "Should Not Resolve If Can't Connect To Both Servers", name: "Should Not Resolve If Can't Connect To Both Servers",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"}, InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
timeout: 200 * time.Millisecond, timeout: 200 * time.Millisecond,
responseShouldBeNil: true, acceptNXDomain: true,
}, },
{ {
name: "Should Not Resolve If Parent Context Is Canceled", name: "Should Not Resolve If Parent Context Is Canceled",
@ -51,14 +52,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
responseShouldBeNil: true, responseShouldBeNil: true,
}, },
} }
// should resolve if first upstream times out
// should not write when both fails
// should not resolve if parent context is canceled
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil) resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".")
resolver.upstreamServers = testCase.InputServers resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX { if testCase.cancelCTX {
@ -84,16 +82,22 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
t.Fatalf("should write a response message") t.Fatalf("should write a response message")
} }
foundAnswer := false if testCase.acceptNXDomain && responseMSG.Rcode == dns.RcodeNameError {
for _, answer := range responseMSG.Answer { return
if strings.Contains(answer.String(), testCase.expectedAnswer) {
foundAnswer = true
break
}
} }
if !foundAnswer { if testCase.expectedAnswer != "" {
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer) foundAnswer := false
for _, answer := range responseMSG.Answer {
if strings.Contains(answer.String(), testCase.expectedAnswer) {
foundAnswer = true
break
}
}
if !foundAnswer {
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
}
} }
}) })
} }

View File

@ -721,7 +721,9 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
func (d *Status) GetDNSStates() []NSGroupState { func (d *Status) GetDNSStates() []NSGroupState {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
return d.nsGroupStates
// shallow copy is good enough, as slices fields are currently not updated
return slices.Clone(d.nsGroupStates)
} }
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo { func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {