Add debug output for timeouts

This commit is contained in:
Viktor Liu
2025-07-09 19:00:17 +02:00
parent a1cb7b4af6
commit 629757c911
10 changed files with 736 additions and 204 deletions

View File

@@ -0,0 +1,155 @@
package config
import (
"fmt"
"net"
"net/netip"
"net/url"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
// ServerDomains represents the management server domains extracted from NetBird configuration
type ServerDomains struct {
Signal domain.Domain
Relay []domain.Domain
Flow domain.Domain
Stuns []domain.Domain
Turns []domain.Domain
}
// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration
func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
if config == nil {
return ServerDomains{}
}
domains := ServerDomains{}
domains.Signal = extractSignalDomain(config)
domains.Relay = extractRelayDomains(config)
domains.Flow = extractFlowDomain(config)
domains.Stuns = extractStunDomains(config)
domains.Turns = extractTurnDomains(config)
return domains
}
// extractValidDomain extracts a valid domain from a URL, filtering out IP addresses
func extractValidDomain(rawURL string) (domain.Domain, error) {
parsedURL, err := url.Parse(rawURL)
if err != nil {
// If URL parsing fails, it might be a raw host:port, try parsing as such
if host, _, err := net.SplitHostPort(rawURL); err == nil {
return extractDomainFromHost(host)
}
// If not host:port, try as raw hostname
return extractDomainFromHost(rawURL)
}
host := parsedURL.Hostname()
if host == "" {
return "", fmt.Errorf("no hostname in URL")
}
return extractDomainFromHost(host)
}
// extractDomainFromHost extracts domain from a host string, filtering out IP addresses
func extractDomainFromHost(host string) (domain.Domain, error) {
if host == "" {
return "", fmt.Errorf("empty host")
}
if _, err := netip.ParseAddr(host); err == nil {
return "", fmt.Errorf("IP address not allowed: %s", host)
}
d, err := domain.FromString(host)
if err != nil {
return "", fmt.Errorf("invalid domain: %v", err)
}
return d, nil
}
// extractSingleDomain extracts a single domain from a URL with error logging
func extractSingleDomain(url, serviceType string) domain.Domain {
if url == "" {
return ""
}
d, err := extractValidDomain(url)
if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err)
return ""
}
return d
}
// extractMultipleDomains extracts multiple domains from URLs with error logging
func extractMultipleDomains(urls []string, serviceType string) []domain.Domain {
var domains []domain.Domain
for _, url := range urls {
if url == "" {
continue
}
d, err := extractValidDomain(url)
if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err)
continue
}
domains = append(domains, d)
}
return domains
}
// extractSignalDomain extracts the signal domain from NetBird configuration.
func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain {
if config.Signal != nil {
return extractSingleDomain(config.Signal.Uri, "signal")
}
return ""
}
// extractRelayDomains extracts relay server domains from NetBird configuration.
func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Relay != nil {
return extractMultipleDomains(config.Relay.Urls, "relay")
}
return nil
}
// extractFlowDomain extracts the traffic flow domain from NetBird configuration.
func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain {
if config.Flow != nil {
return extractSingleDomain(config.Flow.Url, "flow")
}
return ""
}
// extractStunDomains extracts STUN server domains from NetBird configuration.
func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var urls []string
for _, stun := range config.Stuns {
if stun != nil && stun.Uri != "" {
urls = append(urls, stun.Uri)
}
}
return extractMultipleDomains(urls, "STUN")
}
// extractTurnDomains extracts TURN server domains from NetBird configuration.
func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var urls []string
for _, turn := range config.Turns {
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
urls = append(urls, turn.HostConfig.Uri)
}
}
return extractMultipleDomains(urls, "TURN")
}

View File

@@ -182,7 +182,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If handler wants to continue, try next handler // If handler wants to continue, try next handler
if chainWriter.shouldContinue { if chainWriter.shouldContinue {
log.Tracef("handler requested continue to next handler for domain=%s", qname) // Only log continue for non-management cache handlers to reduce noise
if entry.Priority != PriorityMgmtCache {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
}
continue continue
} }
return return

View File

