mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-13 05:16:45 +02:00
[client] Support dns upstream failover for nameserver groups with same match domain (#3178)
This commit is contained in:
parent
5953b43ead
commit
488b697479
@ -12,7 +12,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
PriorityDNSRoute = 100
|
PriorityDNSRoute = 100
|
||||||
PriorityMatchDomain = 50
|
PriorityMatchDomain = 50
|
||||||
PriorityDefault = 0
|
PriorityDefault = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
type SubdomainMatcher interface {
|
type SubdomainMatcher interface {
|
||||||
@ -26,7 +26,6 @@ type HandlerEntry struct {
|
|||||||
Pattern string
|
Pattern string
|
||||||
OrigPattern string
|
OrigPattern string
|
||||||
IsWildcard bool
|
IsWildcard bool
|
||||||
StopHandler handlerWithStop
|
|
||||||
MatchSubdomains bool
|
MatchSubdomains bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,7 +63,7 @@ func (w *ResponseWriterChain) GetOrigPattern() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
@ -78,9 +77,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||||
if c.handlers[i].StopHandler != nil {
|
|
||||||
c.handlers[i].StopHandler.stop()
|
|
||||||
}
|
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -101,7 +97,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
Pattern: pattern,
|
Pattern: pattern,
|
||||||
OrigPattern: origPattern,
|
OrigPattern: origPattern,
|
||||||
IsWildcard: isWildcard,
|
IsWildcard: isWildcard,
|
||||||
StopHandler: stopHandler,
|
|
||||||
MatchSubdomains: matchSubdomains,
|
MatchSubdomains: matchSubdomains,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,9 +137,6 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
|||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
entry := c.handlers[i]
|
entry := c.handlers[i]
|
||||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||||
if entry.StopHandler != nil {
|
|
||||||
entry.StopHandler.stop()
|
|
||||||
}
|
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -180,8 +172,8 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
if log.IsLevelEnabled(log.TraceLevel) {
|
if log.IsLevelEnabled(log.TraceLevel) {
|
||||||
log.Tracef("current handlers (%d):", len(handlers))
|
log.Tracef("current handlers (%d):", len(handlers))
|
||||||
for _, h := range handlers {
|
for _, h := range handlers {
|
||||||
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
|
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -206,13 +198,13 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !matched {
|
if !matched {
|
||||||
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
|
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
|
||||||
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
|
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
|
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
|
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||||
|
|
||||||
chainWriter := &ResponseWriterChain{
|
chainWriter := &ResponseWriterChain{
|
||||||
ResponseWriter: w,
|
ResponseWriter: w,
|
||||||
|
@ -21,9 +21,9 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
|||||||
dnsRouteHandler := &nbdns.MockHandler{}
|
dnsRouteHandler := &nbdns.MockHandler{}
|
||||||
|
|
||||||
// Setup handlers with different priorities
|
// Setup handlers with different priorities
|
||||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
|
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
|
||||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
|
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
|
||||||
|
|
||||||
// Create test request
|
// Create test request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
@ -138,7 +138,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
pattern = "*." + tt.handlerDomain[2:]
|
pattern = "*." + tt.handlerDomain[2:]
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
|
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
|
||||||
|
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
@ -253,7 +253,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
|
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create and execute request
|
// Create and execute request
|
||||||
@ -280,9 +280,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
|||||||
handler3 := &nbdns.MockHandler{}
|
handler3 := &nbdns.MockHandler{}
|
||||||
|
|
||||||
// Add handlers in priority order
|
// Add handlers in priority order
|
||||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
|
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
|
||||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
|
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
|
||||||
|
|
||||||
// Create test request
|
// Create test request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
@ -416,7 +416,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
if op.action == "add" {
|
if op.action == "add" {
|
||||||
handler := &nbdns.MockHandler{}
|
handler := &nbdns.MockHandler{}
|
||||||
handlers[op.priority] = handler
|
handlers[op.priority] = handler
|
||||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
chain.AddHandler(op.pattern, handler, op.priority)
|
||||||
} else {
|
} else {
|
||||||
chain.RemoveHandler(op.pattern, op.priority)
|
chain.RemoveHandler(op.pattern, op.priority)
|
||||||
}
|
}
|
||||||
@ -471,9 +471,9 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
r.SetQuestion(testQuery, dns.TypeA)
|
r.SetQuestion(testQuery, dns.TypeA)
|
||||||
|
|
||||||
// Add handlers in mixed order
|
// Add handlers in mixed order
|
||||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
|
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
|
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||||
|
|
||||||
// Test 1: Initial state with all three handlers
|
// Test 1: Initial state with all three handlers
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
@ -653,7 +653,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
handler = mockHandler
|
handler = mockHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler(pattern, handler, h.priority, nil)
|
chain.AddHandler(pattern, handler, h.priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute request
|
// Execute request
|
||||||
|
@ -29,10 +29,15 @@ func (d *localResolver) String() string {
|
|||||||
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ID returns the unique handler ID
|
||||||
|
func (d *localResolver) id() handlerID {
|
||||||
|
return "local-resolver"
|
||||||
|
}
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) > 0 {
|
if len(r.Question) > 0 {
|
||||||
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
}
|
}
|
||||||
|
|
||||||
replyMessage := &dns.Msg{}
|
replyMessage := &dns.Msg{}
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@ -42,7 +41,12 @@ type Server interface {
|
|||||||
ProbeAvailability()
|
ProbeAvailability()
|
||||||
}
|
}
|
||||||
|
|
||||||
type registeredHandlerMap map[string]handlerWithStop
|
type handlerID string
|
||||||
|
|
||||||
|
type nsGroupsByDomain struct {
|
||||||
|
domain string
|
||||||
|
groups []*nbdns.NameServerGroup
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultServer dns server object
|
// DefaultServer dns server object
|
||||||
type DefaultServer struct {
|
type DefaultServer struct {
|
||||||
@ -52,7 +56,6 @@ type DefaultServer struct {
|
|||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
service service
|
service service
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxMap registeredHandlerMap
|
||||||
handlerPriorities map[string]int
|
|
||||||
localResolver *localResolver
|
localResolver *localResolver
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
hostManager hostManager
|
hostManager hostManager
|
||||||
@ -77,14 +80,17 @@ type handlerWithStop interface {
|
|||||||
dns.Handler
|
dns.Handler
|
||||||
stop()
|
stop()
|
||||||
probeAvailability()
|
probeAvailability()
|
||||||
|
id() handlerID
|
||||||
}
|
}
|
||||||
|
|
||||||
type muxUpdate struct {
|
type handlerWrapper struct {
|
||||||
domain string
|
domain string
|
||||||
handler handlerWithStop
|
handler handlerWithStop
|
||||||
priority int
|
priority int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type registeredHandlerMap map[handlerID]handlerWrapper
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(
|
func NewDefaultServer(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@ -158,13 +164,12 @@ func newDefaultServer(
|
|||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: stop,
|
ctxCancel: stop,
|
||||||
disableSys: disableSys,
|
disableSys: disableSys,
|
||||||
service: dnsService,
|
service: dnsService,
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
handlerPriorities: make(map[string]int),
|
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
@ -192,8 +197,7 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
|
|||||||
log.Warn("skipping empty domain")
|
log.Warn("skipping empty domain")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.handlerChain.AddHandler(domain, handler, priority, nil)
|
s.handlerChain.AddHandler(domain, handler, priority)
|
||||||
s.handlerPriorities[domain] = priority
|
|
||||||
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -209,14 +213,15 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
|||||||
log.Debugf("deregistering handler %v with priority %d", domains, priority)
|
log.Debugf("deregistering handler %v with priority %d", domains, priority)
|
||||||
|
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
|
if domain == "" {
|
||||||
|
log.Warn("skipping empty domain")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
s.handlerChain.RemoveHandler(domain, priority)
|
s.handlerChain.RemoveHandler(domain, priority)
|
||||||
|
|
||||||
// Only deregister from service if no handlers remain
|
// Only deregister from service if no handlers remain
|
||||||
if !s.handlerChain.HasHandlers(domain) {
|
if !s.handlerChain.HasHandlers(domain) {
|
||||||
if domain == "" {
|
|
||||||
log.Warn("skipping empty domain")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -283,14 +288,24 @@ func (s *DefaultServer) Stop() {
|
|||||||
|
|
||||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||||
|
|
||||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||||
s.hostsDNSHolder.set(hostsDnsList)
|
s.hostsDNSHolder.set(hostsDnsList)
|
||||||
|
|
||||||
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
// Check if there's any root handler
|
||||||
if ok {
|
var hasRootHandler bool
|
||||||
|
for _, handler := range s.dnsMuxMap {
|
||||||
|
if handler.domain == nbdns.RootZone {
|
||||||
|
hasRootHandler = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasRootHandler {
|
||||||
log.Debugf("on new host DNS config but skip to apply it")
|
log.Debugf("on new host DNS config but skip to apply it")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
||||||
s.addHostRootZone()
|
s.addHostRootZone()
|
||||||
}
|
}
|
||||||
@ -364,7 +379,7 @@ func (s *DefaultServer) ProbeAvailability() {
|
|||||||
go func(mux handlerWithStop) {
|
go func(mux handlerWithStop) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
mux.probeAvailability()
|
mux.probeAvailability()
|
||||||
}(mux)
|
}(mux.handler)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
@ -419,8 +434,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, map[string]nbdns.SimpleRecord, error) {
|
||||||
var muxUpdates []muxUpdate
|
var muxUpdates []handlerWrapper
|
||||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||||
|
|
||||||
for _, customZone := range customZones {
|
for _, customZone := range customZones {
|
||||||
@ -428,7 +443,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||||
}
|
}
|
||||||
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||||
domain: customZone.Domain,
|
domain: customZone.Domain,
|
||||||
handler: s.localResolver,
|
handler: s.localResolver,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
@ -446,15 +461,59 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
return muxUpdates, localRecords, nil
|
return muxUpdates, localRecords, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
|
||||||
|
var muxUpdates []handlerWrapper
|
||||||
|
|
||||||
var muxUpdates []muxUpdate
|
|
||||||
for _, nsGroup := range nameServerGroups {
|
for _, nsGroup := range nameServerGroups {
|
||||||
if len(nsGroup.NameServers) == 0 {
|
if len(nsGroup.NameServers) == 0 {
|
||||||
log.Warn("received a nameserver group with empty nameserver list")
|
log.Warn("received a nameserver group with empty nameserver list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !nsGroup.Primary && len(nsGroup.Domains) == 0 {
|
||||||
|
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range nsGroup.Domains {
|
||||||
|
if domain == "" {
|
||||||
|
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
groupedNS := groupNSGroupsByDomain(nameServerGroups)
|
||||||
|
|
||||||
|
for _, domainGroup := range groupedNS {
|
||||||
|
basePriority := PriorityMatchDomain
|
||||||
|
if domainGroup.domain == nbdns.RootZone {
|
||||||
|
basePriority = PriorityDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
updates, err := s.createHandlersForDomainGroup(domainGroup, basePriority)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
muxUpdates = append(muxUpdates, updates...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return muxUpdates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomain, basePriority int) ([]handlerWrapper, error) {
|
||||||
|
var muxUpdates []handlerWrapper
|
||||||
|
|
||||||
|
for i, nsGroup := range domainGroup.groups {
|
||||||
|
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
|
||||||
|
priority := basePriority - i
|
||||||
|
|
||||||
|
// Check if we're about to overlap with the next priority tier
|
||||||
|
if basePriority == PriorityMatchDomain && priority <= PriorityDefault {
|
||||||
|
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
|
||||||
|
domainGroup.domain, PriorityMatchDomain-PriorityDefault)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
|
||||||
handler, err := newUpstreamResolver(
|
handler, err := newUpstreamResolver(
|
||||||
s.ctx,
|
s.ctx,
|
||||||
s.wgInterface.Name(),
|
s.wgInterface.Name(),
|
||||||
@ -462,10 +521,12 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
s.wgInterface.Address().Network,
|
s.wgInterface.Address().Network,
|
||||||
s.statusRecorder,
|
s.statusRecorder,
|
||||||
s.hostsDNSHolder,
|
s.hostsDNSHolder,
|
||||||
|
domainGroup.domain,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
|
return nil, fmt.Errorf("create upstream resolver: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ns := range nsGroup.NameServers {
|
for _, ns := range nsGroup.NameServers {
|
||||||
if ns.NSType != nbdns.UDPNameServerType {
|
if ns.NSType != nbdns.UDPNameServerType {
|
||||||
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
|
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
|
||||||
@ -489,78 +550,47 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
// after some period defined by upstream it tries to reactivate self by calling this hook
|
// after some period defined by upstream it tries to reactivate self by calling this hook
|
||||||
// everything we need here is just to re-apply current configuration because it already
|
// everything we need here is just to re-apply current configuration because it already
|
||||||
// contains this upstream settings (temporal deactivation not removed it)
|
// contains this upstream settings (temporal deactivation not removed it)
|
||||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler, priority)
|
||||||
|
|
||||||
if nsGroup.Primary {
|
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
domain: domainGroup.domain,
|
||||||
domain: nbdns.RootZone,
|
handler: handler,
|
||||||
handler: handler,
|
priority: priority,
|
||||||
priority: PriorityDefault,
|
})
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(nsGroup.Domains) == 0 {
|
|
||||||
handler.stop()
|
|
||||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, domain := range nsGroup.Domains {
|
|
||||||
if domain == "" {
|
|
||||||
handler.stop()
|
|
||||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
|
||||||
}
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: domain,
|
|
||||||
handler: handler,
|
|
||||||
priority: PriorityMatchDomain,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return muxUpdates, nil
|
return muxUpdates, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||||
handlersByPriority := make(map[string]int)
|
for _, existing := range s.dnsMuxMap {
|
||||||
|
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||||
var isContainRootUpdate bool
|
existing.handler.stop()
|
||||||
|
|
||||||
// First register new handlers
|
|
||||||
for _, update := range muxUpdates {
|
|
||||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
|
||||||
muxUpdateMap[update.domain] = update.handler
|
|
||||||
handlersByPriority[update.domain] = update.priority
|
|
||||||
|
|
||||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
|
||||||
existingHandler.stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if update.domain == nbdns.RootZone {
|
|
||||||
isContainRootUpdate = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then deregister old handlers not in the update
|
muxUpdateMap := make(registeredHandlerMap)
|
||||||
for key, existingHandler := range s.dnsMuxMap {
|
var containsRootUpdate bool
|
||||||
_, found := muxUpdateMap[key]
|
|
||||||
if !found {
|
for _, update := range muxUpdates {
|
||||||
if !isContainRootUpdate && key == nbdns.RootZone {
|
if update.domain == nbdns.RootZone {
|
||||||
|
containsRootUpdate = true
|
||||||
|
}
|
||||||
|
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||||
|
muxUpdateMap[update.handler.id()] = update
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there's no root update and we had a root handler, restore it
|
||||||
|
if !containsRootUpdate {
|
||||||
|
for _, existing := range s.dnsMuxMap {
|
||||||
|
if existing.domain == nbdns.RootZone {
|
||||||
s.addHostRootZone()
|
s.addHostRootZone()
|
||||||
existingHandler.stop()
|
break
|
||||||
} else {
|
|
||||||
existingHandler.stop()
|
|
||||||
// Deregister with the priority that was used to register
|
|
||||||
if oldPriority, ok := s.handlerPriorities[key]; ok {
|
|
||||||
s.deregisterHandler([]string{key}, oldPriority)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.dnsMuxMap = muxUpdateMap
|
s.dnsMuxMap = muxUpdateMap
|
||||||
s.handlerPriorities = handlersByPriority
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||||
@ -593,6 +623,7 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
|||||||
func (s *DefaultServer) upstreamCallbacks(
|
func (s *DefaultServer) upstreamCallbacks(
|
||||||
nsGroup *nbdns.NameServerGroup,
|
nsGroup *nbdns.NameServerGroup,
|
||||||
handler dns.Handler,
|
handler dns.Handler,
|
||||||
|
priority int,
|
||||||
) (deactivate func(error), reactivate func()) {
|
) (deactivate func(error), reactivate func()) {
|
||||||
var removeIndex map[string]int
|
var removeIndex map[string]int
|
||||||
deactivate = func(err error) {
|
deactivate = func(err error) {
|
||||||
@ -609,13 +640,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
removeIndex[nbdns.RootZone] = -1
|
removeIndex[nbdns.RootZone] = -1
|
||||||
s.currentConfig.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault)
|
s.deregisterHandler([]string{nbdns.RootZone}, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, item := range s.currentConfig.Domains {
|
for i, item := range s.currentConfig.Domains {
|
||||||
if _, found := removeIndex[item.Domain]; found {
|
if _, found := removeIndex[item.Domain]; found {
|
||||||
s.currentConfig.Domains[i].Disabled = true
|
s.currentConfig.Domains[i].Disabled = true
|
||||||
s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain)
|
s.deregisterHandler([]string{item.Domain}, priority)
|
||||||
removeIndex[item.Domain] = i
|
removeIndex[item.Domain] = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -635,8 +666,8 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.updateNSState(nsGroup, err, false)
|
s.updateNSState(nsGroup, err, false)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
reactivate = func() {
|
reactivate = func() {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
@ -646,7 +677,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.currentConfig.Domains[i].Disabled = false
|
s.currentConfig.Domains[i].Disabled = false
|
||||||
s.registerHandler([]string{domain}, handler, PriorityMatchDomain)
|
s.registerHandler([]string{domain}, handler, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
@ -654,7 +685,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
|
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
s.currentConfig.RouteAll = true
|
s.currentConfig.RouteAll = true
|
||||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.hostManager != nil {
|
if s.hostManager != nil {
|
||||||
@ -676,6 +707,7 @@ func (s *DefaultServer) addHostRootZone() {
|
|||||||
s.wgInterface.Address().Network,
|
s.wgInterface.Address().Network,
|
||||||
s.statusRecorder,
|
s.statusRecorder,
|
||||||
s.hostsDNSHolder,
|
s.hostsDNSHolder,
|
||||||
|
nbdns.RootZone,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||||
@ -732,5 +764,34 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
|||||||
for _, ns := range nsGroup.NameServers {
|
for _, ns := range nsGroup.NameServers {
|
||||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s_%s_%s", nsGroup.ID, nsGroup.Name, strings.Join(servers, ","))
|
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupNSGroupsByDomain groups nameserver groups by their match domains
|
||||||
|
func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain {
|
||||||
|
domainMap := make(map[string][]*nbdns.NameServerGroup)
|
||||||
|
|
||||||
|
for _, group := range nsGroups {
|
||||||
|
if group.Primary {
|
||||||
|
domainMap[nbdns.RootZone] = append(domainMap[nbdns.RootZone], group)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range group.Domains {
|
||||||
|
if domain == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domainMap[domain] = append(domainMap[domain], group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []nsGroupsByDomain
|
||||||
|
for domain, groups := range domainMap {
|
||||||
|
result = append(result, nsGroupsByDomain{
|
||||||
|
domain: domain,
|
||||||
|
groups: groups,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
@ -88,6 +89,18 @@ func init() {
|
|||||||
formatter.SetTextFormatter(log.StandardLogger())
|
formatter.SetTextFormatter(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||||
|
var srvs []string
|
||||||
|
for _, srv := range servers {
|
||||||
|
srvs = append(srvs, getNSHostPort(srv))
|
||||||
|
}
|
||||||
|
return &upstreamResolverBase{
|
||||||
|
domain: domain,
|
||||||
|
upstreamServers: srvs,
|
||||||
|
cancel: func() {},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
@ -140,15 +153,37 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler, nbdns.RootZone: dummyHandler},
|
expectedUpstreamMap: registeredHandlerMap{
|
||||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
||||||
|
domain: "netbird.io",
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
dummyHandler.id(): handlerWrapper{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
||||||
|
domain: nbdns.RootZone,
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "New Config Should Succeed",
|
name: "New Config Should Succeed",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||||
initUpstreamMap: registeredHandlerMap{buildRecordKey(zoneRecords[0].Name, 1, 1): dummyHandler},
|
initUpstreamMap: registeredHandlerMap{
|
||||||
initSerial: 0,
|
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||||
inputSerial: 1,
|
domain: buildRecordKey(zoneRecords[0].Name, 1, 1),
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
initSerial: 0,
|
||||||
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
ServiceEnable: true,
|
ServiceEnable: true,
|
||||||
CustomZones: []nbdns.CustomZone{
|
CustomZones: []nbdns.CustomZone{
|
||||||
@ -164,8 +199,19 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler},
|
expectedUpstreamMap: registeredHandlerMap{
|
||||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
||||||
|
domain: "netbird.io",
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
"local-resolver": handlerWrapper{
|
||||||
|
domain: "netbird.cloud",
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Smaller Config Serial Should Be Skipped",
|
name: "Smaller Config Serial Should Be Skipped",
|
||||||
@ -242,9 +288,15 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
shouldFail: true,
|
shouldFail: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Empty Config Should Succeed and Clean Maps",
|
name: "Empty Config Should Succeed and Clean Maps",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||||
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
|
initUpstreamMap: registeredHandlerMap{
|
||||||
|
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||||
@ -252,9 +304,15 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
expectedLocalMap: make(registrationMap),
|
expectedLocalMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Disabled Service Should clean map",
|
name: "Disabled Service Should clean map",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||||
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
|
initUpstreamMap: registeredHandlerMap{
|
||||||
|
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: dummyHandler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||||
@ -421,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
|
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||||
|
"id1": handlerWrapper{
|
||||||
|
domain: zoneRecords[0].Name,
|
||||||
|
handler: &localResolver{},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
}
|
||||||
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
||||||
dnsServer.updateSerial = 0
|
dnsServer.updateSerial = 0
|
||||||
|
|
||||||
@ -562,9 +626,8 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
handlerPriorities: make(map[string]int),
|
hostManager: hostManager,
|
||||||
hostManager: hostManager,
|
|
||||||
currentConfig: HostDNSConfig{
|
currentConfig: HostDNSConfig{
|
||||||
Domains: []DomainConfig{
|
Domains: []DomainConfig{
|
||||||
{false, "domain0", false},
|
{false, "domain0", false},
|
||||||
@ -593,7 +656,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
NameServers: []nbdns.NameServer{
|
NameServers: []nbdns.NameServer{
|
||||||
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
},
|
},
|
||||||
}, nil)
|
}, nil, 0)
|
||||||
|
|
||||||
deactivate(nil)
|
deactivate(nil)
|
||||||
expected := "domain0,domain2"
|
expected := "domain0,domain2"
|
||||||
@ -903,8 +966,8 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
|||||||
Subdomains: true,
|
Subdomains: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil)
|
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
|
||||||
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil)
|
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@ -959,3 +1022,421 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockHandler struct {
|
||||||
|
Id string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||||
|
func (m *mockHandler) stop() {}
|
||||||
|
func (m *mockHandler) probeAvailability() {}
|
||||||
|
func (m *mockHandler) id() handlerID { return handlerID(m.Id) }
|
||||||
|
|
||||||
|
type mockService struct{}
|
||||||
|
|
||||||
|
func (m *mockService) Listen() error { return nil }
|
||||||
|
func (m *mockService) Stop() {}
|
||||||
|
func (m *mockService) RuntimeIP() string { return "127.0.0.1" }
|
||||||
|
func (m *mockService) RuntimePort() int { return 53 }
|
||||||
|
func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||||
|
func (m *mockService) DeregisterMux(string) {}
|
||||||
|
|
||||||
|
func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||||
|
baseMatchHandlers := registeredHandlerMap{
|
||||||
|
"upstream-group1": {
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
"upstream-group2": {
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
baseRootHandlers := registeredHandlerMap{
|
||||||
|
"upstream-root1": {
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root1",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
"upstream-root2": {
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root2",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
baseMixedHandlers := registeredHandlerMap{
|
||||||
|
"upstream-group1": {
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
"upstream-group2": {
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
"upstream-other": {
|
||||||
|
domain: "other.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-other",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
initialHandlers registeredHandlerMap
|
||||||
|
updates []handlerWrapper
|
||||||
|
expectedHandlers map[string]string // map[handlerID]domain
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Remove group1 from update",
|
||||||
|
initialHandlers: baseMatchHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
// Only group2 remains
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group2": "example.com",
|
||||||
|
},
|
||||||
|
description: "When group1 is not included in the update, it should be removed while group2 remains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Remove group2 from update",
|
||||||
|
initialHandlers: baseMatchHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
// Only group1 remains
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
},
|
||||||
|
description: "When group2 is not included in the update, it should be removed while group1 remains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add group3 in first position",
|
||||||
|
initialHandlers: baseMatchHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
// Add group3 with highest priority
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group3",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain + 1,
|
||||||
|
},
|
||||||
|
// Keep existing groups with their original priorities
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
"upstream-group2": "example.com",
|
||||||
|
"upstream-group3": "example.com",
|
||||||
|
},
|
||||||
|
description: "When adding group3 with highest priority, it should be first in chain while maintaining existing groups",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add group3 in last position",
|
||||||
|
initialHandlers: baseMatchHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
// Keep existing groups with their original priorities
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
// Add group3 with lowest priority
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group3",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
"upstream-group2": "example.com",
|
||||||
|
"upstream-group3": "example.com",
|
||||||
|
},
|
||||||
|
description: "When adding group3 with lowest priority, it should be last in chain while maintaining existing groups",
|
||||||
|
},
|
||||||
|
// Root zone tests
|
||||||
|
{
|
||||||
|
name: "Remove root1 from update",
|
||||||
|
initialHandlers: baseRootHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root2",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-root2": ".",
|
||||||
|
},
|
||||||
|
description: "When root1 is not included in the update, it should be removed while root2 remains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Remove root2 from update",
|
||||||
|
initialHandlers: baseRootHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root1",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-root1": ".",
|
||||||
|
},
|
||||||
|
description: "When root2 is not included in the update, it should be removed while root1 remains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add root3 in first position",
|
||||||
|
initialHandlers: baseRootHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root3",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault + 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root1",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root2",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-root1": ".",
|
||||||
|
"upstream-root2": ".",
|
||||||
|
"upstream-root3": ".",
|
||||||
|
},
|
||||||
|
description: "When adding root3 with highest priority, it should be first in chain while maintaining existing root handlers",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add root3 in last position",
|
||||||
|
initialHandlers: baseRootHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root1",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root2",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: ".",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-root3",
|
||||||
|
},
|
||||||
|
priority: PriorityDefault - 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-root1": ".",
|
||||||
|
"upstream-root2": ".",
|
||||||
|
"upstream-root3": ".",
|
||||||
|
},
|
||||||
|
description: "When adding root3 with lowest priority, it should be last in chain while maintaining existing root handlers",
|
||||||
|
},
|
||||||
|
// Mixed domain tests
|
||||||
|
{
|
||||||
|
name: "Update with mixed domains - remove one of duplicate domain",
|
||||||
|
initialHandlers: baseMixedHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "other.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-other",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
"upstream-other": "other.com",
|
||||||
|
},
|
||||||
|
description: "When updating mixed domains, should correctly handle removal of one duplicate while maintaining other domains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Update with mixed domains - add new domain",
|
||||||
|
initialHandlers: baseMixedHandlers,
|
||||||
|
updates: []handlerWrapper{
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group1",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "example.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-group2",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain - 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "other.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-other",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "new.com",
|
||||||
|
handler: &mockHandler{
|
||||||
|
Id: "upstream-new",
|
||||||
|
},
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedHandlers: map[string]string{
|
||||||
|
"upstream-group1": "example.com",
|
||||||
|
"upstream-group2": "example.com",
|
||||||
|
"upstream-other": "other.com",
|
||||||
|
"upstream-new": "new.com",
|
||||||
|
},
|
||||||
|
description: "When updating mixed domains, should maintain existing duplicates and add new domain",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := &DefaultServer{
|
||||||
|
dnsMuxMap: tt.initialHandlers,
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
service: &mockService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform the update
|
||||||
|
server.updateMux(tt.updates)
|
||||||
|
|
||||||
|
// Verify the results
|
||||||
|
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
|
||||||
|
"Number of handlers after update doesn't match expected")
|
||||||
|
|
||||||
|
// Check each expected handler
|
||||||
|
for id, expectedDomain := range tt.expectedHandlers {
|
||||||
|
handler, exists := server.dnsMuxMap[handlerID(id)]
|
||||||
|
assert.True(t, exists, "Expected handler %s not found", id)
|
||||||
|
if exists {
|
||||||
|
assert.Equal(t, expectedDomain, handler.domain,
|
||||||
|
"Domain mismatch for handler %s", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no unexpected handlers exist
|
||||||
|
for handlerID := range server.dnsMuxMap {
|
||||||
|
_, expected := tt.expectedHandlers[string(handlerID)]
|
||||||
|
assert.True(t, expected, "Unexpected handler found: %s", handlerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the handlerChain state and order
|
||||||
|
previousPriority := 0
|
||||||
|
for _, chainEntry := range server.handlerChain.handlers {
|
||||||
|
// Verify priority order
|
||||||
|
if previousPriority > 0 {
|
||||||
|
assert.True(t, chainEntry.Priority <= previousPriority,
|
||||||
|
"Handlers in chain not properly ordered by priority")
|
||||||
|
}
|
||||||
|
previousPriority = chainEntry.Priority
|
||||||
|
|
||||||
|
// Verify handler exists in mux
|
||||||
|
foundInMux := false
|
||||||
|
for _, muxEntry := range server.dnsMuxMap {
|
||||||
|
if chainEntry.Handler == muxEntry.handler &&
|
||||||
|
chainEntry.Priority == muxEntry.priority &&
|
||||||
|
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
||||||
|
foundInMux = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, foundInMux,
|
||||||
|
"Handler in chain not found in dnsMuxMap")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -2,9 +2,13 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@ -40,6 +44,7 @@ type upstreamResolverBase struct {
|
|||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
upstreamClient upstreamClient
|
upstreamClient upstreamClient
|
||||||
upstreamServers []string
|
upstreamServers []string
|
||||||
|
domain string
|
||||||
disabled bool
|
disabled bool
|
||||||
failsCount atomic.Int32
|
failsCount atomic.Int32
|
||||||
successCount atomic.Int32
|
successCount atomic.Int32
|
||||||
@ -53,12 +58,13 @@ type upstreamResolverBase struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase {
|
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
return &upstreamResolverBase{
|
return &upstreamResolverBase{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
|
domain: domain,
|
||||||
upstreamTimeout: upstreamTimeout,
|
upstreamTimeout: upstreamTimeout,
|
||||||
reactivatePeriod: reactivatePeriod,
|
reactivatePeriod: reactivatePeriod,
|
||||||
failsTillDeact: failsTillDeact,
|
failsTillDeact: failsTillDeact,
|
||||||
@ -71,6 +77,17 @@ func (u *upstreamResolverBase) String() string {
|
|||||||
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ID returns the unique handler ID
|
||||||
|
func (u *upstreamResolverBase) id() handlerID {
|
||||||
|
servers := slices.Clone(u.upstreamServers)
|
||||||
|
slices.Sort(servers)
|
||||||
|
|
||||||
|
hash := sha256.New()
|
||||||
|
hash.Write([]byte(u.domain + ":"))
|
||||||
|
hash.Write([]byte(strings.Join(servers, ",")))
|
||||||
|
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||||
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -87,7 +104,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
u.checkUpstreamFails(err)
|
u.checkUpstreamFails(err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
||||||
if r.Extra == nil {
|
if r.Extra == nil {
|
||||||
r.SetEdns0(4096, false)
|
r.SetEdns0(4096, false)
|
||||||
@ -96,6 +113,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-u.ctx.Done():
|
case <-u.ctx.Done():
|
||||||
|
log.Tracef("%s has been stopped", u)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@ -112,41 +130,36 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
||||||
log.WithError(err).WithField("upstream", upstream).
|
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
|
||||||
Warn("got an error while connecting to upstream")
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
u.failsCount.Add(1)
|
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
|
||||||
log.WithError(err).WithField("upstream", upstream).
|
continue
|
||||||
Error("got other error while querying the upstream")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rm == nil {
|
if rm == nil || !rm.Response {
|
||||||
log.WithError(err).WithField("upstream", upstream).
|
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
||||||
Warn("no response from upstream")
|
continue
|
||||||
return
|
|
||||||
}
|
|
||||||
// those checks need to be independent of each other due to memory address issues
|
|
||||||
if !rm.Response {
|
|
||||||
log.WithError(err).WithField("upstream", upstream).
|
|
||||||
Warn("no response from upstream")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
u.successCount.Add(1)
|
u.successCount.Add(1)
|
||||||
log.Tracef("took %s to query the upstream %s", t, upstream)
|
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
|
||||||
|
|
||||||
err = w.WriteMsg(rm)
|
if err = w.WriteMsg(rm); err != nil {
|
||||||
if err != nil {
|
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
|
||||||
log.WithError(err).Error("got an error while writing the upstream resolver response")
|
|
||||||
}
|
}
|
||||||
// count the fails only if they happen sequentially
|
// count the fails only if they happen sequentially
|
||||||
u.failsCount.Store(0)
|
u.failsCount.Store(0)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
u.failsCount.Add(1)
|
u.failsCount.Add(1)
|
||||||
log.Error("all queries to the upstream nameservers failed with timeout")
|
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetRcode(r, dns.RcodeServerFailure)
|
||||||
|
if err := w.WriteMsg(m); err != nil {
|
||||||
|
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkUpstreamFails counts fails and disables or enables upstream resolving
|
// checkUpstreamFails counts fails and disables or enables upstream resolving
|
||||||
|
@ -27,8 +27,9 @@ func newUpstreamResolver(
|
|||||||
_ *net.IPNet,
|
_ *net.IPNet,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
hostsDNSHolder *hostsDNSHolder,
|
hostsDNSHolder *hostsDNSHolder,
|
||||||
|
domain string,
|
||||||
) (*upstreamResolver, error) {
|
) (*upstreamResolver, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||||
c := &upstreamResolver{
|
c := &upstreamResolver{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
hostsDNSHolder: hostsDNSHolder,
|
hostsDNSHolder: hostsDNSHolder,
|
||||||
|
@ -23,8 +23,9 @@ func newUpstreamResolver(
|
|||||||
_ *net.IPNet,
|
_ *net.IPNet,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
|
domain string,
|
||||||
) (*upstreamResolver, error) {
|
) (*upstreamResolver, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||||
nonIOS := &upstreamResolver{
|
nonIOS := &upstreamResolver{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
}
|
}
|
||||||
|
@ -30,8 +30,9 @@ func newUpstreamResolver(
|
|||||||
net *net.IPNet,
|
net *net.IPNet,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
|
domain string,
|
||||||
) (*upstreamResolverIOS, error) {
|
) (*upstreamResolverIOS, error) {
|
||||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||||
|
|
||||||
ios := &upstreamResolverIOS{
|
ios := &upstreamResolverIOS{
|
||||||
upstreamResolverBase: upstreamResolverBase,
|
upstreamResolverBase: upstreamResolverBase,
|
||||||
|
@ -20,6 +20,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cancelCTX bool
|
cancelCTX bool
|
||||||
expectedAnswer string
|
expectedAnswer string
|
||||||
|
acceptNXDomain bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Should Resolve A Record",
|
name: "Should Resolve A Record",
|
||||||
@ -36,11 +37,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
expectedAnswer: "1.1.1.1",
|
expectedAnswer: "1.1.1.1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Should Not Resolve If Can't Connect To Both Servers",
|
name: "Should Not Resolve If Can't Connect To Both Servers",
|
||||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||||
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
|
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
|
||||||
timeout: 200 * time.Millisecond,
|
timeout: 200 * time.Millisecond,
|
||||||
responseShouldBeNil: true,
|
acceptNXDomain: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Should Not Resolve If Parent Context Is Canceled",
|
name: "Should Not Resolve If Parent Context Is Canceled",
|
||||||
@ -51,14 +52,11 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
responseShouldBeNil: true,
|
responseShouldBeNil: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// should resolve if first upstream times out
|
|
||||||
// should not write when both fails
|
|
||||||
// should not resolve if parent context is canceled
|
|
||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil)
|
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".")
|
||||||
resolver.upstreamServers = testCase.InputServers
|
resolver.upstreamServers = testCase.InputServers
|
||||||
resolver.upstreamTimeout = testCase.timeout
|
resolver.upstreamTimeout = testCase.timeout
|
||||||
if testCase.cancelCTX {
|
if testCase.cancelCTX {
|
||||||
@ -84,16 +82,22 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
t.Fatalf("should write a response message")
|
t.Fatalf("should write a response message")
|
||||||
}
|
}
|
||||||
|
|
||||||
foundAnswer := false
|
if testCase.acceptNXDomain && responseMSG.Rcode == dns.RcodeNameError {
|
||||||
for _, answer := range responseMSG.Answer {
|
return
|
||||||
if strings.Contains(answer.String(), testCase.expectedAnswer) {
|
|
||||||
foundAnswer = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !foundAnswer {
|
if testCase.expectedAnswer != "" {
|
||||||
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
|
foundAnswer := false
|
||||||
|
for _, answer := range responseMSG.Answer {
|
||||||
|
if strings.Contains(answer.String(), testCase.expectedAnswer) {
|
||||||
|
foundAnswer = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundAnswer {
|
||||||
|
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -721,7 +721,9 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
|||||||
func (d *Status) GetDNSStates() []NSGroupState {
|
func (d *Status) GetDNSStates() []NSGroupState {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
return d.nsGroupStates
|
|
||||||
|
// shallow copy is good enough, as slices fields are currently not updated
|
||||||
|
return slices.Clone(d.nsGroupStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user