Add debug output for timeouts

This commit is contained in:
Viktor Liu
2025-07-09 19:00:17 +02:00
parent a1cb7b4af6
commit 629757c911
10 changed files with 736 additions and 204 deletions

View File

@@ -0,0 +1,155 @@
package config
import (
"fmt"
"net"
"net/netip"
"net/url"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
// ServerDomains represents the management server domains extracted from NetBird configuration
type ServerDomains struct {
Signal domain.Domain
Relay []domain.Domain
Flow domain.Domain
Stuns []domain.Domain
Turns []domain.Domain
}
// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration
func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
if config == nil {
return ServerDomains{}
}
domains := ServerDomains{}
domains.Signal = extractSignalDomain(config)
domains.Relay = extractRelayDomains(config)
domains.Flow = extractFlowDomain(config)
domains.Stuns = extractStunDomains(config)
domains.Turns = extractTurnDomains(config)
return domains
}
// extractValidDomain extracts a valid domain from a URL, filtering out IP addresses
func extractValidDomain(rawURL string) (domain.Domain, error) {
parsedURL, err := url.Parse(rawURL)
if err != nil {
// If URL parsing fails, it might be a raw host:port, try parsing as such
if host, _, err := net.SplitHostPort(rawURL); err == nil {
return extractDomainFromHost(host)
}
// If not host:port, try as raw hostname
return extractDomainFromHost(rawURL)
}
host := parsedURL.Hostname()
if host == "" {
return "", fmt.Errorf("no hostname in URL")
}
return extractDomainFromHost(host)
}
// extractDomainFromHost extracts domain from a host string, filtering out IP addresses
func extractDomainFromHost(host string) (domain.Domain, error) {
if host == "" {
return "", fmt.Errorf("empty host")
}
if _, err := netip.ParseAddr(host); err == nil {
return "", fmt.Errorf("IP address not allowed: %s", host)
}
d, err := domain.FromString(host)
if err != nil {
return "", fmt.Errorf("invalid domain: %v", err)
}
return d, nil
}
// extractSingleDomain extracts a single domain from a URL with error logging
func extractSingleDomain(url, serviceType string) domain.Domain {
if url == "" {
return ""
}
d, err := extractValidDomain(url)
if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err)
return ""
}
return d
}
// extractMultipleDomains extracts multiple domains from URLs with error logging
func extractMultipleDomains(urls []string, serviceType string) []domain.Domain {
var domains []domain.Domain
for _, url := range urls {
if url == "" {
continue
}
d, err := extractValidDomain(url)
if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err)
continue
}
domains = append(domains, d)
}
return domains
}
// extractSignalDomain extracts the signal domain from NetBird configuration.
func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain {
if config.Signal != nil {
return extractSingleDomain(config.Signal.Uri, "signal")
}
return ""
}
// extractRelayDomains extracts relay server domains from NetBird configuration.
func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Relay != nil {
return extractMultipleDomains(config.Relay.Urls, "relay")
}
return nil
}
// extractFlowDomain extracts the traffic flow domain from NetBird configuration.
func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain {
if config.Flow != nil {
return extractSingleDomain(config.Flow.Url, "flow")
}
return ""
}
// extractStunDomains extracts STUN server domains from NetBird configuration.
func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var urls []string
for _, stun := range config.Stuns {
if stun != nil && stun.Uri != "" {
urls = append(urls, stun.Uri)
}
}
return extractMultipleDomains(urls, "STUN")
}
// extractTurnDomains extracts TURN server domains from NetBird configuration.
func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var urls []string
for _, turn := range config.Turns {
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
urls = append(urls, turn.HostConfig.Uri)
}
}
return extractMultipleDomains(urls, "TURN")
}

View File

@@ -182,7 +182,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
// Only log continue for non-management cache handlers to reduce noise
if entry.Priority != PriorityMgmtCache {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
}
continue
}
return

View File