@@ -3,6 +3,7 @@ package mgmt
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net" "net"
"net/netip" "net/netip"
"net/url" "net/url"
@@ -13,6 +14,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
@@ -27,14 +29,12 @@ type CacheEntry struct {
type Resolver struct { type Resolver struct {
cache map[domain.Domain]CacheEntry cache map[domain.Domain]CacheEntry
mutex sync.RWMutex mutex sync.RWMutex
systemResolver *net.Resolver
} }
// NewResolver creates a new management domains cache resolver. // NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver { func NewResolver() *Resolver {
return &Resolver{ return &Resolver{
cache: make(map[domain.Domain]CacheEntry), cache: make(map[domain.Domain]CacheEntry),
systemResolver: net.DefaultResolver,
} }
} }
@@ -58,22 +58,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
log.Tracef("MgmtCache: checking cache for domain=%s type=%s", qname, dns.TypeToString[question.Qtype])
m.mutex.RLock() m.mutex.RLock()
parsedDomain, err := domain.FromString(qname) domainKey := domain.Domain(qname)
if err != nil { entry, found := m.cache[domainKey]
log.Tracef("MgmtCache: invalid domain format: %s", qname)
m.mutex.RUnlock()
m.continueToNext(w, r)
return
}
entry, found := m.cache[parsedDomain]
m.mutex.RUnlock() m.mutex.RUnlock()
if !found { if !found {
log.Tracef("MgmtCache: no cache entry found for domain=%s", qname)
m.continueToNext(w, r) m.continueToNext(w, r)
return return
} }
@@ -91,7 +81,6 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
if len(records) == 0 { if len(records) == 0 {
log.Tracef("MgmtCache: no %s records for domain=%s", dns.TypeToString[question.Qtype], parsedDomain.SafeString())
m.continueToNext(w, r) m.continueToNext(w, r)
return return
} }
@@ -102,10 +91,10 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp.Answer = append(resp.Answer, rrCopy) resp.Answer = append(resp.Answer, rrCopy)
} }
log.Tracef("MgmtCache: serving %d cached records for domain=%s", len(resp.Answer), parsedDomain.SafeString()) log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), domainKey.SafeString())
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("MgmtCache: failed to write response: %v", err) log.Errorf("failed to write response: %v", err)
} }
} }
@@ -120,60 +109,59 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
resp.SetRcode(r, dns.RcodeNameError) resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("MgmtCache: failed to write continue signal: %v", err) log.Errorf("failed to write continue signal: %v", err)
} }
} }
// AddDomain manually adds a domain to cache by resolving it. // AddDomain manually adds a domain to cache by resolving it.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
log.Debugf("MgmtCache: adding domain=%s to cache", d.SafeString()) log.Debugf("adding domain=%s to cache", d.SafeString())
ctx, cancel := context.WithTimeout(ctx, 10*time.Second) ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel() defer cancel()
var aRecords, aaaaRecords []dns.RR ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
if err != nil {
if ips, err := m.systemResolver.LookupNetIP(ctx, "ip", d.PunycodeString()); err == nil { return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
for _, ip := range ips {
if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: d.PunycodeString() + ".",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: d.PunycodeString() + ".",
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
}
}
m.mutex.Lock()
m.cache[d] = CacheEntry{
ARecords: aRecords,
AAAARecords: aaaaRecords,
}
m.mutex.Unlock()
log.Debugf("MgmtCache: added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
} else {
log.Warnf("MgmtCache: failed to resolve domain=%s: %v", d.SafeString(), err)
return err
} }
var aRecords, aaaaRecords []dns.RR
for _, ip := range ips {
if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: d.PunycodeString() + ".",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: d.PunycodeString() + ".",
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
}
}
m.mutex.Lock()
m.cache[d] = CacheEntry{
ARecords: aRecords,
AAAARecords: aaaaRecords,
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
return nil return nil
} }
@@ -182,7 +170,7 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err
if mgmtURL != nil { if mgmtURL != nil {
if d, err := extractDomainFromURL(mgmtURL); err == nil { if d, err := extractDomainFromURL(mgmtURL); err == nil {
if err := m.AddDomain(ctx, d); err != nil { if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add management domain: %v", err) log.Warnf("failed to add management domain: %v", err)
} }
} }
} }
@@ -190,6 +178,16 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err
return nil return nil
} }
// RemoveDomain removes a domain from the cache.
func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock()
defer m.mutex.Unlock()
delete(m.cache, d)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
}
// PopulateFromNetbirdConfig extracts and caches domains from the netbird config. // PopulateFromNetbirdConfig extracts and caches domains from the netbird config.
func (m *Resolver) PopulateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) error { func (m *Resolver) PopulateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) error {
if config == nil { if config == nil {
@@ -216,19 +214,19 @@ func (m *Resolver) addSignalDomain(ctx context.Context, signal *mgmProto.HostCon
// If parsing fails, it might be a raw host:port, try adding a scheme // If parsing fails, it might be a raw host:port, try adding a scheme
signalURL, err = url.Parse("https://" + signal.Uri) signalURL, err = url.Parse("https://" + signal.Uri)
if err != nil { if err != nil {
log.Warnf("MgmtCache: failed to parse signal URL: %v", err) log.Warnf("failed to parse signal URL: %v", err)
return return
} }
} }
d, err := extractDomainFromURL(signalURL) d, err := extractDomainFromURL(signalURL)
if err != nil { if err != nil {
log.Warnf("MgmtCache: failed to extract signal domain: %v", err) log.Warnf("failed to extract signal domain: %v", err)
return return
} }
if err := m.AddDomain(ctx, d); err != nil { if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add signal domain: %v", err) log.Warnf("failed to add signal domain: %v", err)
} }
} }
@@ -241,18 +239,18 @@ func (m *Resolver) addRelayDomains(ctx context.Context, relay *mgmProto.RelayCon
for _, relayAddr := range relay.Urls { for _, relayAddr := range relay.Urls {
relayURL, err := url.Parse(relayAddr) relayURL, err := url.Parse(relayAddr)
if err != nil { if err != nil {
log.Warnf("MgmtCache: failed to parse relay URL %s: %v", relayAddr, err) log.Warnf("failed to parse relay URL %s: %v", relayAddr, err)
continue continue
} }
d, err := extractDomainFromURL(relayURL) d, err := extractDomainFromURL(relayURL)
if err != nil { if err != nil {
log.Warnf("MgmtCache: failed to extract relay domain from %s: %v", relayAddr, err) log.Warnf("failed to extract relay domain from %s: %v", relayAddr, err)
continue continue
} }
if err := m.AddDomain(ctx, d); err != nil { if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add relay domain: %v", err) log.Warnf("failed to add relay domain: %v", err)
} }
} }
} }
@@ -265,18 +263,18 @@ func (m *Resolver) addFlowDomain(ctx context.Context, flow *mgmProto.FlowConfig)
flowURL, err := url.Parse(flow.Url) flowURL, err := url.Parse(flow.Url)
if err != nil { if err != nil {
log.Warnf("MgmtCache: failed to parse flow URL: %v", err) log.Warnf("failed to parse flow URL: %v", err)
return return
} }
d, err := extractDomainFromURL(flowURL) d, err := extractDomainFromURL(flowURL)
if err != nil { if err != nil {
log.Warnf("MgmtCache: failed to extract flow domain: %v", err) log.Warnf("failed to extract flow domain: %v", err)
return return
} }
if err := m.AddDomain(ctx, d); err != nil { if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add flow domain: %v", err) log.Warnf("failed to add flow domain: %v", err)
} }
} }
@@ -303,7 +301,7 @@ func (m *Resolver) ClearCache() []domain.Domain {
} }
m.cache = make(map[domain.Domain]CacheEntry) m.cache = make(map[domain.Domain]CacheEntry)
log.Debugf("MgmtCache: cleared %d cached domains", len(domains)) log.Debugf("cleared %d cached domains", len(domains))
return domains return domains
} }
@@ -311,7 +309,7 @@ func (m *Resolver) ClearCache() []domain.Domain {
// UpdateFromNetbirdConfig updates the cache intelligently by comparing current and new configurations. // UpdateFromNetbirdConfig updates the cache intelligently by comparing current and new configurations.
// Returns domains that were removed for external deregistration. // Returns domains that were removed for external deregistration.
func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) ([]domain.Domain, error) { func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) ([]domain.Domain, error) {
log.Debugf("MgmtCache: updating cache from NetbirdConfig") log.Debugf("updating cache from NetbirdConfig")
currentDomains := m.GetCachedDomains() currentDomains := m.GetCachedDomains()
newDomains := m.extractDomainsFromConfig(config) newDomains := m.extractDomainsFromConfig(config)
@@ -333,19 +331,86 @@ func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto
m.mutex.Lock() m.mutex.Lock()
for _, domainToRemove := range removedDomains { for _, domainToRemove := range removedDomains {
delete(m.cache, domainToRemove) delete(m.cache, domainToRemove)
log.Debugf("MgmtCache: removed domain=%s from cache", domainToRemove.SafeString()) log.Debugf("removed domain=%s from cache", domainToRemove.SafeString())
} }
m.mutex.Unlock() m.mutex.Unlock()
for _, newDomain := range newDomains { for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil { if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("MgmtCache: failed to add/update domain=%s: %v", newDomain.SafeString(), err) log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} }
} }
return removedDomains, nil return removedDomains, nil
} }
// UpdateFromServerDomains updates the cache using the simplified ServerDomains struct
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) ([]domain.Domain, error) {
log.Debugf("updating cache from ServerDomains")
currentDomains := m.GetCachedDomains()
newDomains := m.extractDomainsFromServerDomains(serverDomains)
var removedDomains []domain.Domain
for _, currentDomain := range currentDomains {
found := false
for _, newDomain := range newDomains {
if currentDomain.SafeString() == newDomain.SafeString() {
found = true
break
}
}
if !found {
removedDomains = append(removedDomains, currentDomain)
if err := m.RemoveDomain(currentDomain); err != nil {
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
}
}
}
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
}
}
return removedDomains, nil
}
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) []domain.Domain {
var domains []domain.Domain
if serverDomains.Signal != "" {
domains = append(domains, serverDomains.Signal)
}
for _, relay := range serverDomains.Relay {
if relay != "" {
domains = append(domains, relay)
}
}
if serverDomains.Flow != "" {
domains = append(domains, serverDomains.Flow)
}
for _, stun := range serverDomains.Stuns {
if stun != "" {
domains = append(domains, stun)
}
}
for _, turn := range serverDomains.Turns {
if turn != "" {
domains = append(domains, turn)
}
}
return domains
}
// extractDomainsFromConfig extracts all domains from a NetbirdConfig. // extractDomainsFromConfig extracts all domains from a NetbirdConfig.
func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []domain.Domain { func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []domain.Domain {
if config == nil { if config == nil {
@@ -354,26 +419,62 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
var domains []domain.Domain var domains []domain.Domain
if config.Signal != nil && config.Signal.Uri != "" { // Extract signal domain
if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil { domains = append(domains, m.extractSignalDomain(config)...)
// Extract relay domains
domains = append(domains, m.extractRelayDomains(config)...)
// Extract flow domain
domains = append(domains, m.extractFlowDomain(config)...)
// Extract STUN domains
domains = append(domains, m.extractSTUNDomains(config)...)
// Extract TURN domains
domains = append(domains, m.extractTURNDomains(config)...)
return domains
}
func (m *Resolver) extractSignalDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Signal == nil || config.Signal.Uri == "" {
return nil
}
if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil {
return []domain.Domain{d}
}
return nil
}
func (m *Resolver) extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Relay == nil {
return nil
}
var domains []domain.Domain
for _, relayURL := range config.Relay.Urls {
if d, err := m.extractDomainFromURL(relayURL); err == nil {
domains = append(domains, d) domains = append(domains, d)
} }
} }
return domains
}
if config.Relay != nil { func (m *Resolver) extractFlowDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
for _, relayURL := range config.Relay.Urls { if config.Flow == nil || config.Flow.Url == "" {
if d, err := m.extractDomainFromURL(relayURL); err == nil { return nil
domains = append(domains, d)
}
}
} }
if config.Flow != nil && config.Flow.Url != "" { if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil {
if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil { return []domain.Domain{d}
domains = append(domains, d)
}
} }
return nil
}
func (m *Resolver) extractSTUNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var domains []domain.Domain
for _, stun := range config.Stuns { for _, stun := range config.Stuns {
if stun != nil && stun.Uri != "" { if stun != nil && stun.Uri != "" {
if d, err := m.extractDomainFromURL(stun.Uri); err == nil { if d, err := m.extractDomainFromURL(stun.Uri); err == nil {
@@ -381,7 +482,11 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
} }
} }
} }
return domains
}
func (m *Resolver) extractTURNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var domains []domain.Domain
for _, turn := range config.Turns { for _, turn := range config.Turns {
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" { if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
if d, err := m.extractDomainFromURL(turn.HostConfig.Uri); err == nil { if d, err := m.extractDomainFromURL(turn.HostConfig.Uri); err == nil {
@@ -389,7 +494,6 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
} }
} }
} }
return domains return domains
} }
@@ -424,18 +528,18 @@ func (m *Resolver) addStunDomains(ctx context.Context, stuns []*mgmProto.HostCon
stunURL, err := url.Parse(stun.Uri) stunURL, err := url.Parse(stun.Uri)
if err != nil { if err != nil {
log.Warnf("MgmtCache: failed to parse STUN URL %s: %v", stun.Uri, err) log.Warnf("failed to parse STUN URL %s: %v", stun.Uri, err)
continue continue
} }
d, err := extractDomainFromURL(stunURL) d, err := extractDomainFromURL(stunURL)
if err != nil { if err != nil {
log.Warnf("MgmtCache: failed to extract STUN domain from %s: %v", stun.Uri, err) log.Warnf("failed to extract STUN domain from %s: %v", stun.Uri, err)
continue continue
} }
if err := m.AddDomain(ctx, d); err != nil { if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add STUN domain: %v", err) log.Warnf("failed to add STUN domain: %v", err)
} }
} }
} }
@@ -449,18 +553,18 @@ func (m *Resolver) addTurnDomains(ctx context.Context, turns []*mgmProto.Protect
turnURL, err := url.Parse(turn.HostConfig.Uri) turnURL, err := url.Parse(turn.HostConfig.Uri)
if err != nil { if err != nil {
log.Warnf("MgmtCache: failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err) log.Warnf("failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err)
continue continue
} }
d, err := extractDomainFromURL(turnURL) d, err := extractDomainFromURL(turnURL)
if err != nil { if err != nil {
log.Warnf("MgmtCache: failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err) log.Warnf("failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err)
continue continue
} }
if err := m.AddDomain(ctx, d); err != nil { if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add TURN domain: %v", err) log.Warnf("failed to add TURN domain: %v", err)
} }
} }
} }

