mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-18 11:00:06 +02:00
Add debug output for timeouts
This commit is contained in:
155
client/internal/dns/config/domains.go
Normal file
155
client/internal/dns/config/domains.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
// ServerDomains represents the management server domains extracted from NetBird configuration
|
||||
type ServerDomains struct {
|
||||
Signal domain.Domain
|
||||
Relay []domain.Domain
|
||||
Flow domain.Domain
|
||||
Stuns []domain.Domain
|
||||
Turns []domain.Domain
|
||||
}
|
||||
|
||||
// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration
|
||||
func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
|
||||
if config == nil {
|
||||
return ServerDomains{}
|
||||
}
|
||||
|
||||
domains := ServerDomains{}
|
||||
|
||||
domains.Signal = extractSignalDomain(config)
|
||||
domains.Relay = extractRelayDomains(config)
|
||||
domains.Flow = extractFlowDomain(config)
|
||||
domains.Stuns = extractStunDomains(config)
|
||||
domains.Turns = extractTurnDomains(config)
|
||||
|
||||
return domains
|
||||
}
|
||||
|
||||
// extractValidDomain extracts a valid domain from a URL, filtering out IP addresses
|
||||
func extractValidDomain(rawURL string) (domain.Domain, error) {
|
||||
parsedURL, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
// If URL parsing fails, it might be a raw host:port, try parsing as such
|
||||
if host, _, err := net.SplitHostPort(rawURL); err == nil {
|
||||
return extractDomainFromHost(host)
|
||||
}
|
||||
// If not host:port, try as raw hostname
|
||||
return extractDomainFromHost(rawURL)
|
||||
}
|
||||
|
||||
host := parsedURL.Hostname()
|
||||
if host == "" {
|
||||
return "", fmt.Errorf("no hostname in URL")
|
||||
}
|
||||
|
||||
return extractDomainFromHost(host)
|
||||
}
|
||||
|
||||
// extractDomainFromHost extracts domain from a host string, filtering out IP addresses
|
||||
func extractDomainFromHost(host string) (domain.Domain, error) {
|
||||
if host == "" {
|
||||
return "", fmt.Errorf("empty host")
|
||||
}
|
||||
|
||||
if _, err := netip.ParseAddr(host); err == nil {
|
||||
return "", fmt.Errorf("IP address not allowed: %s", host)
|
||||
}
|
||||
|
||||
d, err := domain.FromString(host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid domain: %v", err)
|
||||
}
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// extractSingleDomain extracts a single domain from a URL with error logging
|
||||
func extractSingleDomain(url, serviceType string) domain.Domain {
|
||||
if url == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
d, err := extractValidDomain(url)
|
||||
if err != nil {
|
||||
log.Debugf("Skipping %s: %v", serviceType, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
// extractMultipleDomains extracts multiple domains from URLs with error logging
|
||||
func extractMultipleDomains(urls []string, serviceType string) []domain.Domain {
|
||||
var domains []domain.Domain
|
||||
for _, url := range urls {
|
||||
if url == "" {
|
||||
continue
|
||||
}
|
||||
d, err := extractValidDomain(url)
|
||||
if err != nil {
|
||||
log.Debugf("Skipping %s: %v", serviceType, err)
|
||||
continue
|
||||
}
|
||||
domains = append(domains, d)
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// extractSignalDomain extracts the signal domain from NetBird configuration.
|
||||
func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain {
|
||||
if config.Signal != nil {
|
||||
return extractSingleDomain(config.Signal.Uri, "signal")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractRelayDomains extracts relay server domains from NetBird configuration.
|
||||
func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||
if config.Relay != nil {
|
||||
return extractMultipleDomains(config.Relay.Urls, "relay")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractFlowDomain extracts the traffic flow domain from NetBird configuration.
|
||||
func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain {
|
||||
if config.Flow != nil {
|
||||
return extractSingleDomain(config.Flow.Url, "flow")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractStunDomains extracts STUN server domains from NetBird configuration.
|
||||
func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||
var urls []string
|
||||
for _, stun := range config.Stuns {
|
||||
if stun != nil && stun.Uri != "" {
|
||||
urls = append(urls, stun.Uri)
|
||||
}
|
||||
}
|
||||
return extractMultipleDomains(urls, "STUN")
|
||||
}
|
||||
|
||||
// extractTurnDomains extracts TURN server domains from NetBird configuration.
|
||||
func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||
var urls []string
|
||||
for _, turn := range config.Turns {
|
||||
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
|
||||
urls = append(urls, turn.HostConfig.Uri)
|
||||
}
|
||||
}
|
||||
return extractMultipleDomains(urls, "TURN")
|
||||
}
|
@@ -182,7 +182,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// If handler wants to continue, try next handler
|
||||
if chainWriter.shouldContinue {
|
||||
// Only log continue for non-management cache handlers to reduce noise
|
||||
if entry.Priority != PriorityMgmtCache {
|
||||
log.Tracef("handler requested continue to next handler for domain=%s", qname)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return
|
||||
|
@@ -3,6 +3,7 @@ package mgmt
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
@@ -27,14 +29,12 @@ type CacheEntry struct {
|
||||
type Resolver struct {
|
||||
cache map[domain.Domain]CacheEntry
|
||||
mutex sync.RWMutex
|
||||
systemResolver *net.Resolver
|
||||
}
|
||||
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
cache: make(map[domain.Domain]CacheEntry),
|
||||
systemResolver: net.DefaultResolver,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,22 +58,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("MgmtCache: checking cache for domain=%s type=%s", qname, dns.TypeToString[question.Qtype])
|
||||
|
||||
m.mutex.RLock()
|
||||
parsedDomain, err := domain.FromString(qname)
|
||||
if err != nil {
|
||||
log.Tracef("MgmtCache: invalid domain format: %s", qname)
|
||||
m.mutex.RUnlock()
|
||||
m.continueToNext(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
entry, found := m.cache[parsedDomain]
|
||||
domainKey := domain.Domain(qname)
|
||||
entry, found := m.cache[domainKey]
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if !found {
|
||||
log.Tracef("MgmtCache: no cache entry found for domain=%s", qname)
|
||||
m.continueToNext(w, r)
|
||||
return
|
||||
}
|
||||
@@ -91,7 +81,6 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
if len(records) == 0 {
|
||||
log.Tracef("MgmtCache: no %s records for domain=%s", dns.TypeToString[question.Qtype], parsedDomain.SafeString())
|
||||
m.continueToNext(w, r)
|
||||
return
|
||||
}
|
||||
@@ -102,10 +91,10 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp.Answer = append(resp.Answer, rrCopy)
|
||||
}
|
||||
|
||||
log.Tracef("MgmtCache: serving %d cached records for domain=%s", len(resp.Answer), parsedDomain.SafeString())
|
||||
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), domainKey.SafeString())
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("MgmtCache: failed to write response: %v", err)
|
||||
log.Errorf("failed to write response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,20 +109,23 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
resp.MsgHdr.Zero = true
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("MgmtCache: failed to write continue signal: %v", err)
|
||||
log.Errorf("failed to write continue signal: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// AddDomain manually adds a domain to cache by resolving it.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
log.Debugf("MgmtCache: adding domain=%s to cache", d.SafeString())
|
||||
log.Debugf("adding domain=%s to cache", d.SafeString())
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var aRecords, aaaaRecords []dns.RR
|
||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
|
||||
}
|
||||
|
||||
if ips, err := m.systemResolver.LookupNetIP(ctx, "ip", d.PunycodeString()); err == nil {
|
||||
var aRecords, aaaaRecords []dns.RR
|
||||
for _, ip := range ips {
|
||||
if ip.Is4() {
|
||||
rr := &dns.A{
|
||||
@@ -167,12 +159,8 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
log.Debugf("MgmtCache: added domain=%s with %d A records and %d AAAA records",
|
||||
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
} else {
|
||||
log.Warnf("MgmtCache: failed to resolve domain=%s: %v", d.SafeString(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -182,7 +170,7 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err
|
||||
if mgmtURL != nil {
|
||||
if d, err := extractDomainFromURL(mgmtURL); err == nil {
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
log.Warnf("MgmtCache: failed to add management domain: %v", err)
|
||||
log.Warnf("failed to add management domain: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -190,6 +178,16 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveDomain removes a domain from the cache.
|
||||
func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
delete(m.cache, d)
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
}
|
||||
|
||||
// PopulateFromNetbirdConfig extracts and caches domains from the netbird config.
|
||||
func (m *Resolver) PopulateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) error {
|
||||
if config == nil {
|
||||
@@ -216,19 +214,19 @@ func (m *Resolver) addSignalDomain(ctx context.Context, signal *mgmProto.HostCon
|
||||
// If parsing fails, it might be a raw host:port, try adding a scheme
|
||||
signalURL, err = url.Parse("https://" + signal.Uri)
|
||||
if err != nil {
|
||||
log.Warnf("MgmtCache: failed to parse signal URL: %v", err)
|
||||
log.Warnf("failed to parse signal URL: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
d, err := extractDomainFromURL(signalURL)
|
||||
if err != nil {
|
||||
log.Warnf("MgmtCache: failed to extract signal domain: %v", err)
|
||||
log.Warnf("failed to extract signal domain: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
log.Warnf("MgmtCache: failed to add signal domain: %v", err)
|
||||
log.Warnf("failed to add signal domain: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -241,18 +239,18 @@ func (m *Resolver) addRelayDomains(ctx context.Context, relay *mgmProto.RelayCon
|
||||
for _, relayAddr := range relay.Urls {
|
||||
relayURL, err := url.Parse(relayAddr)
|
||||
if err != nil {
|
||||
log.Warnf("MgmtCache: failed to parse relay URL %s: %v", relayAddr, err)
|
||||
log.Warnf("failed to parse relay URL %s: %v", relayAddr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
d, err := extractDomainFromURL(relayURL)
|
||||
if err != nil {
|
||||
log.Warnf("MgmtCache: failed to extract relay domain from %s: %v", relayAddr, err)
|
||||
log.Warnf("failed to extract relay domain from %s: %v", relayAddr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
log.Warnf("MgmtCache: failed to add relay domain: %v", err)
|
||||
log.Warnf("failed to add relay domain: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -265,18 +263,18 @@ func (m *Resolver) addFlowDomain(ctx context.Context, flow *mgmProto.FlowConfig)
|
||||
|
||||
flowURL, err := url.Parse(flow.Url)
|
||||
if err != nil {
|
||||
log.Warnf("MgmtCache: failed to parse flow URL: %v", err)
|
||||
log.Warnf("failed to parse flow URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
d, err := extractDomainFromURL(flowURL)
|
||||
if err != nil {
|
||||
log.Warnf("MgmtCache: failed to extract flow domain: %v", err)
|
||||
log.Warnf("failed to extract flow domain: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
log.Warnf("MgmtCache: failed to add flow domain: %v", err)
|
||||
log.Warnf("failed to add flow domain: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,7 +301,7 @@ func (m *Resolver) ClearCache() []domain.Domain {
|
||||
}
|
||||
|
||||
m.cache = make(map[domain.Domain]CacheEntry)
|
||||
log.Debugf("MgmtCache: cleared %d cached domains", len(domains))
|
||||
log.Debugf("cleared %d cached domains", len(domains))
|
||||
|
||||
return domains
|
||||
}
|
||||
@@ -311,7 +309,7 @@ func (m *Resolver) ClearCache() []domain.Domain {
|
||||
// UpdateFromNetbirdConfig updates the cache intelligently by comparing current and new configurations.
|
||||
// Returns domains that were removed for external deregistration.
|
||||
func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) ([]domain.Domain, error) {
|
||||
log.Debugf("MgmtCache: updating cache from NetbirdConfig")
|
||||
log.Debugf("updating cache from NetbirdConfig")
|
||||
|
||||
currentDomains := m.GetCachedDomains()
|
||||
newDomains := m.extractDomainsFromConfig(config)
|
||||
@@ -333,19 +331,86 @@ func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto
|
||||
m.mutex.Lock()
|
||||
for _, domainToRemove := range removedDomains {
|
||||
delete(m.cache, domainToRemove)
|
||||
log.Debugf("MgmtCache: removed domain=%s from cache", domainToRemove.SafeString())
|
||||
log.Debugf("removed domain=%s from cache", domainToRemove.SafeString())
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
for _, newDomain := range newDomains {
|
||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
||||
log.Warnf("MgmtCache: failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return removedDomains, nil
|
||||
}
|
||||
|
||||
// UpdateFromServerDomains updates the cache using the simplified ServerDomains struct
|
||||
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) ([]domain.Domain, error) {
|
||||
log.Debugf("updating cache from ServerDomains")
|
||||
|
||||
currentDomains := m.GetCachedDomains()
|
||||
newDomains := m.extractDomainsFromServerDomains(serverDomains)
|
||||
|
||||
var removedDomains []domain.Domain
|
||||
for _, currentDomain := range currentDomains {
|
||||
found := false
|
||||
for _, newDomain := range newDomains {
|
||||
if currentDomain.SafeString() == newDomain.SafeString() {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
removedDomains = append(removedDomains, currentDomain)
|
||||
if err := m.RemoveDomain(currentDomain); err != nil {
|
||||
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, newDomain := range newDomains {
|
||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
||||
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||
} else {
|
||||
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
||||
}
|
||||
}
|
||||
|
||||
return removedDomains, nil
|
||||
}
|
||||
|
||||
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) []domain.Domain {
|
||||
var domains []domain.Domain
|
||||
|
||||
if serverDomains.Signal != "" {
|
||||
domains = append(domains, serverDomains.Signal)
|
||||
}
|
||||
|
||||
for _, relay := range serverDomains.Relay {
|
||||
if relay != "" {
|
||||
domains = append(domains, relay)
|
||||
}
|
||||
}
|
||||
|
||||
if serverDomains.Flow != "" {
|
||||
domains = append(domains, serverDomains.Flow)
|
||||
}
|
||||
|
||||
for _, stun := range serverDomains.Stuns {
|
||||
if stun != "" {
|
||||
domains = append(domains, stun)
|
||||
}
|
||||
}
|
||||
|
||||
for _, turn := range serverDomains.Turns {
|
||||
if turn != "" {
|
||||
domains = append(domains, turn)
|
||||
}
|
||||
}
|
||||
|
||||
return domains
|
||||
}
|
||||
|
||||
// extractDomainsFromConfig extracts all domains from a NetbirdConfig.
|
||||
func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||
if config == nil {
|
||||
@@ -354,26 +419,62 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
|
||||
|
||||
var domains []domain.Domain
|
||||
|
||||
if config.Signal != nil && config.Signal.Uri != "" {
|
||||
if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil {
|
||||
domains = append(domains, d)
|
||||
}
|
||||
// Extract signal domain
|
||||
domains = append(domains, m.extractSignalDomain(config)...)
|
||||
|
||||
// Extract relay domains
|
||||
domains = append(domains, m.extractRelayDomains(config)...)
|
||||
|
||||
// Extract flow domain
|
||||
domains = append(domains, m.extractFlowDomain(config)...)
|
||||
|
||||
// Extract STUN domains
|
||||
domains = append(domains, m.extractSTUNDomains(config)...)
|
||||
|
||||
// Extract TURN domains
|
||||
domains = append(domains, m.extractTURNDomains(config)...)
|
||||
|
||||
return domains
|
||||
}
|
||||
|
||||
if config.Relay != nil {
|
||||
func (m *Resolver) extractSignalDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||
if config.Signal == nil || config.Signal.Uri == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil {
|
||||
return []domain.Domain{d}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Resolver) extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||
if config.Relay == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var domains []domain.Domain
|
||||
for _, relayURL := range config.Relay.Urls {
|
||||
if d, err := m.extractDomainFromURL(relayURL); err == nil {
|
||||
domains = append(domains, d)
|
||||
}
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
func (m *Resolver) extractFlowDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||
if config.Flow == nil || config.Flow.Url == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if config.Flow != nil && config.Flow.Url != "" {
|
||||
if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil {
|
||||
domains = append(domains, d)
|
||||
return []domain.Domain{d}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Resolver) extractSTUNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||
var domains []domain.Domain
|
||||
for _, stun := range config.Stuns {
|
||||
if stun != nil && stun.Uri != "" {
|
||||
if d, err := m.extractDomainFromURL(stun.Uri); err == nil {
|
||||
@@ -381,7 +482,11 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
|
||||
}
|
||||
}
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
func (m *Resolver) extractTURNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||
var domains []domain.Domain
|
||||
for _, turn := range config.Turns {
|
||||
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
|
||||
if d, err := m.extractDomainFromURL(turn.HostConfig.Uri); err == nil {
|
||||
@@ -389,7 +494,6 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return domains
|
||||
}
|
||||
|
||||
@@ -424,18 +528,18 @@ func (m *Resolver) addStunDomains(ctx context.Context, stuns []*mgmProto.HostCon
|
||||
|
||||
stunURL, err := url.Parse(stun.Uri)
|
||||
if err != nil {
|
||||
log.Warnf("MgmtCache: failed to parse STUN URL %s: %v", stun.Uri, err)
|
||||
log.Warnf("failed to parse STUN URL %s: %v", stun.Uri, err)
|
||||
continue
|
||||
}
|
||||
|
||||
d, err := extractDomainFromURL(stunURL)
|
||||
if err != nil {
|
||||
log.Warnf("MgmtCache: failed to extract STUN domain from %s: %v", stun.Uri, err)
|
||||
log.Warnf("failed to extract STUN domain from %s: %v", stun.Uri, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
log.Warnf("MgmtCache: failed to add STUN domain: %v", err)
|
||||
log.Warnf("failed to add STUN domain: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -449,18 +553,18 @@ func (m *Resolver) addTurnDomains(ctx context.Context, turns []*mgmProto.Protect
|
||||
|
||||
turnURL, err := url.Parse(turn.HostConfig.Uri)
|
||||
if err != nil {
|
||||
log.Warnf("MgmtCache: failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err)
|
||||
log.Warnf("failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err)
|
||||
continue
|
||||
}
|
||||
|
||||
d, err := extractDomainFromURL(turnURL)
|
||||
if err != nil {
|
||||
log.Warnf("MgmtCache: failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err)
|
||||
log.Warnf("failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
log.Warnf("MgmtCache: failed to add TURN domain: %v", err)
|
||||
log.Warnf("failed to add TURN domain: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -5,7 +5,6 @@ import (
|
||||
"net"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -114,16 +113,15 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
|
||||
|
||||
resolver := NewResolver()
|
||||
|
||||
mgmtURL, _ := url.Parse("https://api.netbird.io")
|
||||
// Use IP address to avoid DNS resolution timeout
|
||||
mgmtURL, _ := url.Parse("https://127.0.0.1")
|
||||
|
||||
err := resolver.PopulateFromConfig(ctx, mgmtURL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Give some time for async population
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// IP addresses are rejected, so no domains should be cached
|
||||
domains := resolver.GetCachedDomains()
|
||||
assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature
|
||||
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
|
||||
}
|
||||
|
||||
func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
|
||||
@@ -132,32 +130,33 @@ func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
|
||||
|
||||
resolver := NewResolver()
|
||||
|
||||
// Use IP addresses to avoid DNS resolution timeouts
|
||||
netbirdConfig := &mgmProto.NetbirdConfig{
|
||||
Signal: &mgmProto.HostConfig{
|
||||
Uri: "https://signal.netbird.io",
|
||||
Uri: "https://10.0.0.1",
|
||||
},
|
||||
Relay: &mgmProto.RelayConfig{
|
||||
Urls: []string{
|
||||
"https://relay1.netbird.io:443",
|
||||
"https://relay2.netbird.io:443",
|
||||
"https://10.0.0.2:443",
|
||||
"https://10.0.0.3:443",
|
||||
},
|
||||
},
|
||||
Flow: &mgmProto.FlowConfig{
|
||||
Url: "https://flow.netbird.io:80",
|
||||
Url: "https://10.0.0.4:80",
|
||||
},
|
||||
Stuns: []*mgmProto.HostConfig{
|
||||
{Uri: "stun:stun1.netbird.io:3478"},
|
||||
{Uri: "stun:stun2.netbird.io:3478"},
|
||||
{Uri: "stun:10.0.0.5:3478"},
|
||||
{Uri: "stun:10.0.0.6:3478"},
|
||||
},
|
||||
Turns: []*mgmProto.ProtectedHostConfig{
|
||||
{
|
||||
HostConfig: &mgmProto.HostConfig{
|
||||
Uri: "turn:turn1.netbird.io:3478",
|
||||
Uri: "turn:10.0.0.7:3478",
|
||||
},
|
||||
},
|
||||
{
|
||||
HostConfig: &mgmProto.HostConfig{
|
||||
Uri: "turn:turn2.netbird.io:3478",
|
||||
Uri: "turn:10.0.0.8:3478",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -166,11 +165,42 @@ func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
|
||||
err := resolver.PopulateFromNetbirdConfig(ctx, netbirdConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Give some time for async population
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// IP addresses are rejected, so no domains should be cached
|
||||
domains := resolver.GetCachedDomains()
|
||||
assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature
|
||||
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
|
||||
}
|
||||
|
||||
func TestResolver_UpdateFromNetbirdConfig(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// Test with empty initial config and then add domains
|
||||
initialConfig := &mgmProto.NetbirdConfig{}
|
||||
|
||||
// Start with empty config
|
||||
removedDomains, err := resolver.UpdateFromNetbirdConfig(context.Background(), initialConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(removedDomains), "No domains should be removed from empty cache")
|
||||
|
||||
// Update to config with IP addresses instead of domains to avoid DNS resolution
|
||||
// IP addresses will be rejected by extractDomainFromURL so no actual resolution happens
|
||||
updatedConfig := &mgmProto.NetbirdConfig{
|
||||
Signal: &mgmProto.HostConfig{
|
||||
Uri: "https://127.0.0.1",
|
||||
},
|
||||
Flow: &mgmProto.FlowConfig{
|
||||
Url: "https://192.168.1.1:80",
|
||||
},
|
||||
}
|
||||
|
||||
removedDomains, err = resolver.UpdateFromNetbirdConfig(context.Background(), updatedConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify the method completes successfully without DNS timeouts
|
||||
assert.GreaterOrEqual(t, len(removedDomains), 0, "Should not error on config update")
|
||||
|
||||
// Verify no domains were actually added since IPs are rejected
|
||||
domains := resolver.GetCachedDomains()
|
||||
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
|
||||
}
|
||||
|
||||
func TestResolver_ContinueToNext(t *testing.T) {
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
@@ -16,6 +17,7 @@ type MockServer struct {
|
||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||
RegisterHandlerFunc func(domain.List, dns.Handler, int)
|
||||
DeregisterHandlerFunc func(domain.List, int)
|
||||
UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error
|
||||
}
|
||||
|
||||
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||
@@ -69,3 +71,10 @@ func (m *MockServer) SearchDomains() []string {
|
||||
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
|
||||
func (m *MockServer) ProbeAvailability() {
|
||||
}
|
||||
|
||||
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||
if m.UpdateServerConfigFunc != nil {
|
||||
return m.UpdateServerConfigFunc(domains)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@@ -16,6 +16,7 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/mgmt"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
@@ -25,7 +26,6 @@ import (
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||
@@ -49,6 +49,7 @@ type Server interface {
|
||||
OnUpdatedHostDNSServer(strings []string)
|
||||
SearchDomains() []string
|
||||
ProbeAvailability()
|
||||
UpdateServerConfig(domains dnsconfig.ServerDomains) error
|
||||
}
|
||||
|
||||
type nsGroupsByDomain struct {
|
||||
@@ -103,20 +104,23 @@ type handlerWrapper struct {
|
||||
|
||||
type registeredHandlerMap map[types.HandlerID]handlerWrapper
|
||||
|
||||
// DefaultServerConfig holds configuration parameters for NewDefaultServer
|
||||
type DefaultServerConfig struct {
|
||||
Ctx context.Context
|
||||
WgInterface WGIface
|
||||
CustomAddress string
|
||||
StatusRecorder *peer.Status
|
||||
StateManager *statemanager.Manager
|
||||
DisableSys bool
|
||||
MgmtURL *url.URL
|
||||
ServerDomains dnsconfig.ServerDomains
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
customAddress string,
|
||||
statusRecorder *peer.Status,
|
||||
stateManager *statemanager.Manager,
|
||||
disableSys bool,
|
||||
mgmtURL *url.URL,
|
||||
netbirdConfig *mgmProto.NetbirdConfig,
|
||||
) (*DefaultServer, error) {
|
||||
func NewDefaultServer(config DefaultServerConfig) (*DefaultServer, error) {
|
||||
var addrPort *netip.AddrPort
|
||||
if customAddress != "" {
|
||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||
if config.CustomAddress != "" {
|
||||
parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
||||
}
|
||||
@@ -124,31 +128,23 @@ func NewDefaultServer(
|
||||
}
|
||||
|
||||
var dnsService service
|
||||
if wgInterface.IsUserspaceBind() {
|
||||
dnsService = NewServiceViaMemory(wgInterface)
|
||||
if config.WgInterface.IsUserspaceBind() {
|
||||
dnsService = NewServiceViaMemory(config.WgInterface)
|
||||
} else {
|
||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||
dnsService = newServiceViaListener(config.WgInterface, addrPort)
|
||||
}
|
||||
|
||||
server := newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys)
|
||||
server := newDefaultServer(config.Ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
|
||||
|
||||
// Pre-populate management cache with management URL
|
||||
if mgmtURL != nil && server.mgmtCacheResolver != nil {
|
||||
if err := server.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL); err != nil {
|
||||
if config.MgmtURL != nil && server.mgmtCacheResolver != nil {
|
||||
if err := server.mgmtCacheResolver.PopulateFromConfig(config.Ctx, config.MgmtURL); err != nil {
|
||||
log.Warnf("Failed to populate management cache from management URL: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-populate management cache with NetbirdConfig domains
|
||||
if netbirdConfig != nil && server.mgmtCacheResolver != nil {
|
||||
if err := server.mgmtCacheResolver.PopulateFromNetbirdConfig(ctx, netbirdConfig); err != nil {
|
||||
log.Warnf("Failed to populate management cache from NetbirdConfig: %v", err)
|
||||
}
|
||||
|
||||
// Register newly populated domains
|
||||
domains := server.mgmtCacheResolver.GetCachedDomains()
|
||||
if len(domains) > 0 {
|
||||
server.RegisterHandler(domains, server.mgmtCacheResolver, PriorityMgmtCache)
|
||||
if server.mgmtCacheResolver != nil {
|
||||
if err := server.UpdateServerConfig(config.ServerDomains); err != nil {
|
||||
log.Warnf("Failed to populate management cache from ServerDomains: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,18 +216,10 @@ func newDefaultServer(
|
||||
mgmtCacheResolver: mgmtCacheResolver,
|
||||
}
|
||||
|
||||
// Register cached domains with the handler chain
|
||||
registerMgmtCacheDomains := func() {
|
||||
domains := mgmtCacheResolver.GetCachedDomains()
|
||||
if len(domains) > 0 {
|
||||
defaultServer.RegisterHandler(domains, mgmtCacheResolver, PriorityMgmtCache)
|
||||
}
|
||||
}
|
||||
|
||||
// Register any pre-populated domains from management cache
|
||||
registerMgmtCacheDomains()
|
||||
|
||||
// Management cache resolver will be registered for specific domains when they are added
|
||||
|
||||
// register with root zone, handler chain takes care of the routing
|
||||
dnsService.RegisterMux(".", handlerChain)
|
||||
@@ -352,7 +340,6 @@ func (s *DefaultServer) Stop() {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
s.service.Stop()
|
||||
|
||||
maps.Clear(s.extraDomains)
|
||||
@@ -368,15 +355,6 @@ func (s *DefaultServer) PopulateMgmtCacheFromConfig(mgmtURL *url.URL) error {
|
||||
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
|
||||
}
|
||||
|
||||
// PopulateMgmtCacheFromNetbirdConfig populates the management cache with domains from the netbird configuration
|
||||
func (s *DefaultServer) PopulateMgmtCacheFromNetbirdConfig(config *mgmProto.NetbirdConfig) error {
|
||||
if s.mgmtCacheResolver == nil {
|
||||
return fmt.Errorf("management cache resolver not initialized")
|
||||
}
|
||||
|
||||
log.Debug("populating management cache from netbird configuration")
|
||||
return s.mgmtCacheResolver.PopulateFromNetbirdConfig(s.ctx, config)
|
||||
}
|
||||
|
||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||
@@ -476,6 +454,29 @@ func (s *DefaultServer) ProbeAvailability() {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
if s.mgmtCacheResolver != nil {
|
||||
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update management cache resolver: %w", err)
|
||||
}
|
||||
|
||||
if len(removedDomains) > 0 {
|
||||
s.DeregisterHandler(removedDomains, PriorityMgmtCache)
|
||||
}
|
||||
|
||||
newDomains := s.mgmtCacheResolver.GetCachedDomains()
|
||||
if len(newDomains) > 0 {
|
||||
s.RegisterHandler(newDomains, s.mgmtCacheResolver, PriorityMgmtCache)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be Disabled, we stop the listener or fake resolver
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
|
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
@@ -363,7 +364,16 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil)
|
||||
dnsServer, err := NewDefaultServer(DefaultServerConfig{
|
||||
Ctx: context.Background(),
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
MgmtURL: nil,
|
||||
ServerDomains: dnsconfig.ServerDomains{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -473,7 +483,16 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil)
|
||||
dnsServer, err := NewDefaultServer(DefaultServerConfig{
|
||||
Ctx: context.Background(),
|
||||
WgInterface: wgIface,
|
||||
CustomAddress: "",
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
MgmtURL: nil,
|
||||
ServerDomains: dnsconfig.ServerDomains{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
@@ -575,7 +594,16 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false, nil, nil)
|
||||
dnsServer, err := NewDefaultServer(DefaultServerConfig{
|
||||
Ctx: context.Background(),
|
||||
WgInterface: &mocWGIface{},
|
||||
CustomAddress: testCase.addrPort,
|
||||
StatusRecorder: peer.NewRecorder("mgm"),
|
||||
StateManager: nil,
|
||||
DisableSys: false,
|
||||
MgmtURL: nil,
|
||||
ServerDomains: dnsconfig.ServerDomains{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -26,10 +27,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
UpstreamTimeout = 15 * time.Second
|
||||
UpstreamTimeout = 4 * time.Second
|
||||
// ClientTimeout is the timeout for the dns.Client.
|
||||
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
|
||||
ClientTimeout = 30 * time.Second
|
||||
ClientTimeout = 5 * time.Second
|
||||
|
||||
reactivatePeriod = 30 * time.Second
|
||||
probeTimeout = 2 * time.Second
|
||||
@@ -105,52 +106,111 @@ func (u *upstreamResolverBase) Stop() {
|
||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
requestID := GenerateRequestID()
|
||||
logger := log.WithField("request_id", requestID)
|
||||
var err error
|
||||
|
||||
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
|
||||
u.prepareRequest(r)
|
||||
|
||||
if u.isContextDone(logger) {
|
||||
return
|
||||
}
|
||||
|
||||
if u.tryUpstreamServers(w, r, logger) {
|
||||
return
|
||||
}
|
||||
|
||||
u.writeErrorResponse(w, r, logger)
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
if r.Extra == nil {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) isContextDone(logger *log.Entry) bool {
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
logger.Tracef("%s has been stopped", u)
|
||||
return
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(u.upstreamServers) > 1 {
|
||||
maxTotal := 5 * time.Second
|
||||
minPerUpstream := 2 * time.Second
|
||||
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
|
||||
if scaledTimeout > minPerUpstream {
|
||||
timeout = scaledTimeout
|
||||
} else {
|
||||
timeout = minPerUpstream
|
||||
}
|
||||
}
|
||||
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if u.queryUpstream(w, r, upstream, timeout, logger) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream string, timeout time.Duration, logger *log.Entry) bool {
|
||||
var rm *dns.Msg
|
||||
var t time.Duration
|
||||
var err error
|
||||
|
||||
var startTime time.Time
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||
defer cancel()
|
||||
startTime = time.Now()
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r)
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
||||
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
|
||||
continue
|
||||
}
|
||||
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
|
||||
continue
|
||||
u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
|
||||
return false
|
||||
}
|
||||
|
||||
if rm == nil || !rm.Response {
|
||||
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
||||
continue
|
||||
return false
|
||||
}
|
||||
|
||||
u.successCount.Add(1)
|
||||
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
|
||||
|
||||
if err = w.WriteMsg(rm); err != nil {
|
||||
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
|
||||
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
|
||||
return
|
||||
}
|
||||
|
||||
elapsed := time.Since(startTime)
|
||||
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
|
||||
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
|
||||
timeoutMsg += " " + peerInfo
|
||||
}
|
||||
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
||||
logger.Warnf(timeoutMsg)
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream, domain string, t time.Duration, logger *log.Entry) bool {
|
||||
u.successCount.Add(1)
|
||||
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
|
||||
|
||||
if err := w.WriteMsg(rm); err != nil {
|
||||
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
|
||||
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
||||
|
||||
m := new(dns.Msg)
|
||||
@@ -355,3 +415,97 @@ func GenerateRequestID() string {
|
||||
}
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
||||
func FormatPeerStatus(peerState *peer.State) string {
|
||||
isConnected := peerState.ConnStatus == peer.StatusConnected
|
||||
hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() &&
|
||||
time.Since(peerState.LastWireguardHandshake) < 3*time.Minute
|
||||
|
||||
statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP)
|
||||
|
||||
switch {
|
||||
case !isConnected:
|
||||
statusInfo += " DISCONNECTED"
|
||||
case !hasRecentHandshake:
|
||||
statusInfo += " NO_RECENT_HANDSHAKE"
|
||||
default:
|
||||
statusInfo += " connected"
|
||||
}
|
||||
|
||||
if !peerState.LastWireguardHandshake.IsZero() {
|
||||
timeSinceHandshake := time.Since(peerState.LastWireguardHandshake)
|
||||
statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second))
|
||||
} else {
|
||||
statusInfo += " no_handshake"
|
||||
}
|
||||
|
||||
if peerState.Relayed {
|
||||
statusInfo += " via_relay"
|
||||
}
|
||||
|
||||
if peerState.Latency > 0 {
|
||||
statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency)
|
||||
}
|
||||
|
||||
return statusInfo
|
||||
}
|
||||
|
||||
// findPeerForIP finds which peer handles the given IP address
|
||||
func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
|
||||
if statusRecorder == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fullStatus := statusRecorder.GetFullStatus()
|
||||
var bestMatch *peer.State
|
||||
var bestPrefixLen int
|
||||
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
routes := peerState.GetRoutes()
|
||||
for route := range routes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen {
|
||||
peerStateCopy := peerState
|
||||
bestMatch = &peerStateCopy
|
||||
bestPrefixLen = prefix.Bits()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bestMatch
|
||||
}
|
||||
|
||||
// parseUpstreamIP parses an upstream server address to extract the IP
|
||||
func parseUpstreamIP(upstream string) (netip.Addr, error) {
|
||||
upstreamIP, err := netip.ParseAddr(upstream)
|
||||
if err != nil {
|
||||
if host, _, err := net.SplitHostPort(upstream); err == nil {
|
||||
return netip.ParseAddr(host)
|
||||
}
|
||||
return netip.Addr{}, err
|
||||
}
|
||||
return upstreamIP, nil
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream string) string {
|
||||
if u.statusRecorder == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
upstreamIP, err := parseUpstreamIP(upstream)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
peerInfo := findPeerForIP(upstreamIP, u.statusRecorder)
|
||||
if peerInfo == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
|
||||
}
|
||||
|
@@ -33,6 +33,7 @@ import (
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
@@ -696,6 +697,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||
}
|
||||
|
||||
if e.dnsServer != nil {
|
||||
serverDomains := config.ExtractFromNetbirdConfig(wCfg)
|
||||
if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
|
||||
log.Warnf("Failed to update DNS server config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
@@ -1604,7 +1612,19 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbird
|
||||
return dnsServer, nil
|
||||
|
||||
default:
|
||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS, mgmtURL, netbirdConfig)
|
||||
// Extract domains from NetBird configuration
|
||||
serverDomains := config.ExtractFromNetbirdConfig(netbirdConfig)
|
||||
|
||||
dnsServer, err := dns.NewDefaultServer(dns.DefaultServerConfig{
|
||||
Ctx: e.ctx,
|
||||
WgInterface: e.wgInterface,
|
||||
CustomAddress: e.config.CustomDNSAddress,
|
||||
StatusRecorder: e.statusRecorder,
|
||||
StateManager: e.stateManager,
|
||||
DisableSys: e.config.DisableDNS,
|
||||
MgmtURL: mgmtURL,
|
||||
ServerDomains: serverDomains,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -2,11 +2,13 @@ package dnsinterceptor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
@@ -26,6 +28,8 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const dnsTimeout = 8 * time.Second
|
||||
|
||||
type domainMap map[domain.Domain][]netip.Prefix
|
||||
|
||||
type internalDNATer interface {
|
||||
@@ -243,7 +247,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout)
|
||||
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
|
||||
if err != nil {
|
||||
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
|
||||
return
|
||||
@@ -254,9 +258,20 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
startTime := time.Now()
|
||||
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
elapsed := time.Since(startTime)
|
||||
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
|
||||
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
|
||||
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
|
||||
} else {
|
||||
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||
}
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
logger.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
@@ -568,3 +583,16 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
|
||||
if d.statusRecorder == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
peerState, err := d.statusRecorder.GetPeer(peerKey)
|
||||
if err != nil {
|
||||
return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState))
|
||||
}
|
||||
|
Reference in New Issue
Block a user