@@ -3,6 +3,7 @@ package mgmt
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
@@ -13,6 +14,7 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
@@ -27,14 +29,12 @@ type CacheEntry struct {
type Resolver struct {
cache map[domain.Domain]CacheEntry
mutex sync.RWMutex
systemResolver *net.Resolver
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
cache: make(map[domain.Domain]CacheEntry),
systemResolver: net.DefaultResolver,
}
}
@@ -58,22 +58,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
log.Tracef("MgmtCache: checking cache for domain=%s type=%s", qname, dns.TypeToString[question.Qtype])
m.mutex.RLock()
parsedDomain, err := domain.FromString(qname)
if err != nil {
log.Tracef("MgmtCache: invalid domain format: %s", qname)
m.mutex.RUnlock()
m.continueToNext(w, r)
return
}
entry, found := m.cache[parsedDomain]
domainKey := domain.Domain(qname)
entry, found := m.cache[domainKey]
m.mutex.RUnlock()
if !found {
log.Tracef("MgmtCache: no cache entry found for domain=%s", qname)
m.continueToNext(w, r)
return
}
@@ -91,7 +81,6 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
if len(records) == 0 {
log.Tracef("MgmtCache: no %s records for domain=%s", dns.TypeToString[question.Qtype], parsedDomain.SafeString())
m.continueToNext(w, r)
return
}
@@ -102,10 +91,10 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp.Answer = append(resp.Answer, rrCopy)
}
log.Tracef("MgmtCache: serving %d cached records for domain=%s", len(resp.Answer), parsedDomain.SafeString())
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), domainKey.SafeString())
if err := w.WriteMsg(resp); err != nil {
log.Errorf("MgmtCache: failed to write response: %v", err)
log.Errorf("failed to write response: %v", err)
}
}
@@ -120,60 +109,59 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
log.Errorf("MgmtCache: failed to write continue signal: %v", err)
log.Errorf("failed to write continue signal: %v", err)
}
}
// AddDomain manually adds a domain to cache by resolving it.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
log.Debugf("MgmtCache: adding domain=%s to cache", d.SafeString())
log.Debugf("adding domain=%s to cache", d.SafeString())
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
var aRecords, aaaaRecords []dns.RR
if ips, err := m.systemResolver.LookupNetIP(ctx, "ip", d.PunycodeString()); err == nil {
for _, ip := range ips {
if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: d.PunycodeString() + ".",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: d.PunycodeString() + ".",
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
}
}
m.mutex.Lock()
m.cache[d] = CacheEntry{
ARecords: aRecords,
AAAARecords: aaaaRecords,
}
m.mutex.Unlock()
log.Debugf("MgmtCache: added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
} else {
log.Warnf("MgmtCache: failed to resolve domain=%s: %v", d.SafeString(), err)
return err
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
if err != nil {
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
}
var aRecords, aaaaRecords []dns.RR
for _, ip := range ips {
if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: d.PunycodeString() + ".",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: d.PunycodeString() + ".",
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
}
}
m.mutex.Lock()
m.cache[d] = CacheEntry{
ARecords: aRecords,
AAAARecords: aaaaRecords,
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
return nil
}
@@ -182,7 +170,7 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err
if mgmtURL != nil {
if d, err := extractDomainFromURL(mgmtURL); err == nil {
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add management domain: %v", err)
log.Warnf("failed to add management domain: %v", err)
}
}
}
@@ -190,6 +178,16 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err
return nil
}
// RemoveDomain removes a domain from the cache.
func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock()
defer m.mutex.Unlock()
delete(m.cache, d)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
}
// PopulateFromNetbirdConfig extracts and caches domains from the netbird config.
func (m *Resolver) PopulateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) error {
if config == nil {
@@ -216,19 +214,19 @@ func (m *Resolver) addSignalDomain(ctx context.Context, signal *mgmProto.HostCon
// If parsing fails, it might be a raw host:port, try adding a scheme
signalURL, err = url.Parse("https://" + signal.Uri)
if err != nil {
log.Warnf("MgmtCache: failed to parse signal URL: %v", err)
log.Warnf("failed to parse signal URL: %v", err)
return
}
}
d, err := extractDomainFromURL(signalURL)
if err != nil {
log.Warnf("MgmtCache: failed to extract signal domain: %v", err)
log.Warnf("failed to extract signal domain: %v", err)
return
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add signal domain: %v", err)
log.Warnf("failed to add signal domain: %v", err)
}
}
@@ -241,18 +239,18 @@ func (m *Resolver) addRelayDomains(ctx context.Context, relay *mgmProto.RelayCon
for _, relayAddr := range relay.Urls {
relayURL, err := url.Parse(relayAddr)
if err != nil {
log.Warnf("MgmtCache: failed to parse relay URL %s: %v", relayAddr, err)
log.Warnf("failed to parse relay URL %s: %v", relayAddr, err)
continue
}
d, err := extractDomainFromURL(relayURL)
if err != nil {
log.Warnf("MgmtCache: failed to extract relay domain from %s: %v", relayAddr, err)
log.Warnf("failed to extract relay domain from %s: %v", relayAddr, err)
continue
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add relay domain: %v", err)
log.Warnf("failed to add relay domain: %v", err)
}
}
}
@@ -265,18 +263,18 @@ func (m *Resolver) addFlowDomain(ctx context.Context, flow *mgmProto.FlowConfig)
flowURL, err := url.Parse(flow.Url)
if err != nil {
log.Warnf("MgmtCache: failed to parse flow URL: %v", err)
log.Warnf("failed to parse flow URL: %v", err)
return
}
d, err := extractDomainFromURL(flowURL)
if err != nil {
log.Warnf("MgmtCache: failed to extract flow domain: %v", err)
log.Warnf("failed to extract flow domain: %v", err)
return
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add flow domain: %v", err)
log.Warnf("failed to add flow domain: %v", err)
}
}
@@ -303,7 +301,7 @@ func (m *Resolver) ClearCache() []domain.Domain {
}
m.cache = make(map[domain.Domain]CacheEntry)
log.Debugf("MgmtCache: cleared %d cached domains", len(domains))
log.Debugf("cleared %d cached domains", len(domains))
return domains
}
@@ -311,7 +309,7 @@ func (m *Resolver) ClearCache() []domain.Domain {
// UpdateFromNetbirdConfig updates the cache intelligently by comparing current and new configurations.
// Returns domains that were removed for external deregistration.
func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) ([]domain.Domain, error) {
log.Debugf("MgmtCache: updating cache from NetbirdConfig")
log.Debugf("updating cache from NetbirdConfig")
currentDomains := m.GetCachedDomains()
newDomains := m.extractDomainsFromConfig(config)
@@ -333,19 +331,86 @@ func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto
m.mutex.Lock()
for _, domainToRemove := range removedDomains {
delete(m.cache, domainToRemove)
log.Debugf("MgmtCache: removed domain=%s from cache", domainToRemove.SafeString())
log.Debugf("removed domain=%s from cache", domainToRemove.SafeString())
}
m.mutex.Unlock()
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("MgmtCache: failed to add/update domain=%s: %v", newDomain.SafeString(), err)
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
}
}
return removedDomains, nil
}
// UpdateFromServerDomains updates the cache using the simplified ServerDomains struct
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) ([]domain.Domain, error) {
log.Debugf("updating cache from ServerDomains")
currentDomains := m.GetCachedDomains()
newDomains := m.extractDomainsFromServerDomains(serverDomains)
var removedDomains []domain.Domain
for _, currentDomain := range currentDomains {
found := false
for _, newDomain := range newDomains {
if currentDomain.SafeString() == newDomain.SafeString() {
found = true
break
}
}
if !found {
removedDomains = append(removedDomains, currentDomain)
if err := m.RemoveDomain(currentDomain); err != nil {
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
}
}
}
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
}
}
return removedDomains, nil
}
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) []domain.Domain {
var domains []domain.Domain
if serverDomains.Signal != "" {
domains = append(domains, serverDomains.Signal)
}
for _, relay := range serverDomains.Relay {
if relay != "" {
domains = append(domains, relay)
}
}
if serverDomains.Flow != "" {
domains = append(domains, serverDomains.Flow)
}
for _, stun := range serverDomains.Stuns {
if stun != "" {
domains = append(domains, stun)
}
}
for _, turn := range serverDomains.Turns {
if turn != "" {
domains = append(domains, turn)
}
}
return domains
}
// extractDomainsFromConfig extracts all domains from a NetbirdConfig.
func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []domain.Domain {
if config == nil {
@@ -354,26 +419,62 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
var domains []domain.Domain
if config.Signal != nil && config.Signal.Uri != "" {
if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil {
// Extract signal domain
domains = append(domains, m.extractSignalDomain(config)...)
// Extract relay domains
domains = append(domains, m.extractRelayDomains(config)...)
// Extract flow domain
domains = append(domains, m.extractFlowDomain(config)...)
// Extract STUN domains
domains = append(domains, m.extractSTUNDomains(config)...)
// Extract TURN domains
domains = append(domains, m.extractTURNDomains(config)...)
return domains
}
func (m *Resolver) extractSignalDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Signal == nil || config.Signal.Uri == "" {
return nil
}
if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil {
return []domain.Domain{d}
}
return nil
}
func (m *Resolver) extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Relay == nil {
return nil
}
var domains []domain.Domain
for _, relayURL := range config.Relay.Urls {
if d, err := m.extractDomainFromURL(relayURL); err == nil {
domains = append(domains, d)
}
}
return domains
}
if config.Relay != nil {
for _, relayURL := range config.Relay.Urls {
if d, err := m.extractDomainFromURL(relayURL); err == nil {
domains = append(domains, d)
}
}
func (m *Resolver) extractFlowDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Flow == nil || config.Flow.Url == "" {
return nil
}
if config.Flow != nil && config.Flow.Url != "" {
if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil {
domains = append(domains, d)
}
if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil {
return []domain.Domain{d}
}
return nil
}
func (m *Resolver) extractSTUNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var domains []domain.Domain
for _, stun := range config.Stuns {
if stun != nil && stun.Uri != "" {
if d, err := m.extractDomainFromURL(stun.Uri); err == nil {
@@ -381,7 +482,11 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
}
}
}
return domains
}
func (m *Resolver) extractTURNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var domains []domain.Domain
for _, turn := range config.Turns {
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
if d, err := m.extractDomainFromURL(turn.HostConfig.Uri); err == nil {
@@ -389,7 +494,6 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
}
}
}
return domains
}
@@ -424,18 +528,18 @@ func (m *Resolver) addStunDomains(ctx context.Context, stuns []*mgmProto.HostCon
stunURL, err := url.Parse(stun.Uri)
if err != nil {
log.Warnf("MgmtCache: failed to parse STUN URL %s: %v", stun.Uri, err)
log.Warnf("failed to parse STUN URL %s: %v", stun.Uri, err)
continue
}
d, err := extractDomainFromURL(stunURL)
if err != nil {
log.Warnf("MgmtCache: failed to extract STUN domain from %s: %v", stun.Uri, err)
log.Warnf("failed to extract STUN domain from %s: %v", stun.Uri, err)
continue
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add STUN domain: %v", err)
log.Warnf("failed to add STUN domain: %v", err)
}
}
}
@@ -449,18 +553,18 @@ func (m *Resolver) addTurnDomains(ctx context.Context, turns []*mgmProto.Protect
turnURL, err := url.Parse(turn.HostConfig.Uri)
if err != nil {
log.Warnf("MgmtCache: failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err)
log.Warnf("failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err)
continue
}
d, err := extractDomainFromURL(turnURL)
if err != nil {
log.Warnf("MgmtCache: failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err)
log.Warnf("failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err)
continue
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add TURN domain: %v", err)
log.Warnf("failed to add TURN domain: %v", err)
}
}
}

View File

@@ -5,7 +5,6 @@ import (
"net"
"net/url"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -114,16 +113,15 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
resolver := NewResolver()
mgmtURL, _ := url.Parse("https://api.netbird.io")
// Use IP address to avoid DNS resolution timeout
mgmtURL, _ := url.Parse("https://127.0.0.1")
err := resolver.PopulateFromConfig(ctx, mgmtURL)
assert.NoError(t, err)
// Give some time for async population
time.Sleep(100 * time.Millisecond)
// IP addresses are rejected, so no domains should be cached
domains := resolver.GetCachedDomains()
assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
}
func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
@@ -132,32 +130,33 @@ func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
resolver := NewResolver()
// Use IP addresses to avoid DNS resolution timeouts
netbirdConfig := &mgmProto.NetbirdConfig{
Signal: &mgmProto.HostConfig{
Uri: "https://signal.netbird.io",
Uri: "https://10.0.0.1",
},
Relay: &mgmProto.RelayConfig{
Urls: []string{
"https://relay1.netbird.io:443",
"https://relay2.netbird.io:443",
"https://10.0.0.2:443",
"https://10.0.0.3:443",
},
},
Flow: &mgmProto.FlowConfig{
Url: "https://flow.netbird.io:80",
Url: "https://10.0.0.4:80",
},
Stuns: []*mgmProto.HostConfig{
{Uri: "stun:stun1.netbird.io:3478"},
{Uri: "stun:stun2.netbird.io:3478"},
{Uri: "stun:10.0.0.5:3478"},
{Uri: "stun:10.0.0.6:3478"},
},
Turns: []*mgmProto.ProtectedHostConfig{
{
HostConfig: &mgmProto.HostConfig{
Uri: "turn:turn1.netbird.io:3478",
Uri: "turn:10.0.0.7:3478",
},
},
{
HostConfig: &mgmProto.HostConfig{
Uri: "turn:turn2.netbird.io:3478",
Uri: "turn:10.0.0.8:3478",
},
},
},
@@ -166,11 +165,42 @@ func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
err := resolver.PopulateFromNetbirdConfig(ctx, netbirdConfig)
assert.NoError(t, err)
// Give some time for async population
time.Sleep(100 * time.Millisecond)
// IP addresses are rejected, so no domains should be cached
domains := resolver.GetCachedDomains()
assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
}
func TestResolver_UpdateFromNetbirdConfig(t *testing.T) {
resolver := NewResolver()
// Test with empty initial config and then add domains
initialConfig := &mgmProto.NetbirdConfig{}
// Start with empty config
removedDomains, err := resolver.UpdateFromNetbirdConfig(context.Background(), initialConfig)
assert.NoError(t, err)
assert.Equal(t, 0, len(removedDomains), "No domains should be removed from empty cache")
// Update to config with IP addresses instead of domains to avoid DNS resolution
// IP addresses will be rejected by extractDomainFromURL so no actual resolution happens
updatedConfig := &mgmProto.NetbirdConfig{
Signal: &mgmProto.HostConfig{
Uri: "https://127.0.0.1",
},
Flow: &mgmProto.FlowConfig{
Url: "https://192.168.1.1:80",
},
}
removedDomains, err = resolver.UpdateFromNetbirdConfig(context.Background(), updatedConfig)
assert.NoError(t, err)
// Verify the method completes successfully without DNS timeouts
assert.GreaterOrEqual(t, len(removedDomains), 0, "Should not error on config update")
// Verify no domains were actually added since IPs are rejected
domains := resolver.GetCachedDomains()
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
}
func TestResolver_ContinueToNext(t *testing.T) {

View File

@@ -5,17 +5,19 @@ import (
"github.com/miekg/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
)
// MockServer is the mock instance of a dns server
type MockServer struct {
InitializeFunc func() error
StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func(domain.List, dns.Handler, int)
DeregisterHandlerFunc func(domain.List, int)
InitializeFunc func() error
StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func(domain.List, dns.Handler, int)
DeregisterHandlerFunc func(domain.List, int)
UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error
}
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
@@ -69,3 +71,10 @@ func (m *MockServer) SearchDomains() []string {
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
func (m *MockServer) ProbeAvailability() {
}
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
if m.UpdateServerConfigFunc != nil {
return m.UpdateServerConfigFunc(domains)
}
return nil
}

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/iface/netstack"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/mgmt"
"github.com/netbirdio/netbird/client/internal/dns/types"
@@ -25,7 +26,6 @@ import (
cProto "github.com/netbirdio/netbird/client/proto"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
@@ -49,6 +49,7 @@ type Server interface {
OnUpdatedHostDNSServer(strings []string)
SearchDomains() []string
ProbeAvailability()
UpdateServerConfig(domains dnsconfig.ServerDomains) error
}
type nsGroupsByDomain struct {
@@ -103,20 +104,23 @@ type handlerWrapper struct {
type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct {
Ctx context.Context
WgInterface WGIface
CustomAddress string
StatusRecorder *peer.Status
StateManager *statemanager.Manager
DisableSys bool
MgmtURL *url.URL
ServerDomains dnsconfig.ServerDomains
}
// NewDefaultServer returns a new dns server
func NewDefaultServer(
ctx context.Context,
wgInterface WGIface,
customAddress string,
statusRecorder *peer.Status,
stateManager *statemanager.Manager,
disableSys bool,
mgmtURL *url.URL,
netbirdConfig *mgmProto.NetbirdConfig,
) (*DefaultServer, error) {
func NewDefaultServer(config DefaultServerConfig) (*DefaultServer, error) {
var addrPort *netip.AddrPort
if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
if config.CustomAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress)
if err != nil {
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
}
@@ -124,31 +128,23 @@ func NewDefaultServer(
}
var dnsService service
if wgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(wgInterface)
if config.WgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(config.WgInterface)
} else {
dnsService = newServiceViaListener(wgInterface, addrPort)
dnsService = newServiceViaListener(config.WgInterface, addrPort)
}
server := newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys)
server := newDefaultServer(config.Ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
// Pre-populate management cache with management URL
if mgmtURL != nil && server.mgmtCacheResolver != nil {
if err := server.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL); err != nil {
if config.MgmtURL != nil && server.mgmtCacheResolver != nil {
if err := server.mgmtCacheResolver.PopulateFromConfig(config.Ctx, config.MgmtURL); err != nil {
log.Warnf("Failed to populate management cache from management URL: %v", err)
}
}
// Pre-populate management cache with NetbirdConfig domains
if netbirdConfig != nil && server.mgmtCacheResolver != nil {
if err := server.mgmtCacheResolver.PopulateFromNetbirdConfig(ctx, netbirdConfig); err != nil {
log.Warnf("Failed to populate management cache from NetbirdConfig: %v", err)
}
// Register newly populated domains
domains := server.mgmtCacheResolver.GetCachedDomains()
if len(domains) > 0 {
server.RegisterHandler(domains, server.mgmtCacheResolver, PriorityMgmtCache)
if server.mgmtCacheResolver != nil {
if err := server.UpdateServerConfig(config.ServerDomains); err != nil {
log.Warnf("Failed to populate management cache from ServerDomains: %v", err)
}
}
@@ -220,19 +216,11 @@ func newDefaultServer(
mgmtCacheResolver: mgmtCacheResolver,
}
// Register cached domains with the handler chain
registerMgmtCacheDomains := func() {
domains := mgmtCacheResolver.GetCachedDomains()
if len(domains) > 0 {
defaultServer.RegisterHandler(domains, mgmtCacheResolver, PriorityMgmtCache)
}
domains := mgmtCacheResolver.GetCachedDomains()
if len(domains) > 0 {
defaultServer.RegisterHandler(domains, mgmtCacheResolver, PriorityMgmtCache)
}
// Register any pre-populated domains from management cache
registerMgmtCacheDomains()
// Management cache resolver will be registered for specific domains when they are added
// register with root zone, handler chain takes care of the routing
dnsService.RegisterMux(".", handlerChain)
@@ -352,7 +340,6 @@ func (s *DefaultServer) Stop() {
}
}
s.service.Stop()
maps.Clear(s.extraDomains)
@@ -368,15 +355,6 @@ func (s *DefaultServer) PopulateMgmtCacheFromConfig(mgmtURL *url.URL) error {
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
}
// PopulateMgmtCacheFromNetbirdConfig populates the management cache with domains from the netbird configuration
func (s *DefaultServer) PopulateMgmtCacheFromNetbirdConfig(config *mgmProto.NetbirdConfig) error {
if s.mgmtCacheResolver == nil {
return fmt.Errorf("management cache resolver not initialized")
}
log.Debug("populating management cache from netbird configuration")
return s.mgmtCacheResolver.PopulateFromNetbirdConfig(s.ctx, config)
}
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone
@@ -476,6 +454,29 @@ func (s *DefaultServer) ProbeAvailability() {
wg.Wait()
}
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
s.mux.Lock()
defer s.mux.Unlock()
if s.mgmtCacheResolver != nil {
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains)
if err != nil {
return fmt.Errorf("update management cache resolver: %w", err)
}
if len(removedDomains) > 0 {
s.DeregisterHandler(removedDomains, PriorityMgmtCache)
}
newDomains := s.mgmtCacheResolver.GetCachedDomains()
if len(newDomains) > 0 {
s.RegisterHandler(newDomains, s.mgmtCacheResolver, PriorityMgmtCache)
}
}
return nil
}
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be Disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/dns/types"
@@ -363,7 +364,16 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil)
dnsServer, err := NewDefaultServer(DefaultServerConfig{
Ctx: context.Background(),
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
})
if err != nil {
t.Fatal(err)
}
@@ -473,7 +483,16 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return
}
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil)
dnsServer, err := NewDefaultServer(DefaultServerConfig{
Ctx: context.Background(),
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
})
if err != nil {
t.Errorf("create DNS server: %v", err)
return
@@ -575,7 +594,16 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false, nil, nil)
dnsServer, err := NewDefaultServer(DefaultServerConfig{
Ctx: context.Background(),
WgInterface: &mocWGIface{},
CustomAddress: testCase.addrPort,
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
})
if err != nil {
t.Fatalf("%v", err)
}

View File

@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strings"
"sync"
@@ -26,10 +27,10 @@ import (
)
const (
UpstreamTimeout = 15 * time.Second
UpstreamTimeout = 4 * time.Second
// ClientTimeout is the timeout for the dns.Client.
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
ClientTimeout = 30 * time.Second
ClientTimeout = 5 * time.Second
reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second
@@ -105,52 +106,111 @@ func (u *upstreamResolverBase) Stop() {
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
var err error
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
u.prepareRequest(r)
if u.isContextDone(logger) {
return
}
if u.tryUpstreamServers(w, r, logger) {
return
}
u.writeErrorResponse(w, r, logger)
}
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true
}
}
func (u *upstreamResolverBase) isContextDone(logger *log.Entry) bool {
select {
case <-u.ctx.Done():
logger.Tracef("%s has been stopped", u)
return
return true
default:
return false
}
}
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
minPerUpstream := 2 * time.Second
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
if scaledTimeout > minPerUpstream {
timeout = scaledTimeout
} else {
timeout = minPerUpstream
}
}
for _, upstream := range u.upstreamServers {
var rm *dns.Msg
var t time.Duration
func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
defer cancel()
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r)
}()
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue
}
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue
if u.queryUpstream(w, r, upstream, timeout, logger) {
return true
}
}
return false
}
if rm == nil || !rm.Response {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
continue
}
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream string, timeout time.Duration, logger *log.Entry) bool {
var rm *dns.Msg
var t time.Duration
var err error
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
var startTime time.Time
func() {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel()
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r)
}()
if err = w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
}
if err != nil {
u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
return false
}
if rm == nil || !rm.Response {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
return false
}
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
return
}
elapsed := time.Since(startTime)
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
timeoutMsg += " " + peerInfo
}
timeoutMsg += fmt.Sprintf(" - error: %v", err)
logger.Warnf(timeoutMsg)
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
}
return true
}
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg)
@@ -355,3 +415,97 @@ func GenerateRequestID() string {
}
return hex.EncodeToString(bytes)
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected
hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() &&
time.Since(peerState.LastWireguardHandshake) < 3*time.Minute
statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP)
switch {
case !isConnected:
statusInfo += " DISCONNECTED"
case !hasRecentHandshake:
statusInfo += " NO_RECENT_HANDSHAKE"
default:
statusInfo += " connected"
}
if !peerState.LastWireguardHandshake.IsZero() {
timeSinceHandshake := time.Since(peerState.LastWireguardHandshake)
statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second))
} else {
statusInfo += " no_handshake"
}
if peerState.Relayed {
statusInfo += " via_relay"
}
if peerState.Latency > 0 {
statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency)
}
return statusInfo
}
// findPeerForIP finds which peer handles the given IP address
func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
if statusRecorder == nil {
return nil
}
fullStatus := statusRecorder.GetFullStatus()
var bestMatch *peer.State
var bestPrefixLen int
for _, peerState := range fullStatus.Peers {
routes := peerState.GetRoutes()
for route := range routes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
continue
}
if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen {
peerStateCopy := peerState
bestMatch = &peerStateCopy
bestPrefixLen = prefix.Bits()
}
}
}
return bestMatch
}
// parseUpstreamIP parses an upstream server address to extract the IP
func parseUpstreamIP(upstream string) (netip.Addr, error) {
upstreamIP, err := netip.ParseAddr(upstream)
if err != nil {
if host, _, err := net.SplitHostPort(upstream); err == nil {
return netip.ParseAddr(host)
}
return netip.Addr{}, err
}
return upstreamIP, nil
}
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream string) string {
if u.statusRecorder == nil {
return ""
}
upstreamIP, err := parseUpstreamIP(upstream)
if err != nil {
return ""
}
peerInfo := findPeerForIP(upstreamIP, u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
}

View File

@@ -33,6 +33,7 @@ import (
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow"
@@ -696,6 +697,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return fmt.Errorf("handle the flow configuration: %w", err)
}
if e.dnsServer != nil {
serverDomains := config.ExtractFromNetbirdConfig(wCfg)
if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
}
// todo update signal
}
@@ -1604,7 +1612,19 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbird
return dnsServer, nil
default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS, mgmtURL, netbirdConfig)
// Extract domains from NetBird configuration
serverDomains := config.ExtractFromNetbirdConfig(netbirdConfig)
dnsServer, err := dns.NewDefaultServer(dns.DefaultServerConfig{
Ctx: e.ctx,
WgInterface: e.wgInterface,
CustomAddress: e.config.CustomDNSAddress,
StatusRecorder: e.statusRecorder,
StateManager: e.stateManager,
DisableSys: e.config.DisableDNS,
MgmtURL: mgmtURL,
ServerDomains: serverDomains,
})
if err != nil {
return nil, err
}

View File

@@ -2,11 +2,13 @@ package dnsinterceptor
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
@@ -26,6 +28,8 @@ import (
"github.com/netbirdio/netbird/route"
)
const dnsTimeout = 8 * time.Second
type domainMap map[domain.Domain][]netip.Prefix
type internalDNATer interface {
@@ -243,7 +247,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout)
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return
@@ -254,9 +258,20 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
startTime := time.Now()
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
if err != nil {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
@@ -568,3 +583,16 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
}
return
}
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
if d.statusRecorder == nil {
return ""
}
peerState, err := d.statusRecorder.GetPeer(peerKey)
if err != nil {
return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err)
}
return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState))
}