Add debug output for timeouts

This commit is contained in:
Viktor Liu
2025-07-09 19:00:17 +02:00
parent a1cb7b4af6
commit 629757c911
10 changed files with 736 additions and 204 deletions

View File

@@ -0,0 +1,155 @@
package config
import (
"fmt"
"net"
"net/netip"
"net/url"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
// ServerDomains represents the management server domains extracted from NetBird configuration
type ServerDomains struct {
Signal domain.Domain
Relay []domain.Domain
Flow domain.Domain
Stuns []domain.Domain
Turns []domain.Domain
}
// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration
func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
if config == nil {
return ServerDomains{}
}
domains := ServerDomains{}
domains.Signal = extractSignalDomain(config)
domains.Relay = extractRelayDomains(config)
domains.Flow = extractFlowDomain(config)
domains.Stuns = extractStunDomains(config)
domains.Turns = extractTurnDomains(config)
return domains
}
// extractValidDomain extracts a valid domain from a URL, filtering out IP addresses
func extractValidDomain(rawURL string) (domain.Domain, error) {
parsedURL, err := url.Parse(rawURL)
if err != nil {
// If URL parsing fails, it might be a raw host:port, try parsing as such
if host, _, err := net.SplitHostPort(rawURL); err == nil {
return extractDomainFromHost(host)
}
// If not host:port, try as raw hostname
return extractDomainFromHost(rawURL)
}
host := parsedURL.Hostname()
if host == "" {
return "", fmt.Errorf("no hostname in URL")
}
return extractDomainFromHost(host)
}
// extractDomainFromHost extracts domain from a host string, filtering out IP addresses
func extractDomainFromHost(host string) (domain.Domain, error) {
if host == "" {
return "", fmt.Errorf("empty host")
}
if _, err := netip.ParseAddr(host); err == nil {
return "", fmt.Errorf("IP address not allowed: %s", host)
}
d, err := domain.FromString(host)
if err != nil {
return "", fmt.Errorf("invalid domain: %v", err)
}
return d, nil
}
// extractSingleDomain extracts a single domain from a URL with error logging
func extractSingleDomain(url, serviceType string) domain.Domain {
if url == "" {
return ""
}
d, err := extractValidDomain(url)
if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err)
return ""
}
return d
}
// extractMultipleDomains extracts multiple domains from URLs with error logging
func extractMultipleDomains(urls []string, serviceType string) []domain.Domain {
var domains []domain.Domain
for _, url := range urls {
if url == "" {
continue
}
d, err := extractValidDomain(url)
if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err)
continue
}
domains = append(domains, d)
}
return domains
}
// extractSignalDomain extracts the signal domain from NetBird configuration.
func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain {
if config.Signal != nil {
return extractSingleDomain(config.Signal.Uri, "signal")
}
return ""
}
// extractRelayDomains extracts relay server domains from NetBird configuration.
func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Relay != nil {
return extractMultipleDomains(config.Relay.Urls, "relay")
}
return nil
}
// extractFlowDomain extracts the traffic flow domain from NetBird configuration.
func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain {
if config.Flow != nil {
return extractSingleDomain(config.Flow.Url, "flow")
}
return ""
}
// extractStunDomains extracts STUN server domains from NetBird configuration.
func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var urls []string
for _, stun := range config.Stuns {
if stun != nil && stun.Uri != "" {
urls = append(urls, stun.Uri)
}
}
return extractMultipleDomains(urls, "STUN")
}
// extractTurnDomains extracts TURN server domains from NetBird configuration.
func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var urls []string
for _, turn := range config.Turns {
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
urls = append(urls, turn.HostConfig.Uri)
}
}
return extractMultipleDomains(urls, "TURN")
}

View File

@@ -182,7 +182,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
// Only log continue for non-management cache handlers to reduce noise
if entry.Priority != PriorityMgmtCache {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
}
continue
}
return

View File

