mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-11 20:36:40 +02:00
[client] Support dns upstream failover for nameserver groups with same match domain (#3178)
This commit is contained in:
parent
5953b43ead
commit
488b697479
@ -12,7 +12,7 @@ import (
|
||||
const (
|
||||
PriorityDNSRoute = 100
|
||||
PriorityMatchDomain = 50
|
||||
PriorityDefault = 0
|
||||
PriorityDefault = 1
|
||||
)
|
||||
|
||||
type SubdomainMatcher interface {
|
||||
@ -26,7 +26,6 @@ type HandlerEntry struct {
|
||||
Pattern string
|
||||
OrigPattern string
|
||||
IsWildcard bool
|
||||
StopHandler handlerWithStop
|
||||
MatchSubdomains bool
|
||||
}
|
||||
|
||||
@ -64,7 +63,7 @@ func (w *ResponseWriterChain) GetOrigPattern() string {
|
||||
}
|
||||
|
||||
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@ -78,9 +77,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||
if c.handlers[i].StopHandler != nil {
|
||||
c.handlers[i].StopHandler.stop()
|
||||
}
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
break
|
||||
}
|
||||
@ -101,7 +97,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
Pattern: pattern,
|
||||
OrigPattern: origPattern,
|
||||
IsWildcard: isWildcard,
|
||||
StopHandler: stopHandler,
|
||||
MatchSubdomains: matchSubdomains,
|
||||
}
|
||||
|
||||
@ -142,9 +137,6 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||
if entry.StopHandler != nil {
|
||||
entry.StopHandler.stop()
|
||||
}
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
return
|
||||
}
|
||||
@ -180,8 +172,8 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if log.IsLevelEnabled(log.TraceLevel) {
|
||||
log.Tracef("current handlers (%d):", len(handlers))
|
||||
for _, h := range handlers {
|
||||
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
|
||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
||||
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
|
||||
}
|
||||
}
|
||||
|
||||
@ -206,13 +198,13 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
if !matched {
|
||||
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
|
||||
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
|
||||
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
|
||||
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
|
||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
|
||||
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||
|
||||
chainWriter := &ResponseWriterChain{
|
||||
ResponseWriter: w,
|
||||
|
@ -21,9 +21,9 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||
dnsRouteHandler := &nbdns.MockHandler{}
|
||||
|
||||
// Setup handlers with different priorities
|
||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
|
||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
|
||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
|
||||
|
||||
// Create test request
|
||||
r := new(dns.Msg)
|
||||
@ -138,7 +138,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
pattern = "*." + tt.handlerDomain[2:]
|
||||
}
|
||||
|
||||
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||
@ -253,7 +253,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
||||
}
|
||||
|
||||
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
|
||||
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority)
|
||||
}
|
||||
|
||||
// Create and execute request
|
||||
@ -280,9 +280,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||
handler3 := &nbdns.MockHandler{}
|
||||
|
||||
// Add handlers in priority order
|
||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
|
||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
|
||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
|
||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
|
||||
|
||||
// Create test request
|
||||
r := new(dns.Msg)
|
||||
@ -416,7 +416,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
if op.action == "add" {
|
||||
handler := &nbdns.MockHandler{}
|
||||
handlers[op.priority] = handler
|
||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
||||
chain.AddHandler(op.pattern, handler, op.priority)
|
||||
} else {
|
||||
chain.RemoveHandler(op.pattern, op.priority)
|
||||
}
|
||||
@ -471,9 +471,9 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
r.SetQuestion(testQuery, dns.TypeA)
|
||||
|
||||
// Add handlers in mixed order
|
||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
|
||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
|
||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||
|
||||
// Test 1: Initial state with all three handlers
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
@ -653,7 +653,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
handler = mockHandler
|
||||
}
|
||||
|
||||
chain.AddHandler(pattern, handler, h.priority, nil)
|
||||
chain.AddHandler(pattern, handler, h.priority)
|
||||
}
|
||||
|
||||
// Execute request
|
||||
|
@ -29,10 +29,15 @@ func (d *localResolver) String() string {
|
||||
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
||||
}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
func (d *localResolver) id() handlerID {
|
||||
return "local-resolver"
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) > 0 {
|
||||
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
}
|
||||
|
||||
replyMessage := &dns.Msg{}
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@ -42,7 +41,12 @@ type Server interface {
|
||||
ProbeAvailability()
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[string]handlerWithStop
|
||||
type handlerID string
|
||||
|
||||
type nsGroupsByDomain struct {
|
||||
domain string
|
||||
groups []*nbdns.NameServerGroup
|
||||
}
|
||||
|
||||
// DefaultServer dns server object
|
||||
type DefaultServer struct {
|
||||
@ -52,7 +56,6 @@ type DefaultServer struct {
|
||||
mux sync.Mutex
|
||||
service service
|
||||
dnsMuxMap registeredHandlerMap
|
||||
handlerPriorities map[string]int
|
||||
localResolver *localResolver
|
||||
wgInterface WGIface
|
||||
hostManager hostManager
|
||||
@ -77,14 +80,17 @@ type handlerWithStop interface {
|
||||
dns.Handler
|
||||
stop()
|
||||
probeAvailability()
|
||||
id() handlerID
|
||||
}
|
||||
|
||||
type muxUpdate struct {
|
||||
type handlerWrapper struct {
|
||||
domain string
|
||||
handler handlerWithStop
|
||||
priority int
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[handlerID]handlerWrapper
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(
|
||||
ctx context.Context,
|
||||
@ -158,13 +164,12 @@ func newDefaultServer(
|
||||
) *DefaultServer {
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
disableSys: disableSys,
|
||||
service: dnsService,
|
||||
handlerChain: NewHandlerChain(),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
handlerPriorities: make(map[string]int),
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
disableSys: disableSys,
|
||||
service: dnsService,
|
||||
handlerChain: NewHandlerChain(),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
@ -192,8 +197,7 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
|
||||
log.Warn("skipping empty domain")
|
||||
continue
|
||||
}
|
||||
s.handlerChain.AddHandler(domain, handler, priority, nil)
|
||||
s.handlerPriorities[domain] = priority
|
||||
s.handlerChain.AddHandler(domain, handler, priority)
|
||||
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
||||
}
|
||||
}
|
||||
@ -209,14 +213,15 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||
log.Debugf("deregistering handler %v with priority %d", domains, priority)
|
||||
|
||||
for _, domain := range domains {
|
||||
if domain == "" {
|
||||
log.Warn("skipping empty domain")
|
||||
continue
|
||||
}
|
||||
|
||||
s.handlerChain.RemoveHandler(domain, priority)
|
||||
|
||||
// Only deregister from service if no handlers remain
|
||||
if !s.handlerChain.HasHandlers(domain) {
|
||||
if domain == "" {
|
||||
log.Warn("skipping empty domain")
|
||||
continue
|
||||
}
|
||||
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
||||
}
|
||||
}
|
||||
@ -283,14 +288,24 @@ func (s *DefaultServer) Stop() {
|
||||
|
||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||
|
||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||
s.hostsDNSHolder.set(hostsDnsList)
|
||||
|
||||
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
||||
if ok {
|
||||
// Check if there's any root handler
|
||||
var hasRootHandler bool
|
||||
for _, handler := range s.dnsMuxMap {
|
||||
if handler.domain == nbdns.RootZone {
|
||||
hasRootHandler = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasRootHandler {
|
||||
log.Debugf("on new host DNS config but skip to apply it")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
||||
s.addHostRootZone()
|
||||
}
|
||||
@ -364,7 +379,7 @@ func (s *DefaultServer) ProbeAvailability() {
|
||||
go func(mux handlerWithStop) {
|
||||
defer wg.Done()
|
||||
mux.probeAvailability()
|
||||
}(mux)
|
||||
}(mux.handler)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@ -419,8 +434,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
||||
var muxUpdates []muxUpdate
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, map[string]nbdns.SimpleRecord, error) {
|
||||
var muxUpdates []handlerWrapper
|
||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||
|
||||
for _, customZone := range customZones {
|
||||
@ -428,7 +443,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||
}
|
||||
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||
domain: customZone.Domain,
|
||||
handler: s.localResolver,
|
||||
priority: PriorityMatchDomain,
|
||||
@ -446,15 +461,59 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
||||
return muxUpdates, localRecords, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
|
||||
var muxUpdates []handlerWrapper
|
||||
|
||||
var muxUpdates []muxUpdate
|
||||
for _, nsGroup := range nameServerGroups {
|
||||
if len(nsGroup.NameServers) == 0 {
|
||||
log.Warn("received a nameserver group with empty nameserver list")
|
||||
continue
|
||||
}
|
||||
|
||||
if !nsGroup.Primary && len(nsGroup.Domains) == 0 {
|
||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||
}
|
||||
|
||||
for _, domain := range nsGroup.Domains {
|
||||
if domain == "" {
|
||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
groupedNS := groupNSGroupsByDomain(nameServerGroups)
|
||||
|
||||
for _, domainGroup := range groupedNS {
|
||||
basePriority := PriorityMatchDomain
|
||||
if domainGroup.domain == nbdns.RootZone {
|
||||
basePriority = PriorityDefault
|
||||
}
|
||||
|
||||
updates, err := s.createHandlersForDomainGroup(domainGroup, basePriority)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
muxUpdates = append(muxUpdates, updates...)
|
||||
}
|
||||
|
||||
return muxUpdates, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomain, basePriority int) ([]handlerWrapper, error) {
|
||||
var muxUpdates []handlerWrapper
|
||||
|
||||
for i, nsGroup := range domainGroup.groups {
|
||||
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
|
||||
priority := basePriority - i
|
||||
|
||||
// Check if we're about to overlap with the next priority tier
|
||||
if basePriority == PriorityMatchDomain && priority <= PriorityDefault {
|
||||
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
|
||||
domainGroup.domain, PriorityMatchDomain-PriorityDefault)
|
||||
break
|
||||
}
|
||||
|
||||
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
@ -462,10 +521,12 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
s.wgInterface.Address().Network,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
domainGroup.domain,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||
return nil, fmt.Errorf("create upstream resolver: %v", err)
|
||||
}
|
||||
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
|
||||
@ -489,78 +550,47 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
// after some period defined by upstream it tries to reactivate self by calling this hook
|
||||
// everything we need here is just to re-apply current configuration because it already
|
||||
// contains this upstream settings (temporal deactivation not removed it)
|
||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler, priority)
|
||||
|
||||
if nsGroup.Primary {
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: nbdns.RootZone,
|
||||
handler: handler,
|
||||
priority: PriorityDefault,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if len(nsGroup.Domains) == 0 {
|
||||
handler.stop()
|
||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||
}
|
||||
|
||||
for _, domain := range nsGroup.Domains {
|
||||
if domain == "" {
|
||||
handler.stop()
|
||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||
}
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: domain,
|
||||
handler: handler,
|
||||
priority: PriorityMatchDomain,
|
||||
})
|
||||
}
|
||||
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||
domain: domainGroup.domain,
|
||||
handler: handler,
|
||||
priority: priority,
|
||||
})
|
||||
}
|
||||
|
||||
return muxUpdates, nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
muxUpdateMap := make(registeredHandlerMap)
|
||||
handlersByPriority := make(map[string]int)
|
||||
|
||||
var isContainRootUpdate bool
|
||||
|
||||
// First register new handlers
|
||||
for _, update := range muxUpdates {
|
||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||
muxUpdateMap[update.domain] = update.handler
|
||||
handlersByPriority[update.domain] = update.priority
|
||||
|
||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||
existingHandler.stop()
|
||||
}
|
||||
|
||||
if update.domain == nbdns.RootZone {
|
||||
isContainRootUpdate = true
|
||||
}
|
||||
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||
for _, existing := range s.dnsMuxMap {
|
||||
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||
existing.handler.stop()
|
||||
}
|
||||
|
||||
// Then deregister old handlers not in the update
|
||||
for key, existingHandler := range s.dnsMuxMap {
|
||||
_, found := muxUpdateMap[key]
|
||||
if !found {
|
||||
if !isContainRootUpdate && key == nbdns.RootZone {
|
||||
muxUpdateMap := make(registeredHandlerMap)
|
||||
var containsRootUpdate bool
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
if update.domain == nbdns.RootZone {
|
||||
containsRootUpdate = true
|
||||
}
|
||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||
muxUpdateMap[update.handler.id()] = update
|
||||
}
|
||||
|
||||
// If there's no root update and we had a root handler, restore it
|
||||
if !containsRootUpdate {
|
||||
for _, existing := range s.dnsMuxMap {
|
||||
if existing.domain == nbdns.RootZone {
|
||||
s.addHostRootZone()
|
||||
existingHandler.stop()
|
||||
} else {
|
||||
existingHandler.stop()
|
||||
// Deregister with the priority that was used to register
|
||||
if oldPriority, ok := s.handlerPriorities[key]; ok {
|
||||
s.deregisterHandler([]string{key}, oldPriority)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
s.handlerPriorities = handlersByPriority
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||
@ -593,6 +623,7 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
||||
func (s *DefaultServer) upstreamCallbacks(
|
||||
nsGroup *nbdns.NameServerGroup,
|
||||
handler dns.Handler,
|
||||
priority int,
|
||||
) (deactivate func(error), reactivate func()) {
|
||||
var removeIndex map[string]int
|
||||
deactivate = func(err error) {
|
||||
@ -609,13 +640,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
if nsGroup.Primary {
|
||||
removeIndex[nbdns.RootZone] = -1
|
||||
s.currentConfig.RouteAll = false
|
||||
s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault)
|
||||
s.deregisterHandler([]string{nbdns.RootZone}, priority)
|
||||
}
|
||||
|
||||
for i, item := range s.currentConfig.Domains {
|
||||
if _, found := removeIndex[item.Domain]; found {
|
||||
s.currentConfig.Domains[i].Disabled = true
|
||||
s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain)
|
||||
s.deregisterHandler([]string{item.Domain}, priority)
|
||||
removeIndex[item.Domain] = i
|
||||
}
|
||||
}
|
||||
@ -635,8 +666,8 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
}
|
||||
|
||||
s.updateNSState(nsGroup, err, false)
|
||||
|
||||
}
|
||||
|
||||
reactivate = func() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
@ -646,7 +677,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
continue
|
||||
}
|
||||
s.currentConfig.Domains[i].Disabled = false
|
||||
s.registerHandler([]string{domain}, handler, PriorityMatchDomain)
|
||||
s.registerHandler([]string{domain}, handler, priority)
|
||||
}
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
@ -654,7 +685,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
|
||||
if nsGroup.Primary {
|
||||
s.currentConfig.RouteAll = true
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||
}
|
||||
|
||||
if s.hostManager != nil {
|
||||
@ -676,6 +707,7 @@ func (s *DefaultServer) addHostRootZone() {
|
||||
s.wgInterface.Address().Network,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
nbdns.RootZone,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||
@ -732,5 +764,34 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||
}
|
||||
return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ","))
|
||||
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
|
||||
}
|
||||
|
||||
// groupNSGroupsByDomain groups nameserver groups by their match domains
|
||||
func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain {
|
||||
domainMap := make(map[string][]*nbdns.NameServerGroup)
|
||||
|
||||
for _, group := range nsGroups {
|
||||
if group.Primary {
|
||||
domainMap[nbdns.RootZone] = append(domainMap[nbdns.RootZone], group)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, domain := range group.Domains {
|
||||
if domain == "" {
|
||||
continue
|
||||
}
|
||||
domainMap[domain] = append(domainMap[domain], group)
|
||||
}
|
||||
}
|
||||
|
||||
var result []nsGroupsByDomain
|
||||
for domain, groups := range domainMap {
|
||||
result = append(result, nsGroupsByDomain{
|
||||
domain: domain,
|
||||
groups: groups,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
@ -88,6 +89,18 @@ func init() {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []string
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, getNSHostPort(srv))
|
||||
}
|
||||
return &upstreamResolverBase{
|
||||
domain: domain,
|
||||
upstreamServers: srvs,
|
||||
cancel: func() {},
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
@ -140,15 +153,37 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler, nbdns.RootZone: dummyHandler},
|
||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
dummyHandler.id(): handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
||||
domain: nbdns.RootZone,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registeredHandlerMap{buildRecordKey(zoneRecords[0].Name, 1, 1): dummyHandler},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "New Config Should Succeed",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||
domain: buildRecordKey(zoneRecords[0].Name, 1, 1),
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@ -164,8 +199,19 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler},
|
||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||
expectedUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
||||
domain: "netbird.io",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
"local-resolver": handlerWrapper{
|
||||
domain: "netbird.cloud",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
@ -242,9 +288,15 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||
@ -252,9 +304,15 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
expectedLocalMap: make(registrationMap),
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: dummyHandler,
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||
@ -421,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
|
||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||
"id1": handlerWrapper{
|
||||
domain: zoneRecords[0].Name,
|
||||
handler: &localResolver{},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
}
|
||||
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
||||
dnsServer.updateSerial = 0
|
||||
|
||||
@ -562,9 +626,8 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
handlerChain: NewHandlerChain(),
|
||||
handlerPriorities: make(map[string]int),
|
||||
hostManager: hostManager,
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: hostManager,
|
||||
currentConfig: HostDNSConfig{
|
||||
Domains: []DomainConfig{
|
||||
{false, "domain0", false},
|
||||
@ -593,7 +656,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
}, nil)
|
||||
}, nil, 0)
|
||||
|
||||
deactivate(nil)
|
||||
expected := "domain0,domain2"
|
||||
@ -903,8 +966,8 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
||||
Subdomains: true,
|
||||
}
|
||||
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil)
|
||||
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil)
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
|
||||
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@ -959,3 +1022,421 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockHandler struct {
|
||||
Id string
|
||||
}
|
||||
|
||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||
func (m *mockHandler) stop() {}
|
||||
func (m *mockHandler) probeAvailability() {}
|
||||
func (m *mockHandler) id() handlerID { return handlerID(m.Id) }
|
||||
|
||||
type mockService struct{}
|
||||
|
||||
func (m *mockService) Listen() error { return nil }
|
||||
func (m *mockService) Stop() {}
|
||||
func (m *mockService) RuntimeIP() string { return "127.0.0.1" }
|
||||
func (m *mockService) RuntimePort() int { return 53 }
|
||||
func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||
func (m *mockService) DeregisterMux(string) {}
|
||||
|
||||
func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
baseMatchHandlers := registeredHandlerMap{
|
||||
"upstream-group1": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
"upstream-group2": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
}
|
||||
|
||||
baseRootHandlers := registeredHandlerMap{
|
||||
"upstream-root1": {
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root1",
|
||||
},
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
"upstream-root2": {
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root2",
|
||||
},
|
||||
priority: PriorityDefault - 1,
|
||||
},
|
||||
}
|
||||
|
||||
baseMixedHandlers := registeredHandlerMap{
|
||||
"upstream-group1": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
"upstream-group2": {
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
"upstream-other": {
|
||||
domain: "other.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-other",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialHandlers registeredHandlerMap
|
||||
updates []handlerWrapper
|
||||
expectedHandlers map[string]string // map[handlerID]domain
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Remove group1 from update",
|
||||
initialHandlers: baseMatchHandlers,
|
||||
updates: []handlerWrapper{
|
||||
// Only group2 remains
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-group2": "example.com",
|
||||
},
|
||||
description: "When group1 is not included in the update, it should be removed while group2 remains",
|
||||
},
|
||||
{
|
||||
name: "Remove group2 from update",
|
||||
initialHandlers: baseMatchHandlers,
|
||||
updates: []handlerWrapper{
|
||||
// Only group1 remains
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-group1": "example.com",
|
||||
},
|
||||
description: "When group2 is not included in the update, it should be removed while group1 remains",
|
||||
},
|
||||
{
|
||||
name: "Add group3 in first position",
|
||||
initialHandlers: baseMatchHandlers,
|
||||
updates: []handlerWrapper{
|
||||
// Add group3 with highest priority
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group3",
|
||||
},
|
||||
priority: PriorityMatchDomain + 1,
|
||||
},
|
||||
// Keep existing groups with their original priorities
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-group1": "example.com",
|
||||
"upstream-group2": "example.com",
|
||||
"upstream-group3": "example.com",
|
||||
},
|
||||
description: "When adding group3 with highest priority, it should be first in chain while maintaining existing groups",
|
||||
},
|
||||
{
|
||||
name: "Add group3 in last position",
|
||||
initialHandlers: baseMatchHandlers,
|
||||
updates: []handlerWrapper{
|
||||
// Keep existing groups with their original priorities
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
// Add group3 with lowest priority
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group3",
|
||||
},
|
||||
priority: PriorityMatchDomain - 2,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-group1": "example.com",
|
||||
"upstream-group2": "example.com",
|
||||
"upstream-group3": "example.com",
|
||||
},
|
||||
description: "When adding group3 with lowest priority, it should be last in chain while maintaining existing groups",
|
||||
},
|
||||
// Root zone tests
|
||||
{
|
||||
name: "Remove root1 from update",
|
||||
initialHandlers: baseRootHandlers,
|
||||
updates: []handlerWrapper{
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root2",
|
||||
},
|
||||
priority: PriorityDefault - 1,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-root2": ".",
|
||||
},
|
||||
description: "When root1 is not included in the update, it should be removed while root2 remains",
|
||||
},
|
||||
{
|
||||
name: "Remove root2 from update",
|
||||
initialHandlers: baseRootHandlers,
|
||||
updates: []handlerWrapper{
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root1",
|
||||
},
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-root1": ".",
|
||||
},
|
||||
description: "When root2 is not included in the update, it should be removed while root1 remains",
|
||||
},
|
||||
{
|
||||
name: "Add root3 in first position",
|
||||
initialHandlers: baseRootHandlers,
|
||||
updates: []handlerWrapper{
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root3",
|
||||
},
|
||||
priority: PriorityDefault + 1,
|
||||
},
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root1",
|
||||
},
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root2",
|
||||
},
|
||||
priority: PriorityDefault - 1,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-root1": ".",
|
||||
"upstream-root2": ".",
|
||||
"upstream-root3": ".",
|
||||
},
|
||||
description: "When adding root3 with highest priority, it should be first in chain while maintaining existing root handlers",
|
||||
},
|
||||
{
|
||||
name: "Add root3 in last position",
|
||||
initialHandlers: baseRootHandlers,
|
||||
updates: []handlerWrapper{
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root1",
|
||||
},
|
||||
priority: PriorityDefault,
|
||||
},
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root2",
|
||||
},
|
||||
priority: PriorityDefault - 1,
|
||||
},
|
||||
{
|
||||
domain: ".",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-root3",
|
||||
},
|
||||
priority: PriorityDefault - 2,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-root1": ".",
|
||||
"upstream-root2": ".",
|
||||
"upstream-root3": ".",
|
||||
},
|
||||
description: "When adding root3 with lowest priority, it should be last in chain while maintaining existing root handlers",
|
||||
},
|
||||
// Mixed domain tests
|
||||
{
|
||||
name: "Update with mixed domains - remove one of duplicate domain",
|
||||
initialHandlers: baseMixedHandlers,
|
||||
updates: []handlerWrapper{
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "other.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-other",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-group1": "example.com",
|
||||
"upstream-other": "other.com",
|
||||
},
|
||||
description: "When updating mixed domains, should correctly handle removal of one duplicate while maintaining other domains",
|
||||
},
|
||||
{
|
||||
name: "Update with mixed domains - add new domain",
|
||||
initialHandlers: baseMixedHandlers,
|
||||
updates: []handlerWrapper{
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group1",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "example.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-group2",
|
||||
},
|
||||
priority: PriorityMatchDomain - 1,
|
||||
},
|
||||
{
|
||||
domain: "other.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-other",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
{
|
||||
domain: "new.com",
|
||||
handler: &mockHandler{
|
||||
Id: "upstream-new",
|
||||
},
|
||||
priority: PriorityMatchDomain,
|
||||
},
|
||||
},
|
||||
expectedHandlers: map[string]string{
|
||||
"upstream-group1": "example.com",
|
||||
"upstream-group2": "example.com",
|
||||
"upstream-other": "other.com",
|
||||
"upstream-new": "new.com",
|
||||
},
|
||||
description: "When updating mixed domains, should maintain existing duplicates and add new domain",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := &DefaultServer{
|
||||
dnsMuxMap: tt.initialHandlers,
|
||||
handlerChain: NewHandlerChain(),
|
||||
service: &mockService{},
|
||||
}
|
||||
|
||||
// Perform the update
|
||||
server.updateMux(tt.updates)
|
||||
|
||||
// Verify the results
|
||||
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
|
||||
"Number of handlers after update doesn't match expected")
|
||||
|
||||
// Check each expected handler
|
||||
for id, expectedDomain := range tt.expectedHandlers {
|
||||
handler, exists := server.dnsMuxMap[handlerID(id)]
|
||||
assert.True(t, exists, "Expected handler %s not found", id)
|
||||
if exists {
|
||||
assert.Equal(t, expectedDomain, handler.domain,
|
||||
"Domain mismatch for handler %s", id)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no unexpected handlers exist
|
||||
for handlerID := range server.dnsMuxMap {
|
||||
_, expected := tt.expectedHandlers[string(handlerID)]
|
||||
assert.True(t, expected, "Unexpected handler found: %s", handlerID)
|
||||
}
|
||||
|
||||
// Verify the handlerChain state and order
|
||||
previousPriority := 0
|
||||
for _, chainEntry := range server.handlerChain.handlers {
|
||||
// Verify priority order
|
||||
if previousPriority > 0 {
|
||||
assert.True(t, chainEntry.Priority <= previousPriority,
|
||||
"Handlers in chain not properly ordered by priority")
|
||||
}
|
||||
previousPriority = chainEntry.Priority
|
||||
|
||||
// Verify handler exists in mux
|
||||
foundInMux := false
|
||||
for _, muxEntry := range server.dnsMuxMap {
|
||||
if chainEntry.Handler == muxEntry.handler &&
|
||||
chainEntry.Priority == muxEntry.priority &&
|
||||
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
||||
foundInMux = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundInMux,
|
||||
"Handler in chain not found in dnsMuxMap")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2,9 +2,13 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -40,6 +44,7 @@ type upstreamResolverBase struct {
|
||||
cancel context.CancelFunc
|
||||
upstreamClient upstreamClient
|
||||
upstreamServers []string
|
||||
domain string
|
||||
disabled bool
|
||||
failsCount atomic.Int32
|
||||
successCount atomic.Int32
|
||||
@ -53,12 +58,13 @@ type upstreamResolverBase struct {
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase {
|
||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
domain: domain,
|
||||
upstreamTimeout: upstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
failsTillDeact: failsTillDeact,
|
||||
@ -71,6 +77,17 @@ func (u *upstreamResolverBase) String() string {
|
||||
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
||||
}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
func (u *upstreamResolverBase) id() handlerID {
|
||||
servers := slices.Clone(u.upstreamServers)
|
||||
slices.Sort(servers)
|
||||
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(u.domain + ":"))
|
||||
hash.Write([]byte(strings.Join(servers, ",")))
|
||||
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||
return true
|
||||
}
|
||||
@ -87,7 +104,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
u.checkUpstreamFails(err)
|
||||
}()
|
||||
|
||||
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
||||
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
||||
if r.Extra == nil {
|
||||
r.SetEdns0(4096, false)
|
||||
@ -96,6 +113,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
log.Tracef("%s has been stopped", u)
|
||||
return
|
||||
default:
|
||||
}
|
||||
@ -112,41 +130,36 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
||||
log.WithError(err).WithField("upstream", upstream).
|
||||
Warn("got an error while connecting to upstream")
|
||||
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
|
||||
continue
|
||||
}
|
||||
u.failsCount.Add(1)
|
||||
log.WithError(err).WithField("upstream", upstream).
|
||||
Error("got other error while querying the upstream")
|
||||
return
|
||||
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if rm == nil {
|
||||
log.WithError(err).WithField("upstream", upstream).
|
||||
Warn("no response from upstream")
|
||||
return
|
||||
}
|
||||
// those checks need to be independent of each other due to memory address issues
|
||||
if !rm.Response {
|
||||
log.WithError(err).WithField("upstream", upstream).
|
||||
Warn("no response from upstream")
|
||||
return
|
||||
if rm == nil || !rm.Response {
|
||||
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
||||
continue
|
||||
}
|
||||
|
||||
u.successCount.Add(1)
|
||||
log.Tracef("took %s to query the upstream %s", t, upstream)
|
||||
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
|
||||
|
||||
err = w.WriteMsg(rm)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("got an error while writing the upstream resolver response")
|
||||
if err = w.WriteMsg(rm); err != nil {
|
||||
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
|
||||
}
|
||||
// count the fails only if they happen sequentially
|
||||
u.failsCount.Store(0)
|
||||
return
|
||||
}
|
||||
u.failsCount.Add(1)
|
||||
log.Error("all queries to the upstream nameservers failed with timeout")
|
||||
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(r, dns.RcodeServerFailure)
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// checkUpstreamFails counts fails and disables or enables upstream resolving
|
||||
|
@ -27,8 +27,9 @@ func newUpstreamResolver(
|
||||
_ *net.IPNet,
|
||||
statusRecorder *peer.Status,
|
||||
hostsDNSHolder *hostsDNSHolder,
|
||||
domain string,
|
||||
) (*upstreamResolver, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
c := &upstreamResolver{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
hostsDNSHolder: hostsDNSHolder,
|
||||
|
@ -23,8 +23,9 @@ func newUpstreamResolver(
|
||||
_ *net.IPNet,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
) (*upstreamResolver, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
nonIOS := &upstreamResolver{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
}
|
||||
|
@ -30,8 +30,9 @@ func newUpstreamResolver(
|
||||
net *net.IPNet,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
) (*upstreamResolverIOS, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
|
||||
ios := &upstreamResolverIOS{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
|
@ -20,6 +20,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
timeout time.Duration
|
||||
cancelCTX bool
|
||||
expectedAnswer string
|
||||
acceptNXDomain bool
|
||||
}{
|
||||
{
|
||||
name: "Should Resolve A Record",
|
||||
@ -36,11 +37,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
expectedAnswer: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "Should Not Resolve If Can't Connect To Both Servers",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
|
||||
timeout: 200 * time.Millisecond,
|
||||
responseShouldBeNil: true,
|
||||
name: "Should Not Resolve If Can't Connect To Both Servers",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
|
||||
timeout: 200 * time.Millisecond,
|
||||
acceptNXDomain: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Resolve If Parent Context Is Canceled",
|
||||
@ -51,14 +52,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
}
|
||||
// should resolve if first upstream times out
|
||||
// should not write when both fails
|
||||
// should not resolve if parent context is canceled
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil)
|
||||
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".")
|
||||
resolver.upstreamServers = testCase.InputServers
|
||||
resolver.upstreamTimeout = testCase.timeout
|
||||
if testCase.cancelCTX {
|
||||
@ -84,16 +82,22 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
t.Fatalf("should write a response message")
|
||||
}
|
||||
|
||||
foundAnswer := false
|
||||
for _, answer := range responseMSG.Answer {
|
||||
if strings.Contains(answer.String(), testCase.expectedAnswer) {
|
||||
foundAnswer = true
|
||||
break
|
||||
}
|
||||
if testCase.acceptNXDomain && responseMSG.Rcode == dns.RcodeNameError {
|
||||
return
|
||||
}
|
||||
|
||||
if !foundAnswer {
|
||||
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
|
||||
if testCase.expectedAnswer != "" {
|
||||
foundAnswer := false
|
||||
for _, answer := range responseMSG.Answer {
|
||||
if strings.Contains(answer.String(), testCase.expectedAnswer) {
|
||||
foundAnswer = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundAnswer {
|
||||
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -721,7 +721,9 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
func (d *Status) GetDNSStates() []NSGroupState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return d.nsGroupStates
|
||||
|
||||
// shallow copy is good enough, as slices fields are currently not updated
|
||||
return slices.Clone(d.nsGroupStates)
|
||||
}
|
||||
|
||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||
|
Loading…
x
Reference in New Issue
Block a user