From 629757c91135db1ebf2357aef03bb3443d674066 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 9 Jul 2025 19:00:17 +0200 Subject: [PATCH] Add debug output for timeouts --- client/internal/dns/config/domains.go | 155 ++++++++++ client/internal/dns/handler_chain.go | 5 +- client/internal/dns/mgmt/mgmt.go | 290 ++++++++++++------ client/internal/dns/mgmt/mgmt_test.go | 66 ++-- client/internal/dns/mock_server.go | 19 +- client/internal/dns/server.go | 103 ++++--- client/internal/dns/server_test.go | 34 +- client/internal/dns/upstream.go | 212 +++++++++++-- client/internal/engine.go | 22 +- .../routemanager/dnsinterceptor/handler.go | 34 +- 10 files changed, 736 insertions(+), 204 deletions(-) create mode 100644 client/internal/dns/config/domains.go diff --git a/client/internal/dns/config/domains.go b/client/internal/dns/config/domains.go new file mode 100644 index 000000000..067608b2e --- /dev/null +++ b/client/internal/dns/config/domains.go @@ -0,0 +1,155 @@ +package config + +import ( + "fmt" + "net" + "net/netip" + "net/url" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/domain" + mgmProto "github.com/netbirdio/netbird/management/proto" +) + +// ServerDomains represents the management server domains extracted from NetBird configuration +type ServerDomains struct { + Signal domain.Domain + Relay []domain.Domain + Flow domain.Domain + Stuns []domain.Domain + Turns []domain.Domain +} + +// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration +func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains { + if config == nil { + return ServerDomains{} + } + + domains := ServerDomains{} + + domains.Signal = extractSignalDomain(config) + domains.Relay = extractRelayDomains(config) + domains.Flow = extractFlowDomain(config) + domains.Stuns = extractStunDomains(config) + domains.Turns = extractTurnDomains(config) + + return domains +} + +// extractValidDomain extracts a valid domain from a URL, filtering out IP addresses +func extractValidDomain(rawURL string) (domain.Domain, error) { + parsedURL, err := url.Parse(rawURL) + if err != nil { + // If URL parsing fails, it might be a raw host:port, try parsing as such + if host, _, err := net.SplitHostPort(rawURL); err == nil { + return extractDomainFromHost(host) + } + // If not host:port, try as raw hostname + return extractDomainFromHost(rawURL) + } + + host := parsedURL.Hostname() + if host == "" { + return "", fmt.Errorf("no hostname in URL") + } + + return extractDomainFromHost(host) +} + +// extractDomainFromHost extracts domain from a host string, filtering out IP addresses +func extractDomainFromHost(host string) (domain.Domain, error) { + if host == "" { + return "", fmt.Errorf("empty host") + } + + if _, err := netip.ParseAddr(host); err == nil { + return "", fmt.Errorf("IP address not allowed: %s", host) + } + + d, err := domain.FromString(host) + if err != nil { + return "", fmt.Errorf("invalid domain: %v", err) + } + + return d, nil +} + +// extractSingleDomain extracts a single domain from a URL with error logging +func extractSingleDomain(url, serviceType string) domain.Domain { + if url == "" { + return "" + } + + d, err := extractValidDomain(url) + if err != nil { + log.Debugf("Skipping %s: %v", serviceType, err) + return "" + } + + return d +} + +// extractMultipleDomains extracts multiple domains from URLs with error logging +func extractMultipleDomains(urls []string, serviceType string) []domain.Domain { + var domains []domain.Domain + for _, url := range urls { + if url == "" { + continue + } + d, err := extractValidDomain(url) + if err != nil { + log.Debugf("Skipping %s: %v", serviceType, err) + continue + } + domains = append(domains, d) + } + return domains +} + +// extractSignalDomain extracts the signal domain from NetBird configuration. +func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain { + if config.Signal != nil { + return extractSingleDomain(config.Signal.Uri, "signal") + } + return "" +} + +// extractRelayDomains extracts relay server domains from NetBird configuration. +func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + if config.Relay != nil { + return extractMultipleDomains(config.Relay.Urls, "relay") + } + return nil +} + +// extractFlowDomain extracts the traffic flow domain from NetBird configuration. +func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain { + if config.Flow != nil { + return extractSingleDomain(config.Flow.Url, "flow") + } + return "" +} + +// extractStunDomains extracts STUN server domains from NetBird configuration. +func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + var urls []string + for _, stun := range config.Stuns { + if stun != nil && stun.Uri != "" { + urls = append(urls, stun.Uri) + } + } + return extractMultipleDomains(urls, "STUN") +} + +// extractTurnDomains extracts TURN server domains from NetBird configuration. +func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + var urls []string + for _, turn := range config.Turns { + if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" { + urls = append(urls, turn.HostConfig.Uri) + } + } + return extractMultipleDomains(urls, "TURN") +} diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index ee80113e1..e73325b4b 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -182,7 +182,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // If handler wants to continue, try next handler if chainWriter.shouldContinue { - log.Tracef("handler requested continue to next handler for domain=%s", qname) + // Only log continue for non-management cache handlers to reduce noise + if entry.Priority != PriorityMgmtCache { + log.Tracef("handler requested continue to next handler for domain=%s", qname) + } continue } return diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index ba8b1640b..993a51297 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -3,6 +3,7 @@ package mgmt import ( "context" "errors" + "fmt" "net" "net/netip" "net/url" @@ -13,6 +14,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/management/domain" mgmProto "github.com/netbirdio/netbird/management/proto" ) @@ -27,14 +29,12 @@ type CacheEntry struct { type Resolver struct { cache map[domain.Domain]CacheEntry mutex sync.RWMutex - systemResolver *net.Resolver } // NewResolver creates a new management domains cache resolver. func NewResolver() *Resolver { return &Resolver{ cache: make(map[domain.Domain]CacheEntry), - systemResolver: net.DefaultResolver, } } @@ -58,22 +58,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - log.Tracef("MgmtCache: checking cache for domain=%s type=%s", qname, dns.TypeToString[question.Qtype]) - m.mutex.RLock() - parsedDomain, err := domain.FromString(qname) - if err != nil { - log.Tracef("MgmtCache: invalid domain format: %s", qname) - m.mutex.RUnlock() - m.continueToNext(w, r) - return - } - - entry, found := m.cache[parsedDomain] + domainKey := domain.Domain(qname) + entry, found := m.cache[domainKey] m.mutex.RUnlock() if !found { - log.Tracef("MgmtCache: no cache entry found for domain=%s", qname) m.continueToNext(w, r) return } @@ -91,7 +81,6 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } if len(records) == 0 { - log.Tracef("MgmtCache: no %s records for domain=%s", dns.TypeToString[question.Qtype], parsedDomain.SafeString()) m.continueToNext(w, r) return } @@ -102,10 +91,10 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { resp.Answer = append(resp.Answer, rrCopy) } - log.Tracef("MgmtCache: serving %d cached records for domain=%s", len(resp.Answer), parsedDomain.SafeString()) + log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), domainKey.SafeString()) if err := w.WriteMsg(resp); err != nil { - log.Errorf("MgmtCache: failed to write response: %v", err) + log.Errorf("failed to write response: %v", err) } } @@ -120,60 +109,59 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { resp.SetRcode(r, dns.RcodeNameError) resp.MsgHdr.Zero = true if err := w.WriteMsg(resp); err != nil { - log.Errorf("MgmtCache: failed to write continue signal: %v", err) + log.Errorf("failed to write continue signal: %v", err) } } // AddDomain manually adds a domain to cache by resolving it. func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { - log.Debugf("MgmtCache: adding domain=%s to cache", d.SafeString()) + log.Debugf("adding domain=%s to cache", d.SafeString()) ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - var aRecords, aaaaRecords []dns.RR - - if ips, err := m.systemResolver.LookupNetIP(ctx, "ip", d.PunycodeString()); err == nil { - for _, ip := range ips { - if ip.Is4() { - rr := &dns.A{ - Hdr: dns.RR_Header{ - Name: d.PunycodeString() + ".", - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 300, - }, - A: ip.AsSlice(), - } - aRecords = append(aRecords, rr) - } else if ip.Is6() { - rr := &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: d.PunycodeString() + ".", - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 300, - }, - AAAA: ip.AsSlice(), - } - aaaaRecords = append(aaaaRecords, rr) - } - } - - m.mutex.Lock() - m.cache[d] = CacheEntry{ - ARecords: aRecords, - AAAARecords: aaaaRecords, - } - m.mutex.Unlock() - - log.Debugf("MgmtCache: added domain=%s with %d A records and %d AAAA records", - d.SafeString(), len(aRecords), len(aaaaRecords)) - } else { - log.Warnf("MgmtCache: failed to resolve domain=%s: %v", d.SafeString(), err) - return err + ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) + if err != nil { + return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err) } + var aRecords, aaaaRecords []dns.RR + for _, ip := range ips { + if ip.Is4() { + rr := &dns.A{ + Hdr: dns.RR_Header{ + Name: d.PunycodeString() + ".", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: ip.AsSlice(), + } + aRecords = append(aRecords, rr) + } else if ip.Is6() { + rr := &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: d.PunycodeString() + ".", + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 300, + }, + AAAA: ip.AsSlice(), + } + aaaaRecords = append(aaaaRecords, rr) + } + } + + m.mutex.Lock() + m.cache[d] = CacheEntry{ + ARecords: aRecords, + AAAARecords: aaaaRecords, + } + m.mutex.Unlock() + + log.Debugf("added domain=%s with %d A records and %d AAAA records", + d.SafeString(), len(aRecords), len(aaaaRecords)) + return nil } @@ -182,7 +170,7 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err if mgmtURL != nil { if d, err := extractDomainFromURL(mgmtURL); err == nil { if err := m.AddDomain(ctx, d); err != nil { - log.Warnf("MgmtCache: failed to add management domain: %v", err) + log.Warnf("failed to add management domain: %v", err) } } } @@ -190,6 +178,16 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err return nil } +// RemoveDomain removes a domain from the cache. +func (m *Resolver) RemoveDomain(d domain.Domain) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + delete(m.cache, d) + log.Debugf("removed domain=%s from cache", d.SafeString()) + return nil +} + // PopulateFromNetbirdConfig extracts and caches domains from the netbird config. func (m *Resolver) PopulateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) error { if config == nil { @@ -216,19 +214,19 @@ func (m *Resolver) addSignalDomain(ctx context.Context, signal *mgmProto.HostCon // If parsing fails, it might be a raw host:port, try adding a scheme signalURL, err = url.Parse("https://" + signal.Uri) if err != nil { - log.Warnf("MgmtCache: failed to parse signal URL: %v", err) + log.Warnf("failed to parse signal URL: %v", err) return } } d, err := extractDomainFromURL(signalURL) if err != nil { - log.Warnf("MgmtCache: failed to extract signal domain: %v", err) + log.Warnf("failed to extract signal domain: %v", err) return } if err := m.AddDomain(ctx, d); err != nil { - log.Warnf("MgmtCache: failed to add signal domain: %v", err) + log.Warnf("failed to add signal domain: %v", err) } } @@ -241,18 +239,18 @@ func (m *Resolver) addRelayDomains(ctx context.Context, relay *mgmProto.RelayCon for _, relayAddr := range relay.Urls { relayURL, err := url.Parse(relayAddr) if err != nil { - log.Warnf("MgmtCache: failed to parse relay URL %s: %v", relayAddr, err) + log.Warnf("failed to parse relay URL %s: %v", relayAddr, err) continue } d, err := extractDomainFromURL(relayURL) if err != nil { - log.Warnf("MgmtCache: failed to extract relay domain from %s: %v", relayAddr, err) + log.Warnf("failed to extract relay domain from %s: %v", relayAddr, err) continue } if err := m.AddDomain(ctx, d); err != nil { - log.Warnf("MgmtCache: failed to add relay domain: %v", err) + log.Warnf("failed to add relay domain: %v", err) } } } @@ -265,18 +263,18 @@ func (m *Resolver) addFlowDomain(ctx context.Context, flow *mgmProto.FlowConfig) flowURL, err := url.Parse(flow.Url) if err != nil { - log.Warnf("MgmtCache: failed to parse flow URL: %v", err) + log.Warnf("failed to parse flow URL: %v", err) return } d, err := extractDomainFromURL(flowURL) if err != nil { - log.Warnf("MgmtCache: failed to extract flow domain: %v", err) + log.Warnf("failed to extract flow domain: %v", err) return } if err := m.AddDomain(ctx, d); err != nil { - log.Warnf("MgmtCache: failed to add flow domain: %v", err) + log.Warnf("failed to add flow domain: %v", err) } } @@ -303,7 +301,7 @@ func (m *Resolver) ClearCache() []domain.Domain { } m.cache = make(map[domain.Domain]CacheEntry) - log.Debugf("MgmtCache: cleared %d cached domains", len(domains)) + log.Debugf("cleared %d cached domains", len(domains)) return domains } @@ -311,7 +309,7 @@ func (m *Resolver) ClearCache() []domain.Domain { // UpdateFromNetbirdConfig updates the cache intelligently by comparing current and new configurations. // Returns domains that were removed for external deregistration. func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) ([]domain.Domain, error) { - log.Debugf("MgmtCache: updating cache from NetbirdConfig") + log.Debugf("updating cache from NetbirdConfig") currentDomains := m.GetCachedDomains() newDomains := m.extractDomainsFromConfig(config) @@ -333,19 +331,86 @@ func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto m.mutex.Lock() for _, domainToRemove := range removedDomains { delete(m.cache, domainToRemove) - log.Debugf("MgmtCache: removed domain=%s from cache", domainToRemove.SafeString()) + log.Debugf("removed domain=%s from cache", domainToRemove.SafeString()) } m.mutex.Unlock() for _, newDomain := range newDomains { if err := m.AddDomain(ctx, newDomain); err != nil { - log.Warnf("MgmtCache: failed to add/update domain=%s: %v", newDomain.SafeString(), err) + log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err) } } return removedDomains, nil } +// UpdateFromServerDomains updates the cache using the simplified ServerDomains struct +func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) ([]domain.Domain, error) { + log.Debugf("updating cache from ServerDomains") + + currentDomains := m.GetCachedDomains() + newDomains := m.extractDomainsFromServerDomains(serverDomains) + + var removedDomains []domain.Domain + for _, currentDomain := range currentDomains { + found := false + for _, newDomain := range newDomains { + if currentDomain.SafeString() == newDomain.SafeString() { + found = true + break + } + } + if !found { + removedDomains = append(removedDomains, currentDomain) + if err := m.RemoveDomain(currentDomain); err != nil { + log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err) + } + } + } + + for _, newDomain := range newDomains { + if err := m.AddDomain(ctx, newDomain); err != nil { + log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err) + } else { + log.Debugf("added/updated management cache domain=%s", newDomain.SafeString()) + } + } + + return removedDomains, nil +} + +func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) []domain.Domain { + var domains []domain.Domain + + if serverDomains.Signal != "" { + domains = append(domains, serverDomains.Signal) + } + + for _, relay := range serverDomains.Relay { + if relay != "" { + domains = append(domains, relay) + } + } + + if serverDomains.Flow != "" { + domains = append(domains, serverDomains.Flow) + } + + for _, stun := range serverDomains.Stuns { + if stun != "" { + domains = append(domains, stun) + } + } + + for _, turn := range serverDomains.Turns { + if turn != "" { + domains = append(domains, turn) + } + } + + return domains +} + // extractDomainsFromConfig extracts all domains from a NetbirdConfig. func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []domain.Domain { if config == nil { @@ -354,26 +419,62 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do var domains []domain.Domain - if config.Signal != nil && config.Signal.Uri != "" { - if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil { + // Extract signal domain + domains = append(domains, m.extractSignalDomain(config)...) + + // Extract relay domains + domains = append(domains, m.extractRelayDomains(config)...) + + // Extract flow domain + domains = append(domains, m.extractFlowDomain(config)...) + + // Extract STUN domains + domains = append(domains, m.extractSTUNDomains(config)...) + + // Extract TURN domains + domains = append(domains, m.extractTURNDomains(config)...) + + return domains +} + +func (m *Resolver) extractSignalDomain(config *mgmProto.NetbirdConfig) []domain.Domain { + if config.Signal == nil || config.Signal.Uri == "" { + return nil + } + + if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil { + return []domain.Domain{d} + } + return nil +} + +func (m *Resolver) extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + if config.Relay == nil { + return nil + } + + var domains []domain.Domain + for _, relayURL := range config.Relay.Urls { + if d, err := m.extractDomainFromURL(relayURL); err == nil { domains = append(domains, d) } } + return domains +} - if config.Relay != nil { - for _, relayURL := range config.Relay.Urls { - if d, err := m.extractDomainFromURL(relayURL); err == nil { - domains = append(domains, d) - } - } +func (m *Resolver) extractFlowDomain(config *mgmProto.NetbirdConfig) []domain.Domain { + if config.Flow == nil || config.Flow.Url == "" { + return nil } - if config.Flow != nil && config.Flow.Url != "" { - if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil { - domains = append(domains, d) - } + if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil { + return []domain.Domain{d} } + return nil +} +func (m *Resolver) extractSTUNDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + var domains []domain.Domain for _, stun := range config.Stuns { if stun != nil && stun.Uri != "" { if d, err := m.extractDomainFromURL(stun.Uri); err == nil { @@ -381,7 +482,11 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do } } } + return domains +} +func (m *Resolver) extractTURNDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + var domains []domain.Domain for _, turn := range config.Turns { if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" { if d, err := m.extractDomainFromURL(turn.HostConfig.Uri); err == nil { @@ -389,7 +494,6 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do } } } - return domains } @@ -424,18 +528,18 @@ func (m *Resolver) addStunDomains(ctx context.Context, stuns []*mgmProto.HostCon stunURL, err := url.Parse(stun.Uri) if err != nil { - log.Warnf("MgmtCache: failed to parse STUN URL %s: %v", stun.Uri, err) + log.Warnf("failed to parse STUN URL %s: %v", stun.Uri, err) continue } d, err := extractDomainFromURL(stunURL) if err != nil { - log.Warnf("MgmtCache: failed to extract STUN domain from %s: %v", stun.Uri, err) + log.Warnf("failed to extract STUN domain from %s: %v", stun.Uri, err) continue } if err := m.AddDomain(ctx, d); err != nil { - log.Warnf("MgmtCache: failed to add STUN domain: %v", err) + log.Warnf("failed to add STUN domain: %v", err) } } } @@ -449,18 +553,18 @@ func (m *Resolver) addTurnDomains(ctx context.Context, turns []*mgmProto.Protect turnURL, err := url.Parse(turn.HostConfig.Uri) if err != nil { - log.Warnf("MgmtCache: failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err) + log.Warnf("failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err) continue } d, err := extractDomainFromURL(turnURL) if err != nil { - log.Warnf("MgmtCache: failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err) + log.Warnf("failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err) continue } if err := m.AddDomain(ctx, d); err != nil { - log.Warnf("MgmtCache: failed to add TURN domain: %v", err) + log.Warnf("failed to add TURN domain: %v", err) } } } diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go index bd4aec99b..17da1a75c 100644 --- a/client/internal/dns/mgmt/mgmt_test.go +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -5,7 +5,6 @@ import ( "net" "net/url" "testing" - "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -114,16 +113,15 @@ func TestResolver_PopulateFromConfig(t *testing.T) { resolver := NewResolver() - mgmtURL, _ := url.Parse("https://api.netbird.io") + // Use IP address to avoid DNS resolution timeout + mgmtURL, _ := url.Parse("https://127.0.0.1") err := resolver.PopulateFromConfig(ctx, mgmtURL) assert.NoError(t, err) - // Give some time for async population - time.Sleep(100 * time.Millisecond) - + // IP addresses are rejected, so no domains should be cached domains := resolver.GetCachedDomains() - assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature + assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses") } func TestResolver_PopulateFromNetbirdConfig(t *testing.T) { @@ -132,32 +130,33 @@ func TestResolver_PopulateFromNetbirdConfig(t *testing.T) { resolver := NewResolver() + // Use IP addresses to avoid DNS resolution timeouts netbirdConfig := &mgmProto.NetbirdConfig{ Signal: &mgmProto.HostConfig{ - Uri: "https://signal.netbird.io", + Uri: "https://10.0.0.1", }, Relay: &mgmProto.RelayConfig{ Urls: []string{ - "https://relay1.netbird.io:443", - "https://relay2.netbird.io:443", + "https://10.0.0.2:443", + "https://10.0.0.3:443", }, }, Flow: &mgmProto.FlowConfig{ - Url: "https://flow.netbird.io:80", + Url: "https://10.0.0.4:80", }, Stuns: []*mgmProto.HostConfig{ - {Uri: "stun:stun1.netbird.io:3478"}, - {Uri: "stun:stun2.netbird.io:3478"}, + {Uri: "stun:10.0.0.5:3478"}, + {Uri: "stun:10.0.0.6:3478"}, }, Turns: []*mgmProto.ProtectedHostConfig{ { HostConfig: &mgmProto.HostConfig{ - Uri: "turn:turn1.netbird.io:3478", + Uri: "turn:10.0.0.7:3478", }, }, { HostConfig: &mgmProto.HostConfig{ - Uri: "turn:turn2.netbird.io:3478", + Uri: "turn:10.0.0.8:3478", }, }, }, @@ -166,11 +165,42 @@ func TestResolver_PopulateFromNetbirdConfig(t *testing.T) { err := resolver.PopulateFromNetbirdConfig(ctx, netbirdConfig) assert.NoError(t, err) - // Give some time for async population - time.Sleep(100 * time.Millisecond) - + // IP addresses are rejected, so no domains should be cached domains := resolver.GetCachedDomains() - assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature + assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses") +} + +func TestResolver_UpdateFromNetbirdConfig(t *testing.T) { + resolver := NewResolver() + + // Test with empty initial config and then add domains + initialConfig := &mgmProto.NetbirdConfig{} + + // Start with empty config + removedDomains, err := resolver.UpdateFromNetbirdConfig(context.Background(), initialConfig) + assert.NoError(t, err) + assert.Equal(t, 0, len(removedDomains), "No domains should be removed from empty cache") + + // Update to config with IP addresses instead of domains to avoid DNS resolution + // IP addresses will be rejected by extractDomainFromURL so no actual resolution happens + updatedConfig := &mgmProto.NetbirdConfig{ + Signal: &mgmProto.HostConfig{ + Uri: "https://127.0.0.1", + }, + Flow: &mgmProto.FlowConfig{ + Url: "https://192.168.1.1:80", + }, + } + + removedDomains, err = resolver.UpdateFromNetbirdConfig(context.Background(), updatedConfig) + assert.NoError(t, err) + + // Verify the method completes successfully without DNS timeouts + assert.GreaterOrEqual(t, len(removedDomains), 0, "Should not error on config update") + + // Verify no domains were actually added since IPs are rejected + domains := resolver.GetCachedDomains() + assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses") } func TestResolver_ContinueToNext(t *testing.T) { diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index c5dd6e23f..cd1cba1e6 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -5,17 +5,19 @@ import ( "github.com/miekg/dns" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" ) // MockServer is the mock instance of a dns server type MockServer struct { - InitializeFunc func() error - StopFunc func() - UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error - RegisterHandlerFunc func(domain.List, dns.Handler, int) - DeregisterHandlerFunc func(domain.List, int) + InitializeFunc func() error + StopFunc func() + UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + RegisterHandlerFunc func(domain.List, dns.Handler, int) + DeregisterHandlerFunc func(domain.List, int) + UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error } func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) { @@ -69,3 +71,10 @@ func (m *MockServer) SearchDomains() []string { // ProbeAvailability mocks implementation of ProbeAvailability from the Server interface func (m *MockServer) ProbeAvailability() { } + +func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { + if m.UpdateServerConfigFunc != nil { + return m.UpdateServerConfigFunc(domains) + } + return nil +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 1fa76c8ef..8d7ebfcac 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -16,6 +16,7 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/iface/netstack" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dns/local" "github.com/netbirdio/netbird/client/internal/dns/mgmt" "github.com/netbirdio/netbird/client/internal/dns/types" @@ -25,7 +26,6 @@ import ( cProto "github.com/netbirdio/netbird/client/proto" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" - mgmProto "github.com/netbirdio/netbird/management/proto" ) // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes @@ -49,6 +49,7 @@ type Server interface { OnUpdatedHostDNSServer(strings []string) SearchDomains() []string ProbeAvailability() + UpdateServerConfig(domains dnsconfig.ServerDomains) error } type nsGroupsByDomain struct { @@ -103,20 +104,23 @@ type handlerWrapper struct { type registeredHandlerMap map[types.HandlerID]handlerWrapper +// DefaultServerConfig holds configuration parameters for NewDefaultServer +type DefaultServerConfig struct { + Ctx context.Context + WgInterface WGIface + CustomAddress string + StatusRecorder *peer.Status + StateManager *statemanager.Manager + DisableSys bool + MgmtURL *url.URL + ServerDomains dnsconfig.ServerDomains +} + // NewDefaultServer returns a new dns server -func NewDefaultServer( - ctx context.Context, - wgInterface WGIface, - customAddress string, - statusRecorder *peer.Status, - stateManager *statemanager.Manager, - disableSys bool, - mgmtURL *url.URL, - netbirdConfig *mgmProto.NetbirdConfig, -) (*DefaultServer, error) { +func NewDefaultServer(config DefaultServerConfig) (*DefaultServer, error) { var addrPort *netip.AddrPort - if customAddress != "" { - parsedAddrPort, err := netip.ParseAddrPort(customAddress) + if config.CustomAddress != "" { + parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress) if err != nil { return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err) } @@ -124,31 +128,23 @@ func NewDefaultServer( } var dnsService service - if wgInterface.IsUserspaceBind() { - dnsService = NewServiceViaMemory(wgInterface) + if config.WgInterface.IsUserspaceBind() { + dnsService = NewServiceViaMemory(config.WgInterface) } else { - dnsService = newServiceViaListener(wgInterface, addrPort) + dnsService = newServiceViaListener(config.WgInterface, addrPort) } - server := newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys) + server := newDefaultServer(config.Ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys) - // Pre-populate management cache with management URL - if mgmtURL != nil && server.mgmtCacheResolver != nil { - if err := server.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL); err != nil { + if config.MgmtURL != nil && server.mgmtCacheResolver != nil { + if err := server.mgmtCacheResolver.PopulateFromConfig(config.Ctx, config.MgmtURL); err != nil { log.Warnf("Failed to populate management cache from management URL: %v", err) } } - // Pre-populate management cache with NetbirdConfig domains - if netbirdConfig != nil && server.mgmtCacheResolver != nil { - if err := server.mgmtCacheResolver.PopulateFromNetbirdConfig(ctx, netbirdConfig); err != nil { - log.Warnf("Failed to populate management cache from NetbirdConfig: %v", err) - } - - // Register newly populated domains - domains := server.mgmtCacheResolver.GetCachedDomains() - if len(domains) > 0 { - server.RegisterHandler(domains, server.mgmtCacheResolver, PriorityMgmtCache) + if server.mgmtCacheResolver != nil { + if err := server.UpdateServerConfig(config.ServerDomains); err != nil { + log.Warnf("Failed to populate management cache from ServerDomains: %v", err) } } @@ -220,19 +216,11 @@ func newDefaultServer( mgmtCacheResolver: mgmtCacheResolver, } - // Register cached domains with the handler chain - registerMgmtCacheDomains := func() { - domains := mgmtCacheResolver.GetCachedDomains() - if len(domains) > 0 { - defaultServer.RegisterHandler(domains, mgmtCacheResolver, PriorityMgmtCache) - } + domains := mgmtCacheResolver.GetCachedDomains() + if len(domains) > 0 { + defaultServer.RegisterHandler(domains, mgmtCacheResolver, PriorityMgmtCache) } - // Register any pre-populated domains from management cache - registerMgmtCacheDomains() - - // Management cache resolver will be registered for specific domains when they are added - // register with root zone, handler chain takes care of the routing dnsService.RegisterMux(".", handlerChain) @@ -352,7 +340,6 @@ func (s *DefaultServer) Stop() { } } - s.service.Stop() maps.Clear(s.extraDomains) @@ -368,15 +355,6 @@ func (s *DefaultServer) PopulateMgmtCacheFromConfig(mgmtURL *url.URL) error { return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL) } -// PopulateMgmtCacheFromNetbirdConfig populates the management cache with domains from the netbird configuration -func (s *DefaultServer) PopulateMgmtCacheFromNetbirdConfig(config *mgmProto.NetbirdConfig) error { - if s.mgmtCacheResolver == nil { - return fmt.Errorf("management cache resolver not initialized") - } - - log.Debug("populating management cache from netbird configuration") - return s.mgmtCacheResolver.PopulateFromNetbirdConfig(s.ctx, config) -} // OnUpdatedHostDNSServer update the DNS servers addresses for root zones // It will be applied if the mgm server do not enforce DNS settings for root zone @@ -476,6 +454,29 @@ func (s *DefaultServer) ProbeAvailability() { wg.Wait() } +func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { + s.mux.Lock() + defer s.mux.Unlock() + + if s.mgmtCacheResolver != nil { + removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains) + if err != nil { + return fmt.Errorf("update management cache resolver: %w", err) + } + + if len(removedDomains) > 0 { + s.DeregisterHandler(removedDomains, PriorityMgmtCache) + } + + newDomains := s.mgmtCacheResolver.GetCachedDomains() + if len(newDomains) > 0 { + s.RegisterHandler(newDomains, s.mgmtCacheResolver, PriorityMgmtCache) + } + } + + return nil +} + func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // is the service should be Disabled, we stop the listener or fake resolver // and proceed with a regular update to clean up the handlers and records diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 4a806be3f..4892fb3af 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" pfmock "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/wgaddr" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dns/local" "github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/types" @@ -363,7 +364,16 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil) + dnsServer, err := NewDefaultServer(DefaultServerConfig{ + Ctx: context.Background(), + WgInterface: wgIface, + CustomAddress: "", + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + MgmtURL: nil, + ServerDomains: dnsconfig.ServerDomains{}, + }) if err != nil { t.Fatal(err) } @@ -473,7 +483,16 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil) + dnsServer, err := NewDefaultServer(DefaultServerConfig{ + Ctx: context.Background(), + WgInterface: wgIface, + CustomAddress: "", + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + MgmtURL: nil, + ServerDomains: dnsconfig.ServerDomains{}, + }) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -575,7 +594,16 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false, nil, nil) + dnsServer, err := NewDefaultServer(DefaultServerConfig{ + Ctx: context.Background(), + WgInterface: &mocWGIface{}, + CustomAddress: testCase.addrPort, + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + MgmtURL: nil, + ServerDomains: dnsconfig.ServerDomains{}, + }) if err != nil { t.Fatalf("%v", err) } diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index ddc9f81a1..863414537 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "net" + "net/netip" "slices" "strings" "sync" @@ -26,10 +27,10 @@ import ( ) const ( - UpstreamTimeout = 15 * time.Second + UpstreamTimeout = 4 * time.Second // ClientTimeout is the timeout for the dns.Client. // Set longer than UpstreamTimeout to ensure context timeout takes precedence - ClientTimeout = 30 * time.Second + ClientTimeout = 5 * time.Second reactivatePeriod = 30 * time.Second probeTimeout = 2 * time.Second @@ -105,52 +106,111 @@ func (u *upstreamResolverBase) Stop() { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { requestID := GenerateRequestID() logger := log.WithField("request_id", requestID) - var err error logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + + u.prepareRequest(r) + + if u.isContextDone(logger) { + return + } + + if u.tryUpstreamServers(w, r, logger) { + return + } + + u.writeErrorResponse(w, r, logger) +} + +func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { if r.Extra == nil { r.MsgHdr.AuthenticatedData = true } +} +func (u *upstreamResolverBase) isContextDone(logger *log.Entry) bool { select { case <-u.ctx.Done(): logger.Tracef("%s has been stopped", u) - return + return true default: + return false + } +} + +func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool { + timeout := u.upstreamTimeout + if len(u.upstreamServers) > 1 { + maxTotal := 5 * time.Second + minPerUpstream := 2 * time.Second + scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers)) + if scaledTimeout > minPerUpstream { + timeout = scaledTimeout + } else { + timeout = minPerUpstream + } } for _, upstream := range u.upstreamServers { - var rm *dns.Msg - var t time.Duration - - func() { - ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) - defer cancel() - rm, t, err = u.upstreamClient.exchange(ctx, upstream, r) - }() - - if err != nil { - if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { - logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) - continue - } - logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) - continue + if u.queryUpstream(w, r, upstream, timeout, logger) { + return true } + } + return false +} - if rm == nil || !rm.Response { - logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) - continue - } +func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream string, timeout time.Duration, logger *log.Entry) bool { + var rm *dns.Msg + var t time.Duration + var err error - u.successCount.Add(1) - logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) + var startTime time.Time + func() { + ctx, cancel := context.WithTimeout(u.ctx, timeout) + defer cancel() + startTime = time.Now() + rm, t, err = u.upstreamClient.exchange(ctx, upstream, r) + }() - if err = w.WriteMsg(rm); err != nil { - logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) - } + if err != nil { + u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger) + return false + } + + if rm == nil || !rm.Response { + logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) + return false + } + + return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger) +} + +func (u *upstreamResolverBase) handleUpstreamError(err error, upstream, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) { + if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) { + logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err) return } + + elapsed := time.Since(startTime) + timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout) + if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" { + timeoutMsg += " " + peerInfo + } + timeoutMsg += fmt.Sprintf(" - error: %v", err) + logger.Warnf(timeoutMsg) +} + +func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream, domain string, t time.Duration, logger *log.Entry) bool { + u.successCount.Add(1) + logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain) + + if err := w.WriteMsg(rm); err != nil { + logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err) + } + return true +} + +func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) { logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) m := new(dns.Msg) @@ -355,3 +415,97 @@ func GenerateRequestID() string { } return hex.EncodeToString(bytes) } + +// FormatPeerStatus formats peer connection status information for debugging DNS timeouts +func FormatPeerStatus(peerState *peer.State) string { + isConnected := peerState.ConnStatus == peer.StatusConnected + hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() && + time.Since(peerState.LastWireguardHandshake) < 3*time.Minute + + statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP) + + switch { + case !isConnected: + statusInfo += " DISCONNECTED" + case !hasRecentHandshake: + statusInfo += " NO_RECENT_HANDSHAKE" + default: + statusInfo += " connected" + } + + if !peerState.LastWireguardHandshake.IsZero() { + timeSinceHandshake := time.Since(peerState.LastWireguardHandshake) + statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second)) + } else { + statusInfo += " no_handshake" + } + + if peerState.Relayed { + statusInfo += " via_relay" + } + + if peerState.Latency > 0 { + statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency) + } + + return statusInfo +} + +// findPeerForIP finds which peer handles the given IP address +func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State { + if statusRecorder == nil { + return nil + } + + fullStatus := statusRecorder.GetFullStatus() + var bestMatch *peer.State + var bestPrefixLen int + + for _, peerState := range fullStatus.Peers { + routes := peerState.GetRoutes() + for route := range routes { + prefix, err := netip.ParsePrefix(route) + if err != nil { + continue + } + + if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen { + peerStateCopy := peerState + bestMatch = &peerStateCopy + bestPrefixLen = prefix.Bits() + } + } + } + + return bestMatch +} + +// parseUpstreamIP parses an upstream server address to extract the IP +func parseUpstreamIP(upstream string) (netip.Addr, error) { + upstreamIP, err := netip.ParseAddr(upstream) + if err != nil { + if host, _, err := net.SplitHostPort(upstream); err == nil { + return netip.ParseAddr(host) + } + return netip.Addr{}, err + } + return upstreamIP, nil +} + +func (u *upstreamResolverBase) debugUpstreamTimeout(upstream string) string { + if u.statusRecorder == nil { + return "" + } + + upstreamIP, err := parseUpstreamIP(upstream) + if err != nil { + return "" + } + + peerInfo := findPeerForIP(upstreamIP, u.statusRecorder) + if peerInfo == nil { + return "" + } + + return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo)) +} diff --git a/client/internal/engine.go b/client/internal/engine.go index 9104f042f..f4aed0b3f 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -33,6 +33,7 @@ import ( nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/netflow" @@ -696,6 +697,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return fmt.Errorf("handle the flow configuration: %w", err) } + if e.dnsServer != nil { + serverDomains := config.ExtractFromNetbirdConfig(wCfg) + if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil { + log.Warnf("Failed to update DNS server config: %v", err) + } + } + // todo update signal } @@ -1604,7 +1612,19 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbird return dnsServer, nil default: - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS, mgmtURL, netbirdConfig) + // Extract domains from NetBird configuration + serverDomains := config.ExtractFromNetbirdConfig(netbirdConfig) + + dnsServer, err := dns.NewDefaultServer(dns.DefaultServerConfig{ + Ctx: e.ctx, + WgInterface: e.wgInterface, + CustomAddress: e.config.CustomDNSAddress, + StatusRecorder: e.statusRecorder, + StateManager: e.stateManager, + DisableSys: e.config.DisableDNS, + MgmtURL: mgmtURL, + ServerDomains: serverDomains, + }) if err != nil { return nil, err } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index c7c3aeb0b..b0413b07d 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -2,11 +2,13 @@ package dnsinterceptor import ( "context" + "errors" "fmt" "net/netip" "runtime" "strings" "sync" + "time" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" @@ -26,6 +28,8 @@ import ( "github.com/netbirdio/netbird/route" ) +const dnsTimeout = 8 * time.Second + type domainMap map[domain.Domain][]netip.Prefix type internalDNATer interface { @@ -243,7 +247,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout) + client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout) if err != nil { d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err)) return @@ -254,9 +258,20 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) - reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) + ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) + defer cancel() + + startTime := time.Now() + reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream) if err != nil { - logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) + if errors.Is(err, context.DeadlineExceeded) { + elapsed := time.Since(startTime) + peerInfo := d.debugPeerTimeout(upstreamIP, peerKey) + logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v", + elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err) + } else { + logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) + } if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { logger.Errorf("failed writing DNS response: %v", err) } @@ -568,3 +583,16 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR } return } + +func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string { + if d.statusRecorder == nil { + return "" + } + + peerState, err := d.statusRecorder.GetPeer(peerKey) + if err != nil { + return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err) + } + + return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState)) +}