@@ -3,6 +3,7 @@ package mgmt
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
@@ -13,6 +14,7 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
@@ -27,14 +29,12 @@ type CacheEntry struct {
type Resolver struct {
cache map[domain.Domain]CacheEntry
mutex sync.RWMutex
systemResolver *net.Resolver
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
cache: make(map[domain.Domain]CacheEntry),
systemResolver: net.DefaultResolver,
}
}
@@ -58,22 +58,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
log.Tracef("MgmtCache: checking cache for domain=%s type=%s", qname, dns.TypeToString[question.Qtype])
m.mutex.RLock()
parsedDomain, err := domain.FromString(qname)
if err != nil {
log.Tracef("MgmtCache: invalid domain format: %s", qname)
m.mutex.RUnlock()
m.continueToNext(w, r)
return
}
entry, found := m.cache[parsedDomain]
domainKey := domain.Domain(qname)
entry, found := m.cache[domainKey]
m.mutex.RUnlock()
if !found {
log.Tracef("MgmtCache: no cache entry found for domain=%s", qname)
m.continueToNext(w, r)
return
}
@@ -91,7 +81,6 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
if len(records) == 0 {
log.Tracef("MgmtCache: no %s records for domain=%s", dns.TypeToString[question.Qtype], parsedDomain.SafeString())
m.continueToNext(w, r)
return
}
@@ -102,10 +91,10 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp.Answer = append(resp.Answer, rrCopy)
}
log.Tracef("MgmtCache: serving %d cached records for domain=%s", len(resp.Answer), parsedDomain.SafeString())
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), domainKey.SafeString())
if err := w.WriteMsg(resp); err != nil {
log.Errorf("MgmtCache: failed to write response: %v", err)
log.Errorf("failed to write response: %v", err)
}
}
@@ -120,20 +109,23 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
log.Errorf("MgmtCache: failed to write continue signal: %v", err)
log.Errorf("failed to write continue signal: %v", err)
}
}
// AddDomain manually adds a domain to cache by resolving it.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
log.Debugf("MgmtCache: adding domain=%s to cache", d.SafeString())
log.Debugf("adding domain=%s to cache", d.SafeString())
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
var aRecords, aaaaRecords []dns.RR
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
if err != nil {
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
}
if ips, err := m.systemResolver.LookupNetIP(ctx, "ip", d.PunycodeString()); err == nil {
var aRecords, aaaaRecords []dns.RR
for _, ip := range ips {
if ip.Is4() {
rr := &dns.A{
@@ -167,12 +159,8 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
}
m.mutex.Unlock()
log.Debugf("MgmtCache: added domain=%s with %d A records and %d AAAA records",
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
} else {
log.Warnf("MgmtCache: failed to resolve domain=%s: %v", d.SafeString(), err)
return err
}
return nil
}
@@ -182,7 +170,7 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err
if mgmtURL != nil {
if d, err := extractDomainFromURL(mgmtURL); err == nil {
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add management domain: %v", err)
log.Warnf("failed to add management domain: %v", err)
}
}
}
@@ -190,6 +178,16 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err
return nil
}
// RemoveDomain removes a domain from the cache.
func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock()
defer m.mutex.Unlock()
delete(m.cache, d)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
}
// PopulateFromNetbirdConfig extracts and caches domains from the netbird config.
func (m *Resolver) PopulateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) error {
if config == nil {
@@ -216,19 +214,19 @@ func (m *Resolver) addSignalDomain(ctx context.Context, signal *mgmProto.HostCon
// If parsing fails, it might be a raw host:port, try adding a scheme
signalURL, err = url.Parse("https://" + signal.Uri)
if err != nil {
log.Warnf("MgmtCache: failed to parse signal URL: %v", err)
log.Warnf("failed to parse signal URL: %v", err)
return
}
}
d, err := extractDomainFromURL(signalURL)
if err != nil {
log.Warnf("MgmtCache: failed to extract signal domain: %v", err)
log.Warnf("failed to extract signal domain: %v", err)
return
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add signal domain: %v", err)
log.Warnf("failed to add signal domain: %v", err)
}
}
@@ -241,18 +239,18 @@ func (m *Resolver) addRelayDomains(ctx context.Context, relay *mgmProto.RelayCon
for _, relayAddr := range relay.Urls {
relayURL, err := url.Parse(relayAddr)
if err != nil {
log.Warnf("MgmtCache: failed to parse relay URL %s: %v", relayAddr, err)
log.Warnf("failed to parse relay URL %s: %v", relayAddr, err)
continue
}
d, err := extractDomainFromURL(relayURL)
if err != nil {
log.Warnf("MgmtCache: failed to extract relay domain from %s: %v", relayAddr, err)
log.Warnf("failed to extract relay domain from %s: %v", relayAddr, err)
continue
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add relay domain: %v", err)
log.Warnf("failed to add relay domain: %v", err)
}
}
}
@@ -265,18 +263,18 @@ func (m *Resolver) addFlowDomain(ctx context.Context, flow *mgmProto.FlowConfig)
flowURL, err := url.Parse(flow.Url)
if err != nil {
log.Warnf("MgmtCache: failed to parse flow URL: %v", err)
log.Warnf("failed to parse flow URL: %v", err)
return
}
d, err := extractDomainFromURL(flowURL)
if err != nil {
log.Warnf("MgmtCache: failed to extract flow domain: %v", err)
log.Warnf("failed to extract flow domain: %v", err)
return
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add flow domain: %v", err)
log.Warnf("failed to add flow domain: %v", err)
}
}
@@ -303,7 +301,7 @@ func (m *Resolver) ClearCache() []domain.Domain {
}
m.cache = make(map[domain.Domain]CacheEntry)
log.Debugf("MgmtCache: cleared %d cached domains", len(domains))
log.Debugf("cleared %d cached domains", len(domains))
return domains
}
@@ -311,7 +309,7 @@ func (m *Resolver) ClearCache() []domain.Domain {
// UpdateFromNetbirdConfig updates the cache intelligently by comparing current and new configurations.
// Returns domains that were removed for external deregistration.
func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) ([]domain.Domain, error) {
log.Debugf("MgmtCache: updating cache from NetbirdConfig")
log.Debugf("updating cache from NetbirdConfig")
currentDomains := m.GetCachedDomains()
newDomains := m.extractDomainsFromConfig(config)
@@ -333,19 +331,86 @@ func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto
m.mutex.Lock()
for _, domainToRemove := range removedDomains {
delete(m.cache, domainToRemove)
log.Debugf("MgmtCache: removed domain=%s from cache", domainToRemove.SafeString())
log.Debugf("removed domain=%s from cache", domainToRemove.SafeString())
}
m.mutex.Unlock()
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("MgmtCache: failed to add/update domain=%s: %v", newDomain.SafeString(), err)
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
}
}
return removedDomains, nil
}
// UpdateFromServerDomains updates the cache using the simplified ServerDomains struct
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) ([]domain.Domain, error) {
log.Debugf("updating cache from ServerDomains")
currentDomains := m.GetCachedDomains()
newDomains := m.extractDomainsFromServerDomains(serverDomains)
var removedDomains []domain.Domain
for _, currentDomain := range currentDomains {
found := false
for _, newDomain := range newDomains {
if currentDomain.SafeString() == newDomain.SafeString() {
found = true
break
}
}
if !found {
removedDomains = append(removedDomains, currentDomain)
if err := m.RemoveDomain(currentDomain); err != nil {
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
}
}
}
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
}
}
return removedDomains, nil
}
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) []domain.Domain {
var domains []domain.Domain
if serverDomains.Signal != "" {
domains = append(domains, serverDomains.Signal)
}
for _, relay := range serverDomains.Relay {
if relay != "" {
domains = append(domains, relay)
}
}
if serverDomains.Flow != "" {
domains = append(domains, serverDomains.Flow)
}
for _, stun := range serverDomains.Stuns {
if stun != "" {
domains = append(domains, stun)
}
}
for _, turn := range serverDomains.Turns {
if turn != "" {
domains = append(domains, turn)
}
}
return domains
}
// extractDomainsFromConfig extracts all domains from a NetbirdConfig.
func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []domain.Domain {
if config == nil {
@@ -354,26 +419,62 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
var domains []domain.Domain
if config.Signal != nil && config.Signal.Uri != "" {
if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil {
domains = append(domains, d)
}
// Extract signal domain
domains = append(domains, m.extractSignalDomain(config)...)
// Extract relay domains
domains = append(domains, m.extractRelayDomains(config)...)
// Extract flow domain
domains = append(domains, m.extractFlowDomain(config)...)
// Extract STUN domains
domains = append(domains, m.extractSTUNDomains(config)...)
// Extract TURN domains
domains = append(domains, m.extractTURNDomains(config)...)
return domains
}
if config.Relay != nil {
func (m *Resolver) extractSignalDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Signal == nil || config.Signal.Uri == "" {
return nil
}
if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil {
return []domain.Domain{d}
}
return nil
}
func (m *Resolver) extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Relay == nil {
return nil
}
var domains []domain.Domain
for _, relayURL := range config.Relay.Urls {
if d, err := m.extractDomainFromURL(relayURL); err == nil {
domains = append(domains, d)
}
}
return domains
}
func (m *Resolver) extractFlowDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Flow == nil || config.Flow.Url == "" {
return nil
}
if config.Flow != nil && config.Flow.Url != "" {
if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil {
domains = append(domains, d)
return []domain.Domain{d}
}
return nil
}
func (m *Resolver) extractSTUNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var domains []domain.Domain
for _, stun := range config.Stuns {
if stun != nil && stun.Uri != "" {
if d, err := m.extractDomainFromURL(stun.Uri); err == nil {
@@ -381,7 +482,11 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
}
}
}
return domains
}
func (m *Resolver) extractTURNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var domains []domain.Domain
for _, turn := range config.Turns {
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
if d, err := m.extractDomainFromURL(turn.HostConfig.Uri); err == nil {
@@ -389,7 +494,6 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
}
}
}
return domains
}
@@ -424,18 +528,18 @@ func (m *Resolver) addStunDomains(ctx context.Context, stuns []*mgmProto.HostCon
stunURL, err := url.Parse(stun.Uri)
if err != nil {
log.Warnf("MgmtCache: failed to parse STUN URL %s: %v", stun.Uri, err)
log.Warnf("failed to parse STUN URL %s: %v", stun.Uri, err)
continue
}
d, err := extractDomainFromURL(stunURL)
if err != nil {
log.Warnf("MgmtCache: failed to extract STUN domain from %s: %v", stun.Uri, err)
log.Warnf("failed to extract STUN domain from %s: %v", stun.Uri, err)
continue
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add STUN domain: %v", err)
log.Warnf("failed to add STUN domain: %v", err)
}
}
}
@@ -449,18 +553,18 @@ func (m *Resolver) addTurnDomains(ctx context.Context, turns []*mgmProto.Protect
turnURL, err := url.Parse(turn.HostConfig.Uri)
if err != nil {
log.Warnf("MgmtCache: failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err)
log.Warnf("failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err)
continue
}
d, err := extractDomainFromURL(turnURL)
if err != nil {
log.Warnf("MgmtCache: failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err)
log.Warnf("failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err)
continue
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("MgmtCache: failed to add TURN domain: %v", err)
log.Warnf("failed to add TURN domain: %v", err)
}
}
}

View File

@@ -5,7 +5,6 @@ import (
"net"
"net/url"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -114,16 +113,15 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
resolver := NewResolver()
mgmtURL, _ := url.Parse("https://api.netbird.io")
// Use IP address to avoid DNS resolution timeout
mgmtURL, _ := url.Parse("https://127.0.0.1")
err := resolver.PopulateFromConfig(ctx, mgmtURL)
assert.NoError(t, err)
// Give some time for async population
time.Sleep(100 * time.Millisecond)
// IP addresses are rejected, so no domains should be cached
domains := resolver.GetCachedDomains()
assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
}
func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
@@ -132,32 +130,33 @@ func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
resolver := NewResolver()
// Use IP addresses to avoid DNS resolution timeouts
netbirdConfig := &mgmProto.NetbirdConfig{
Signal: &mgmProto.HostConfig{
Uri: "https://signal.netbird.io",
Uri: "https://10.0.0.1",
},
Relay: &mgmProto.RelayConfig{
Urls: []string{
"https://relay1.netbird.io:443",
"https://relay2.netbird.io:443",
"https://10.0.0.2:443",
"https://10.0.0.3:443",
},
},
Flow: &mgmProto.FlowConfig{
Url: "https://flow.netbird.io:80",
Url: "https://10.0.0.4:80",
},
Stuns: []*mgmProto.HostConfig{
{Uri: "stun:stun1.netbird.io:3478"},
{Uri: "stun:stun2.netbird.io:3478"},
{Uri: "stun:10.0.0.5:3478"},
{Uri: "stun:10.0.0.6:3478"},
},
Turns: []*mgmProto.ProtectedHostConfig{
{
HostConfig: &mgmProto.HostConfig{
Uri: "turn:turn1.netbird.io:3478",
Uri: "turn:10.0.0.7:3478",
},
},
{
HostConfig: &mgmProto.HostConfig{
Uri: "turn:turn2.netbird.io:3478",
Uri: "turn:10.0.0.8:3478",
},
},
},
@@ -166,11 +165,42 @@ func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
err := resolver.PopulateFromNetbirdConfig(ctx, netbirdConfig)
assert.NoError(t, err)
// Give some time for async population
time.Sleep(100 * time.Millisecond)
// IP addresses are rejected, so no domains should be cached
domains := resolver.GetCachedDomains()
assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
}
func TestResolver_UpdateFromNetbirdConfig(t *testing.T) {
resolver := NewResolver()
// Test with empty initial config and then add domains
initialConfig := &mgmProto.NetbirdConfig{}
// Start with empty config
removedDomains, err := resolver.UpdateFromNetbirdConfig(context.Background(), initialConfig)
assert.NoError(t, err)
assert.Equal(t, 0, len(removedDomains), "No domains should be removed from empty cache")
// Update to config with IP addresses instead of domains to avoid DNS resolution
// IP addresses will be rejected by extractDomainFromURL so no actual resolution happens
updatedConfig := &mgmProto.NetbirdConfig{
Signal: &mgmProto.HostConfig{
Uri: "https://127.0.0.1",
},
Flow: &mgmProto.FlowConfig{
Url: "https://192.168.1.1:80",
},
}
removedDomains, err = resolver.UpdateFromNetbirdConfig(context.Background(), updatedConfig)
assert.NoError(t, err)
// Verify the method completes successfully without DNS timeouts
assert.GreaterOrEqual(t, len(removedDomains), 0, "Should not error on config update")
// Verify no domains were actually added since IPs are rejected
domains := resolver.GetCachedDomains()
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
}
func TestResolver_ContinueToNext(t *testing.T) {

View File

@@ -5,6 +5,7 @@ import (
"github.com/miekg/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
)
@@ -16,6 +17,7 @@ type MockServer struct {
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func(domain.List, dns.Handler, int)
DeregisterHandlerFunc func(domain.List, int)
UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error
}
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
@@ -69,3 +71,10 @@ func (m *MockServer) SearchDomains() []string {
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
func (m *MockServer) ProbeAvailability() {
}
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
if m.UpdateServerConfigFunc != nil {
return m.UpdateServerConfigFunc(domains)
}
return nil
}

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/iface/netstack"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/mgmt"
"github.com/netbirdio/netbird/client/internal/dns/types"
@@ -25,7 +26,6 @@ import (
cProto "github.com/netbirdio/netbird/client/proto"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
@@ -49,6 +49,7 @@ type Server interface {
OnUpdatedHostDNSServer(strings []string)
SearchDomains() []string
ProbeAvailability()
UpdateServerConfig(domains dnsconfig.ServerDomains) error
}
type nsGroupsByDomain struct {
@@ -103,20 +104,23 @@ type handlerWrapper struct {
type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct {
Ctx context.Context
WgInterface WGIface
CustomAddress string
StatusRecorder *peer.Status
StateManager *statemanager.Manager
DisableSys bool
MgmtURL *url.URL
ServerDomains dnsconfig.ServerDomains
}
// NewDefaultServer returns a new dns server
func NewDefaultServer(
ctx context.Context,
wgInterface WGIface,
customAddress string,
statusRecorder *peer.Status,
stateManager *statemanager.Manager,
disableSys bool,
mgmtURL *url.URL,
netbirdConfig *mgmProto.NetbirdConfig,
) (*DefaultServer, error) {
func NewDefaultServer(config DefaultServerConfig) (*DefaultServer, error) {
var addrPort *netip.AddrPort
if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
if config.CustomAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress)
if err != nil {
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
}
@@ -124,31 +128,23 @@ func NewDefaultServer(
}
var dnsService service
if wgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(wgInterface)
if config.WgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(config.WgInterface)
} else {
dnsService = newServiceViaListener(wgInterface, addrPort)
dnsService = newServiceViaListener(config.WgInterface, addrPort)
}
server := newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys)
server := newDefaultServer(config.Ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
// Pre-populate management cache with management URL
if mgmtURL != nil && server.mgmtCacheResolver != nil {
if err := server.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL); err != nil {
if config.MgmtURL != nil && server.mgmtCacheResolver != nil {
if err := server.mgmtCacheResolver.PopulateFromConfig(config.Ctx, config.MgmtURL); err != nil {
log.Warnf("Failed to populate management cache from management URL: %v", err)
}
}
// Pre-populate management cache with NetbirdConfig domains
if netbirdConfig != nil && server.mgmtCacheResolver != nil {
if err := server.mgmtCacheResolver.PopulateFromNetbirdConfig(ctx, netbirdConfig); err != nil {
log.Warnf("Failed to populate management cache from NetbirdConfig: %v", err)
}
// Register newly populated domains
domains := server.mgmtCacheResolver.GetCachedDomains()
if len(domains) > 0 {
server.RegisterHandler(domains, server.mgmtCacheResolver, PriorityMgmtCache)
if server.mgmtCacheResolver != nil {
if err := server.UpdateServerConfig(config.ServerDomains); err != nil {
log.Warnf("Failed to populate management cache from ServerDomains: %v", err)
}
}
@@ -220,18 +216,10 @@ func newDefaultServer(
mgmtCacheResolver: mgmtCacheResolver,
}
// Register cached domains with the handler chain
registerMgmtCacheDomains := func() {
domains := mgmtCacheResolver.GetCachedDomains()
if len(domains) > 0 {
defaultServer.RegisterHandler(domains, mgmtCacheResolver, PriorityMgmtCache)
}
}
// Register any pre-populated domains from management cache
registerMgmtCacheDomains()
// Management cache resolver will be registered for specific domains when they are added
// register with root zone, handler chain takes care of the routing
dnsService.RegisterMux(".", handlerChain)
@@ -352,7 +340,6 @@ func (s *DefaultServer) Stop() {
}
}
s.service.Stop()
maps.Clear(s.extraDomains)
@@ -368,15 +355,6 @@ func (s *DefaultServer) PopulateMgmtCacheFromConfig(mgmtURL *url.URL) error {
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
}
// PopulateMgmtCacheFromNetbirdConfig populates the management cache with domains from the netbird configuration
func (s *DefaultServer) PopulateMgmtCacheFromNetbirdConfig(config *mgmProto.NetbirdConfig) error {
if s.mgmtCacheResolver == nil {
return fmt.Errorf("management cache resolver not initialized")
}
log.Debug("populating management cache from netbird configuration")
return s.mgmtCacheResolver.PopulateFromNetbirdConfig(s.ctx, config)
}
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone
@@ -476,6 +454,29 @@ func (s *DefaultServer) ProbeAvailability() {
wg.Wait()
}
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
s.mux.Lock()
defer s.mux.Unlock()
if s.mgmtCacheResolver != nil {
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains)
if err != nil {
return fmt.Errorf("update management cache resolver: %w", err)
}
if len(removedDomains) > 0 {
s.DeregisterHandler(removedDomains, PriorityMgmtCache)
}
newDomains := s.mgmtCacheResolver.GetCachedDomains()
if len(newDomains) > 0 {
s.RegisterHandler(newDomains, s.mgmtCacheResolver, PriorityMgmtCache)
}
}
return nil
}
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be Disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/dns/types"
@@ -363,7 +364,16 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil)
dnsServer, err := NewDefaultServer(DefaultServerConfig{
Ctx: context.Background(),
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
})
if err != nil {
t.Fatal(err)
}
@@ -473,7 +483,16 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return
}
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil)
dnsServer, err := NewDefaultServer(DefaultServerConfig{
Ctx: context.Background(),
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
})
if err != nil {
t.Errorf("create DNS server: %v", err)
return
@@ -575,7 +594,16 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false, nil, nil)
dnsServer, err := NewDefaultServer(DefaultServerConfig{
Ctx: context.Background(),
WgInterface: &mocWGIface{},
CustomAddress: testCase.addrPort,
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
})
if err != nil {
t.Fatalf("%v", err)
}

View File

@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strings"
"sync"
@@ -26,10 +27,10 @@ import (
)
const (
UpstreamTimeout = 15 * time.Second
UpstreamTimeout = 4 * time.Second
// ClientTimeout is the timeout for the dns.Client.
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
ClientTimeout = 30 * time.Second
ClientTimeout = 5 * time.Second
reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second
@@ -105,52 +106,111 @@ func (u *upstreamResolverBase) Stop() {
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
var err error
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
u.prepareRequest(r)
if u.isContextDone(logger) {
return
}
if u.tryUpstreamServers(w, r, logger) {
return
}
u.writeErrorResponse(w, r, logger)
}
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true
}
}
func (u *upstreamResolverBase) isContextDone(logger *log.Entry) bool {
select {
case <-u.ctx.Done():
logger.Tracef("%s has been stopped", u)
return
return true
default:
return false
}
}
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
minPerUpstream := 2 * time.Second
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
if scaledTimeout > minPerUpstream {
timeout = scaledTimeout
} else {
timeout = minPerUpstream
}
}
for _, upstream := range u.upstreamServers {
if u.queryUpstream(w, r, upstream, timeout, logger) {
return true
}
}
return false
}
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream string, timeout time.Duration, logger *log.Entry) bool {
var rm *dns.Msg
var t time.Duration
var err error
var startTime time.Time
func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel()
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r)
}()
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue
}
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue
u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
return false
}
if rm == nil || !rm.Response {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
continue
return false
}
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
if err = w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
return
}
elapsed := time.Since(startTime)
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
timeoutMsg += " " + peerInfo
}
timeoutMsg += fmt.Sprintf(" - error: %v", err)
logger.Warnf(timeoutMsg)
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
}
return true
}
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg)
@@ -355,3 +415,97 @@ func GenerateRequestID() string {
}
return hex.EncodeToString(bytes)
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected
hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() &&
time.Since(peerState.LastWireguardHandshake) < 3*time.Minute
statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP)
switch {
case !isConnected:
statusInfo += " DISCONNECTED"
case !hasRecentHandshake:
statusInfo += " NO_RECENT_HANDSHAKE"
default:
statusInfo += " connected"
}
if !peerState.LastWireguardHandshake.IsZero() {
timeSinceHandshake := time.Since(peerState.LastWireguardHandshake)
statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second))
} else {
statusInfo += " no_handshake"
}
if peerState.Relayed {
statusInfo += " via_relay"
}
if peerState.Latency > 0 {
statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency)
}
return statusInfo
}
// findPeerForIP finds which peer handles the given IP address
func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
if statusRecorder == nil {
return nil
}
fullStatus := statusRecorder.GetFullStatus()
var bestMatch *peer.State
var bestPrefixLen int
for _, peerState := range fullStatus.Peers {
routes := peerState.GetRoutes()
for route := range routes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
continue
}
if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen {
peerStateCopy := peerState
bestMatch = &peerStateCopy
bestPrefixLen = prefix.Bits()
}
}
}
return bestMatch
}
// parseUpstreamIP parses an upstream server address to extract the IP
func parseUpstreamIP(upstream string) (netip.Addr, error) {
upstreamIP, err := netip.ParseAddr(upstream)
if err != nil {
if host, _, err := net.SplitHostPort(upstream); err == nil {
return netip.ParseAddr(host)
}
return netip.Addr{}, err
}
return upstreamIP, nil
}
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream string) string {
if u.statusRecorder == nil {
return ""
}
upstreamIP, err := parseUpstreamIP(upstream)
if err != nil {
return ""
}
peerInfo := findPeerForIP(upstreamIP, u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
}

View File

@@ -33,6 +33,7 @@ import (
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow"
@@ -696,6 +697,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return fmt.Errorf("handle the flow configuration: %w", err)
}
if e.dnsServer != nil {
serverDomains := config.ExtractFromNetbirdConfig(wCfg)
if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
}
// todo update signal
}
@@ -1604,7 +1612,19 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbird
return dnsServer, nil
default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS, mgmtURL, netbirdConfig)
// Extract domains from NetBird configuration
serverDomains := config.ExtractFromNetbirdConfig(netbirdConfig)
dnsServer, err := dns.NewDefaultServer(dns.DefaultServerConfig{
Ctx: e.ctx,
WgInterface: e.wgInterface,
CustomAddress: e.config.CustomDNSAddress,
StatusRecorder: e.statusRecorder,
StateManager: e.stateManager,
DisableSys: e.config.DisableDNS,
MgmtURL: mgmtURL,
ServerDomains: serverDomains,
})
if err != nil {
return nil, err
}

View File

@@ -2,11 +2,13 @@ package dnsinterceptor
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
@@ -26,6 +28,8 @@ import (
"github.com/netbirdio/netbird/route"
)
const dnsTimeout = 8 * time.Second
type domainMap map[domain.Domain][]netip.Prefix
type internalDNATer interface {
@@ -243,7 +247,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout)
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return
@@ -254,9 +258,20 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
startTime := time.Now()
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
@@ -568,3 +583,16 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
}
return
}
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
if d.statusRecorder == nil {
return ""
}
peerState, err := d.statusRecorder.GetPeer(peerKey)
if err != nil {
return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err)
}
return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState))
}