View File

@@ -5,7 +5,6 @@ import (
"net" "net"
"net/url" "net/url"
"testing" "testing"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -114,16 +113,15 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
resolver := NewResolver() resolver := NewResolver()
mgmtURL, _ := url.Parse("https://api.netbird.io") // Use IP address to avoid DNS resolution timeout
mgmtURL, _ := url.Parse("https://127.0.0.1")
err := resolver.PopulateFromConfig(ctx, mgmtURL) err := resolver.PopulateFromConfig(ctx, mgmtURL)
assert.NoError(t, err) assert.NoError(t, err)
// Give some time for async population // IP addresses are rejected, so no domains should be cached
time.Sleep(100 * time.Millisecond)
domains := resolver.GetCachedDomains() domains := resolver.GetCachedDomains()
assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
} }
func TestResolver_PopulateFromNetbirdConfig(t *testing.T) { func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
@@ -132,32 +130,33 @@ func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
resolver := NewResolver() resolver := NewResolver()
// Use IP addresses to avoid DNS resolution timeouts
netbirdConfig := &mgmProto.NetbirdConfig{ netbirdConfig := &mgmProto.NetbirdConfig{
Signal: &mgmProto.HostConfig{ Signal: &mgmProto.HostConfig{
Uri: "https://signal.netbird.io", Uri: "https://10.0.0.1",
}, },
Relay: &mgmProto.RelayConfig{ Relay: &mgmProto.RelayConfig{
Urls: []string{ Urls: []string{
"https://relay1.netbird.io:443", "https://10.0.0.2:443",
"https://relay2.netbird.io:443", "https://10.0.0.3:443",
}, },
}, },
Flow: &mgmProto.FlowConfig{ Flow: &mgmProto.FlowConfig{
Url: "https://flow.netbird.io:80", Url: "https://10.0.0.4:80",
}, },
Stuns: []*mgmProto.HostConfig{ Stuns: []*mgmProto.HostConfig{
{Uri: "stun:stun1.netbird.io:3478"}, {Uri: "stun:10.0.0.5:3478"},
{Uri: "stun:stun2.netbird.io:3478"}, {Uri: "stun:10.0.0.6:3478"},
}, },
Turns: []*mgmProto.ProtectedHostConfig{ Turns: []*mgmProto.ProtectedHostConfig{
{ {
HostConfig: &mgmProto.HostConfig{ HostConfig: &mgmProto.HostConfig{
Uri: "turn:turn1.netbird.io:3478", Uri: "turn:10.0.0.7:3478",
}, },
}, },
{ {
HostConfig: &mgmProto.HostConfig{ HostConfig: &mgmProto.HostConfig{
Uri: "turn:turn2.netbird.io:3478", Uri: "turn:10.0.0.8:3478",
}, },
}, },
}, },
@@ -166,11 +165,42 @@ func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
err := resolver.PopulateFromNetbirdConfig(ctx, netbirdConfig) err := resolver.PopulateFromNetbirdConfig(ctx, netbirdConfig)
assert.NoError(t, err) assert.NoError(t, err)
// Give some time for async population // IP addresses are rejected, so no domains should be cached
time.Sleep(100 * time.Millisecond)
domains := resolver.GetCachedDomains() domains := resolver.GetCachedDomains()
assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
}
func TestResolver_UpdateFromNetbirdConfig(t *testing.T) {
resolver := NewResolver()
// Test with empty initial config and then add domains
initialConfig := &mgmProto.NetbirdConfig{}
// Start with empty config
removedDomains, err := resolver.UpdateFromNetbirdConfig(context.Background(), initialConfig)
assert.NoError(t, err)
assert.Equal(t, 0, len(removedDomains), "No domains should be removed from empty cache")
// Update to config with IP addresses instead of domains to avoid DNS resolution
// IP addresses will be rejected by extractDomainFromURL so no actual resolution happens
updatedConfig := &mgmProto.NetbirdConfig{
Signal: &mgmProto.HostConfig{
Uri: "https://127.0.0.1",
},
Flow: &mgmProto.FlowConfig{
Url: "https://192.168.1.1:80",
},
}
removedDomains, err = resolver.UpdateFromNetbirdConfig(context.Background(), updatedConfig)
assert.NoError(t, err)
// Verify the method completes successfully without DNS timeouts
assert.GreaterOrEqual(t, len(removedDomains), 0, "Should not error on config update")
// Verify no domains were actually added since IPs are rejected
domains := resolver.GetCachedDomains()
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
} }
func TestResolver_ContinueToNext(t *testing.T) { func TestResolver_ContinueToNext(t *testing.T) {

View File

@@ -5,17 +5,19 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
) )
// MockServer is the mock instance of a dns server // MockServer is the mock instance of a dns server
type MockServer struct { type MockServer struct {
InitializeFunc func() error InitializeFunc func() error
StopFunc func() StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func(domain.List, dns.Handler, int) RegisterHandlerFunc func(domain.List, dns.Handler, int)
DeregisterHandlerFunc func(domain.List, int) DeregisterHandlerFunc func(domain.List, int)
UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error
} }
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) { func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
@@ -69,3 +71,10 @@ func (m *MockServer) SearchDomains() []string {
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface // ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
func (m *MockServer) ProbeAvailability() { func (m *MockServer) ProbeAvailability() {
} }
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
if m.UpdateServerConfigFunc != nil {
return m.UpdateServerConfigFunc(domains)
}
return nil
}

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/local" "github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/mgmt" "github.com/netbirdio/netbird/client/internal/dns/mgmt"
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
@@ -25,7 +26,6 @@ import (
cProto "github.com/netbirdio/netbird/client/proto" cProto "github.com/netbirdio/netbird/client/proto"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
) )
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
@@ -49,6 +49,7 @@ type Server interface {
OnUpdatedHostDNSServer(strings []string) OnUpdatedHostDNSServer(strings []string)
SearchDomains() []string SearchDomains() []string
ProbeAvailability() ProbeAvailability()
UpdateServerConfig(domains dnsconfig.ServerDomains) error
} }
type nsGroupsByDomain struct { type nsGroupsByDomain struct {
@@ -103,20 +104,23 @@ type handlerWrapper struct {
type registeredHandlerMap map[types.HandlerID]handlerWrapper type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct {
Ctx context.Context
WgInterface WGIface
CustomAddress string
StatusRecorder *peer.Status
StateManager *statemanager.Manager
DisableSys bool
MgmtURL *url.URL
ServerDomains dnsconfig.ServerDomains
}
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer( func NewDefaultServer(config DefaultServerConfig) (*DefaultServer, error) {
ctx context.Context,
wgInterface WGIface,
customAddress string,
statusRecorder *peer.Status,
stateManager *statemanager.Manager,
disableSys bool,
mgmtURL *url.URL,
netbirdConfig *mgmProto.NetbirdConfig,
) (*DefaultServer, error) {
var addrPort *netip.AddrPort var addrPort *netip.AddrPort
if customAddress != "" { if config.CustomAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress) parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err) return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
} }
@@ -124,31 +128,23 @@ func NewDefaultServer(
} }
var dnsService service var dnsService service
if wgInterface.IsUserspaceBind() { if config.WgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(wgInterface) dnsService = NewServiceViaMemory(config.WgInterface)
} else { } else {
dnsService = newServiceViaListener(wgInterface, addrPort) dnsService = newServiceViaListener(config.WgInterface, addrPort)
} }
server := newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys) server := newDefaultServer(config.Ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
// Pre-populate management cache with management URL if config.MgmtURL != nil && server.mgmtCacheResolver != nil {
if mgmtURL != nil && server.mgmtCacheResolver != nil { if err := server.mgmtCacheResolver.PopulateFromConfig(config.Ctx, config.MgmtURL); err != nil {
if err := server.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL); err != nil {
log.Warnf("Failed to populate management cache from management URL: %v", err) log.Warnf("Failed to populate management cache from management URL: %v", err)
} }
} }
// Pre-populate management cache with NetbirdConfig domains if server.mgmtCacheResolver != nil {
if netbirdConfig != nil && server.mgmtCacheResolver != nil { if err := server.UpdateServerConfig(config.ServerDomains); err != nil {
if err := server.mgmtCacheResolver.PopulateFromNetbirdConfig(ctx, netbirdConfig); err != nil { log.Warnf("Failed to populate management cache from ServerDomains: %v", err)
log.Warnf("Failed to populate management cache from NetbirdConfig: %v", err)
}
// Register newly populated domains
domains := server.mgmtCacheResolver.GetCachedDomains()
if len(domains) > 0 {
server.RegisterHandler(domains, server.mgmtCacheResolver, PriorityMgmtCache)
} }
} }
@@ -220,19 +216,11 @@ func newDefaultServer(
mgmtCacheResolver: mgmtCacheResolver, mgmtCacheResolver: mgmtCacheResolver,
} }
// Register cached domains with the handler chain domains := mgmtCacheResolver.GetCachedDomains()
registerMgmtCacheDomains := func() { if len(domains) > 0 {
domains := mgmtCacheResolver.GetCachedDomains() defaultServer.RegisterHandler(domains, mgmtCacheResolver, PriorityMgmtCache)
if len(domains) > 0 {
defaultServer.RegisterHandler(domains, mgmtCacheResolver, PriorityMgmtCache)
}
} }
// Register any pre-populated domains from management cache
registerMgmtCacheDomains()
// Management cache resolver will be registered for specific domains when they are added
// register with root zone, handler chain takes care of the routing // register with root zone, handler chain takes care of the routing
dnsService.RegisterMux(".", handlerChain) dnsService.RegisterMux(".", handlerChain)
@@ -352,7 +340,6 @@ func (s *DefaultServer) Stop() {
} }
} }
s.service.Stop() s.service.Stop()
maps.Clear(s.extraDomains) maps.Clear(s.extraDomains)
@@ -368,15 +355,6 @@ func (s *DefaultServer) PopulateMgmtCacheFromConfig(mgmtURL *url.URL) error {
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL) return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
} }
// PopulateMgmtCacheFromNetbirdConfig populates the management cache with domains from the netbird configuration
func (s *DefaultServer) PopulateMgmtCacheFromNetbirdConfig(config *mgmProto.NetbirdConfig) error {
if s.mgmtCacheResolver == nil {
return fmt.Errorf("management cache resolver not initialized")
}
log.Debug("populating management cache from netbird configuration")
return s.mgmtCacheResolver.PopulateFromNetbirdConfig(s.ctx, config)
}
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone // It will be applied if the mgm server do not enforce DNS settings for root zone
@@ -476,6 +454,29 @@ func (s *DefaultServer) ProbeAvailability() {
wg.Wait() wg.Wait()
} }
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
s.mux.Lock()
defer s.mux.Unlock()
if s.mgmtCacheResolver != nil {
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains)
if err != nil {
return fmt.Errorf("update management cache resolver: %w", err)
}
if len(removedDomains) > 0 {
s.DeregisterHandler(removedDomains, PriorityMgmtCache)
}
newDomains := s.mgmtCacheResolver.GetCachedDomains()
if len(newDomains) > 0 {
s.RegisterHandler(newDomains, s.mgmtCacheResolver, PriorityMgmtCache)
}
}
return nil
}
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be Disabled, we stop the listener or fake resolver // is the service should be Disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records // and proceed with a regular update to clean up the handlers and records

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks" pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/local" "github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
@@ -363,7 +364,16 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err) t.Log(err)
} }
}() }()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil) dnsServer, err := NewDefaultServer(DefaultServerConfig{
Ctx: context.Background(),
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -473,7 +483,16 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return return
} }
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil) dnsServer, err := NewDefaultServer(DefaultServerConfig{
Ctx: context.Background(),
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
})
if err != nil { if err != nil {
t.Errorf("create DNS server: %v", err) t.Errorf("create DNS server: %v", err)
return return
@@ -575,7 +594,16 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false, nil, nil) dnsServer, err := NewDefaultServer(DefaultServerConfig{
Ctx: context.Background(),
WgInterface: &mocWGIface{},
CustomAddress: testCase.addrPort,
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
})
if err != nil { if err != nil {
t.Fatalf("%v", err) t.Fatalf("%v", err)
} }

View File

@@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"slices" "slices"
"strings" "strings"
"sync" "sync"
@@ -26,10 +27,10 @@ import (
) )
const ( const (
UpstreamTimeout = 15 * time.Second UpstreamTimeout = 4 * time.Second
// ClientTimeout is the timeout for the dns.Client. // ClientTimeout is the timeout for the dns.Client.
// Set longer than UpstreamTimeout to ensure context timeout takes precedence // Set longer than UpstreamTimeout to ensure context timeout takes precedence
ClientTimeout = 30 * time.Second ClientTimeout = 5 * time.Second
reactivatePeriod = 30 * time.Second reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second probeTimeout = 2 * time.Second
@@ -105,52 +106,111 @@ func (u *upstreamResolverBase) Stop() {
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID() requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID) logger := log.WithField("request_id", requestID)
var err error
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
u.prepareRequest(r)
if u.isContextDone(logger) {
return
}
if u.tryUpstreamServers(w, r, logger) {
return
}
u.writeErrorResponse(w, r, logger)
}
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
if r.Extra == nil { if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
}
func (u *upstreamResolverBase) isContextDone(logger *log.Entry) bool {
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
logger.Tracef("%s has been stopped", u) logger.Tracef("%s has been stopped", u)
return return true
default: default:
return false
}
}
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
minPerUpstream := 2 * time.Second
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
if scaledTimeout > minPerUpstream {
timeout = scaledTimeout
} else {
timeout = minPerUpstream
}
} }
for _, upstream := range u.upstreamServers { for _, upstream := range u.upstreamServers {
var rm *dns.Msg if u.queryUpstream(w, r, upstream, timeout, logger) {
var t time.Duration return true
func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
defer cancel()
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r)
}()
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue
}
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue
} }
}
return false
}
if rm == nil || !rm.Response { func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream string, timeout time.Duration, logger *log.Entry) bool {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) var rm *dns.Msg
continue var t time.Duration
} var err error
u.successCount.Add(1) var startTime time.Time
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) func() {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel()
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r)
}()
if err = w.WriteMsg(rm); err != nil { if err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
} return false
}
if rm == nil || !rm.Response {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
return false
}
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
return return
} }
elapsed := time.Since(startTime)
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
timeoutMsg += " " + peerInfo
}
timeoutMsg += fmt.Sprintf(" - error: %v", err)
logger.Warnf(timeoutMsg)
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
}
return true
}
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg) m := new(dns.Msg)
@@ -355,3 +415,97 @@ func GenerateRequestID() string {
} }
return hex.EncodeToString(bytes) return hex.EncodeToString(bytes)
} }
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected
hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() &&
time.Since(peerState.LastWireguardHandshake) < 3*time.Minute
statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP)
switch {
case !isConnected:
statusInfo += " DISCONNECTED"
case !hasRecentHandshake:
statusInfo += " NO_RECENT_HANDSHAKE"
default:
statusInfo += " connected"
}
if !peerState.LastWireguardHandshake.IsZero() {
timeSinceHandshake := time.Since(peerState.LastWireguardHandshake)
statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second))
} else {
statusInfo += " no_handshake"
}
if peerState.Relayed {
statusInfo += " via_relay"
}
if peerState.Latency > 0 {
statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency)
}
return statusInfo
}
// findPeerForIP finds which peer handles the given IP address
func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
if statusRecorder == nil {
return nil
}
fullStatus := statusRecorder.GetFullStatus()
var bestMatch *peer.State
var bestPrefixLen int
for _, peerState := range fullStatus.Peers {
routes := peerState.GetRoutes()
for route := range routes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
continue
}
if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen {
peerStateCopy := peerState
bestMatch = &peerStateCopy
bestPrefixLen = prefix.Bits()
}
}
}
return bestMatch
}
// parseUpstreamIP parses an upstream server address to extract the IP
func parseUpstreamIP(upstream string) (netip.Addr, error) {
upstreamIP, err := netip.ParseAddr(upstream)
if err != nil {
if host, _, err := net.SplitHostPort(upstream); err == nil {
return netip.ParseAddr(host)
}
return netip.Addr{}, err
}
return upstreamIP, nil
}
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream string) string {
if u.statusRecorder == nil {
return ""
}
upstreamIP, err := parseUpstreamIP(upstream)
if err != nil {
return ""
}
peerInfo := findPeerForIP(upstreamIP, u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
}

View File

@@ -33,6 +33,7 @@ import (
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
@@ -696,6 +697,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return fmt.Errorf("handle the flow configuration: %w", err) return fmt.Errorf("handle the flow configuration: %w", err)
} }
if e.dnsServer != nil {
serverDomains := config.ExtractFromNetbirdConfig(wCfg)
if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
}
// todo update signal // todo update signal
} }
@@ -1604,7 +1612,19 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbird
return dnsServer, nil return dnsServer, nil
default: default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS, mgmtURL, netbirdConfig) // Extract domains from NetBird configuration
serverDomains := config.ExtractFromNetbirdConfig(netbirdConfig)
dnsServer, err := dns.NewDefaultServer(dns.DefaultServerConfig{
Ctx: e.ctx,
WgInterface: e.wgInterface,
CustomAddress: e.config.CustomDNSAddress,
StatusRecorder: e.statusRecorder,
StateManager: e.stateManager,
DisableSys: e.config.DisableDNS,
MgmtURL: mgmtURL,
ServerDomains: serverDomains,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -2,11 +2,13 @@ package dnsinterceptor
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -26,6 +28,8 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const dnsTimeout = 8 * time.Second
type domainMap map[domain.Domain][]netip.Prefix type domainMap map[domain.Domain][]netip.Prefix
type internalDNATer interface { type internalDNATer interface {
@@ -243,7 +247,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout) client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if err != nil { if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err)) d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return return
@@ -254,9 +258,20 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
startTime := time.Now()
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
if err != nil { if err != nil {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err) logger.Errorf("failed writing DNS response: %v", err)
} }
@@ -568,3 +583,16 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
} }
return return
} }
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
if d.statusRecorder == nil {
return ""
}
peerState, err := d.statusRecorder.GetPeer(peerKey)
if err != nil {
return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err)
}
return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState))
}