mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-18 19:09:09 +02:00
Add debug output for timeouts
This commit is contained in:
155
client/internal/dns/config/domains.go
Normal file
155
client/internal/dns/config/domains.go
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServerDomains represents the management server domains extracted from NetBird configuration
|
||||||
|
type ServerDomains struct {
|
||||||
|
Signal domain.Domain
|
||||||
|
Relay []domain.Domain
|
||||||
|
Flow domain.Domain
|
||||||
|
Stuns []domain.Domain
|
||||||
|
Turns []domain.Domain
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration
|
||||||
|
func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
|
||||||
|
if config == nil {
|
||||||
|
return ServerDomains{}
|
||||||
|
}
|
||||||
|
|
||||||
|
domains := ServerDomains{}
|
||||||
|
|
||||||
|
domains.Signal = extractSignalDomain(config)
|
||||||
|
domains.Relay = extractRelayDomains(config)
|
||||||
|
domains.Flow = extractFlowDomain(config)
|
||||||
|
domains.Stuns = extractStunDomains(config)
|
||||||
|
domains.Turns = extractTurnDomains(config)
|
||||||
|
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractValidDomain extracts a valid domain from a URL, filtering out IP addresses
|
||||||
|
func extractValidDomain(rawURL string) (domain.Domain, error) {
|
||||||
|
parsedURL, err := url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
// If URL parsing fails, it might be a raw host:port, try parsing as such
|
||||||
|
if host, _, err := net.SplitHostPort(rawURL); err == nil {
|
||||||
|
return extractDomainFromHost(host)
|
||||||
|
}
|
||||||
|
// If not host:port, try as raw hostname
|
||||||
|
return extractDomainFromHost(rawURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
host := parsedURL.Hostname()
|
||||||
|
if host == "" {
|
||||||
|
return "", fmt.Errorf("no hostname in URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
return extractDomainFromHost(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractDomainFromHost extracts domain from a host string, filtering out IP addresses
|
||||||
|
func extractDomainFromHost(host string) (domain.Domain, error) {
|
||||||
|
if host == "" {
|
||||||
|
return "", fmt.Errorf("empty host")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := netip.ParseAddr(host); err == nil {
|
||||||
|
return "", fmt.Errorf("IP address not allowed: %s", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := domain.FromString(host)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("invalid domain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractSingleDomain extracts a single domain from a URL with error logging
|
||||||
|
func extractSingleDomain(url, serviceType string) domain.Domain {
|
||||||
|
if url == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := extractValidDomain(url)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Skipping %s: %v", serviceType, err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractMultipleDomains extracts multiple domains from URLs with error logging
|
||||||
|
func extractMultipleDomains(urls []string, serviceType string) []domain.Domain {
|
||||||
|
var domains []domain.Domain
|
||||||
|
for _, url := range urls {
|
||||||
|
if url == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
d, err := extractValidDomain(url)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Skipping %s: %v", serviceType, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domains = append(domains, d)
|
||||||
|
}
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractSignalDomain extracts the signal domain from NetBird configuration.
|
||||||
|
func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain {
|
||||||
|
if config.Signal != nil {
|
||||||
|
return extractSingleDomain(config.Signal.Uri, "signal")
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractRelayDomains extracts relay server domains from NetBird configuration.
|
||||||
|
func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||||
|
if config.Relay != nil {
|
||||||
|
return extractMultipleDomains(config.Relay.Urls, "relay")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractFlowDomain extracts the traffic flow domain from NetBird configuration.
|
||||||
|
func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain {
|
||||||
|
if config.Flow != nil {
|
||||||
|
return extractSingleDomain(config.Flow.Url, "flow")
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractStunDomains extracts STUN server domains from NetBird configuration.
|
||||||
|
func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||||
|
var urls []string
|
||||||
|
for _, stun := range config.Stuns {
|
||||||
|
if stun != nil && stun.Uri != "" {
|
||||||
|
urls = append(urls, stun.Uri)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return extractMultipleDomains(urls, "STUN")
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractTurnDomains extracts TURN server domains from NetBird configuration.
|
||||||
|
func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||||
|
var urls []string
|
||||||
|
for _, turn := range config.Turns {
|
||||||
|
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
|
||||||
|
urls = append(urls, turn.HostConfig.Uri)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return extractMultipleDomains(urls, "TURN")
|
||||||
|
}
|
@@ -182,7 +182,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
// If handler wants to continue, try next handler
|
// If handler wants to continue, try next handler
|
||||||
if chainWriter.shouldContinue {
|
if chainWriter.shouldContinue {
|
||||||
log.Tracef("handler requested continue to next handler for domain=%s", qname)
|
// Only log continue for non-management cache handlers to reduce noise
|
||||||
|
if entry.Priority != PriorityMgmtCache {
|
||||||
|
log.Tracef("handler requested continue to next handler for domain=%s", qname)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@@ -3,6 +3,7 @@ package mgmt
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
@@ -27,14 +29,12 @@ type CacheEntry struct {
|
|||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
cache map[domain.Domain]CacheEntry
|
cache map[domain.Domain]CacheEntry
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
systemResolver *net.Resolver
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResolver creates a new management domains cache resolver.
|
// NewResolver creates a new management domains cache resolver.
|
||||||
func NewResolver() *Resolver {
|
func NewResolver() *Resolver {
|
||||||
return &Resolver{
|
return &Resolver{
|
||||||
cache: make(map[domain.Domain]CacheEntry),
|
cache: make(map[domain.Domain]CacheEntry),
|
||||||
systemResolver: net.DefaultResolver,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,22 +58,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("MgmtCache: checking cache for domain=%s type=%s", qname, dns.TypeToString[question.Qtype])
|
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
parsedDomain, err := domain.FromString(qname)
|
domainKey := domain.Domain(qname)
|
||||||
if err != nil {
|
entry, found := m.cache[domainKey]
|
||||||
log.Tracef("MgmtCache: invalid domain format: %s", qname)
|
|
||||||
m.mutex.RUnlock()
|
|
||||||
m.continueToNext(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
entry, found := m.cache[parsedDomain]
|
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
log.Tracef("MgmtCache: no cache entry found for domain=%s", qname)
|
|
||||||
m.continueToNext(w, r)
|
m.continueToNext(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -91,7 +81,6 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(records) == 0 {
|
if len(records) == 0 {
|
||||||
log.Tracef("MgmtCache: no %s records for domain=%s", dns.TypeToString[question.Qtype], parsedDomain.SafeString())
|
|
||||||
m.continueToNext(w, r)
|
m.continueToNext(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -102,10 +91,10 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
resp.Answer = append(resp.Answer, rrCopy)
|
resp.Answer = append(resp.Answer, rrCopy)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("MgmtCache: serving %d cached records for domain=%s", len(resp.Answer), parsedDomain.SafeString())
|
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), domainKey.SafeString())
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("MgmtCache: failed to write response: %v", err)
|
log.Errorf("failed to write response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,60 +109,59 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
resp.SetRcode(r, dns.RcodeNameError)
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
resp.MsgHdr.Zero = true
|
resp.MsgHdr.Zero = true
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("MgmtCache: failed to write continue signal: %v", err)
|
log.Errorf("failed to write continue signal: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddDomain manually adds a domain to cache by resolving it.
|
// AddDomain manually adds a domain to cache by resolving it.
|
||||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||||
log.Debugf("MgmtCache: adding domain=%s to cache", d.SafeString())
|
log.Debugf("adding domain=%s to cache", d.SafeString())
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
var aRecords, aaaaRecords []dns.RR
|
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||||
|
if err != nil {
|
||||||
if ips, err := m.systemResolver.LookupNetIP(ctx, "ip", d.PunycodeString()); err == nil {
|
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
|
||||||
for _, ip := range ips {
|
|
||||||
if ip.Is4() {
|
|
||||||
rr := &dns.A{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: d.PunycodeString() + ".",
|
|
||||||
Rrtype: dns.TypeA,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: 300,
|
|
||||||
},
|
|
||||||
A: ip.AsSlice(),
|
|
||||||
}
|
|
||||||
aRecords = append(aRecords, rr)
|
|
||||||
} else if ip.Is6() {
|
|
||||||
rr := &dns.AAAA{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: d.PunycodeString() + ".",
|
|
||||||
Rrtype: dns.TypeAAAA,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: 300,
|
|
||||||
},
|
|
||||||
AAAA: ip.AsSlice(),
|
|
||||||
}
|
|
||||||
aaaaRecords = append(aaaaRecords, rr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
m.cache[d] = CacheEntry{
|
|
||||||
ARecords: aRecords,
|
|
||||||
AAAARecords: aaaaRecords,
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
log.Debugf("MgmtCache: added domain=%s with %d A records and %d AAAA records",
|
|
||||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
|
||||||
} else {
|
|
||||||
log.Warnf("MgmtCache: failed to resolve domain=%s: %v", d.SafeString(), err)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var aRecords, aaaaRecords []dns.RR
|
||||||
|
for _, ip := range ips {
|
||||||
|
if ip.Is4() {
|
||||||
|
rr := &dns.A{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: d.PunycodeString() + ".",
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 300,
|
||||||
|
},
|
||||||
|
A: ip.AsSlice(),
|
||||||
|
}
|
||||||
|
aRecords = append(aRecords, rr)
|
||||||
|
} else if ip.Is6() {
|
||||||
|
rr := &dns.AAAA{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: d.PunycodeString() + ".",
|
||||||
|
Rrtype: dns.TypeAAAA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 300,
|
||||||
|
},
|
||||||
|
AAAA: ip.AsSlice(),
|
||||||
|
}
|
||||||
|
aaaaRecords = append(aaaaRecords, rr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.cache[d] = CacheEntry{
|
||||||
|
ARecords: aRecords,
|
||||||
|
AAAARecords: aaaaRecords,
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
||||||
|
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -182,7 +170,7 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err
|
|||||||
if mgmtURL != nil {
|
if mgmtURL != nil {
|
||||||
if d, err := extractDomainFromURL(mgmtURL); err == nil {
|
if d, err := extractDomainFromURL(mgmtURL); err == nil {
|
||||||
if err := m.AddDomain(ctx, d); err != nil {
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
log.Warnf("MgmtCache: failed to add management domain: %v", err)
|
log.Warnf("failed to add management domain: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -190,6 +178,16 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemoveDomain removes a domain from the cache.
|
||||||
|
func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
delete(m.cache, d)
|
||||||
|
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// PopulateFromNetbirdConfig extracts and caches domains from the netbird config.
|
// PopulateFromNetbirdConfig extracts and caches domains from the netbird config.
|
||||||
func (m *Resolver) PopulateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) error {
|
func (m *Resolver) PopulateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) error {
|
||||||
if config == nil {
|
if config == nil {
|
||||||
@@ -216,19 +214,19 @@ func (m *Resolver) addSignalDomain(ctx context.Context, signal *mgmProto.HostCon
|
|||||||
// If parsing fails, it might be a raw host:port, try adding a scheme
|
// If parsing fails, it might be a raw host:port, try adding a scheme
|
||||||
signalURL, err = url.Parse("https://" + signal.Uri)
|
signalURL, err = url.Parse("https://" + signal.Uri)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("MgmtCache: failed to parse signal URL: %v", err)
|
log.Warnf("failed to parse signal URL: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := extractDomainFromURL(signalURL)
|
d, err := extractDomainFromURL(signalURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("MgmtCache: failed to extract signal domain: %v", err)
|
log.Warnf("failed to extract signal domain: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.AddDomain(ctx, d); err != nil {
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
log.Warnf("MgmtCache: failed to add signal domain: %v", err)
|
log.Warnf("failed to add signal domain: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -241,18 +239,18 @@ func (m *Resolver) addRelayDomains(ctx context.Context, relay *mgmProto.RelayCon
|
|||||||
for _, relayAddr := range relay.Urls {
|
for _, relayAddr := range relay.Urls {
|
||||||
relayURL, err := url.Parse(relayAddr)
|
relayURL, err := url.Parse(relayAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("MgmtCache: failed to parse relay URL %s: %v", relayAddr, err)
|
log.Warnf("failed to parse relay URL %s: %v", relayAddr, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := extractDomainFromURL(relayURL)
|
d, err := extractDomainFromURL(relayURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("MgmtCache: failed to extract relay domain from %s: %v", relayAddr, err)
|
log.Warnf("failed to extract relay domain from %s: %v", relayAddr, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.AddDomain(ctx, d); err != nil {
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
log.Warnf("MgmtCache: failed to add relay domain: %v", err)
|
log.Warnf("failed to add relay domain: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -265,18 +263,18 @@ func (m *Resolver) addFlowDomain(ctx context.Context, flow *mgmProto.FlowConfig)
|
|||||||
|
|
||||||
flowURL, err := url.Parse(flow.Url)
|
flowURL, err := url.Parse(flow.Url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("MgmtCache: failed to parse flow URL: %v", err)
|
log.Warnf("failed to parse flow URL: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := extractDomainFromURL(flowURL)
|
d, err := extractDomainFromURL(flowURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("MgmtCache: failed to extract flow domain: %v", err)
|
log.Warnf("failed to extract flow domain: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.AddDomain(ctx, d); err != nil {
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
log.Warnf("MgmtCache: failed to add flow domain: %v", err)
|
log.Warnf("failed to add flow domain: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -303,7 +301,7 @@ func (m *Resolver) ClearCache() []domain.Domain {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.cache = make(map[domain.Domain]CacheEntry)
|
m.cache = make(map[domain.Domain]CacheEntry)
|
||||||
log.Debugf("MgmtCache: cleared %d cached domains", len(domains))
|
log.Debugf("cleared %d cached domains", len(domains))
|
||||||
|
|
||||||
return domains
|
return domains
|
||||||
}
|
}
|
||||||
@@ -311,7 +309,7 @@ func (m *Resolver) ClearCache() []domain.Domain {
|
|||||||
// UpdateFromNetbirdConfig updates the cache intelligently by comparing current and new configurations.
|
// UpdateFromNetbirdConfig updates the cache intelligently by comparing current and new configurations.
|
||||||
// Returns domains that were removed for external deregistration.
|
// Returns domains that were removed for external deregistration.
|
||||||
func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) ([]domain.Domain, error) {
|
func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) ([]domain.Domain, error) {
|
||||||
log.Debugf("MgmtCache: updating cache from NetbirdConfig")
|
log.Debugf("updating cache from NetbirdConfig")
|
||||||
|
|
||||||
currentDomains := m.GetCachedDomains()
|
currentDomains := m.GetCachedDomains()
|
||||||
newDomains := m.extractDomainsFromConfig(config)
|
newDomains := m.extractDomainsFromConfig(config)
|
||||||
@@ -333,19 +331,86 @@ func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
for _, domainToRemove := range removedDomains {
|
for _, domainToRemove := range removedDomains {
|
||||||
delete(m.cache, domainToRemove)
|
delete(m.cache, domainToRemove)
|
||||||
log.Debugf("MgmtCache: removed domain=%s from cache", domainToRemove.SafeString())
|
log.Debugf("removed domain=%s from cache", domainToRemove.SafeString())
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
for _, newDomain := range newDomains {
|
for _, newDomain := range newDomains {
|
||||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
if err := m.AddDomain(ctx, newDomain); err != nil {
|
||||||
log.Warnf("MgmtCache: failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return removedDomains, nil
|
return removedDomains, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateFromServerDomains updates the cache using the simplified ServerDomains struct
|
||||||
|
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) ([]domain.Domain, error) {
|
||||||
|
log.Debugf("updating cache from ServerDomains")
|
||||||
|
|
||||||
|
currentDomains := m.GetCachedDomains()
|
||||||
|
newDomains := m.extractDomainsFromServerDomains(serverDomains)
|
||||||
|
|
||||||
|
var removedDomains []domain.Domain
|
||||||
|
for _, currentDomain := range currentDomains {
|
||||||
|
found := false
|
||||||
|
for _, newDomain := range newDomains {
|
||||||
|
if currentDomain.SafeString() == newDomain.SafeString() {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
removedDomains = append(removedDomains, currentDomain)
|
||||||
|
if err := m.RemoveDomain(currentDomain); err != nil {
|
||||||
|
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, newDomain := range newDomains {
|
||||||
|
if err := m.AddDomain(ctx, newDomain); err != nil {
|
||||||
|
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return removedDomains, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) []domain.Domain {
|
||||||
|
var domains []domain.Domain
|
||||||
|
|
||||||
|
if serverDomains.Signal != "" {
|
||||||
|
domains = append(domains, serverDomains.Signal)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, relay := range serverDomains.Relay {
|
||||||
|
if relay != "" {
|
||||||
|
domains = append(domains, relay)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if serverDomains.Flow != "" {
|
||||||
|
domains = append(domains, serverDomains.Flow)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, stun := range serverDomains.Stuns {
|
||||||
|
if stun != "" {
|
||||||
|
domains = append(domains, stun)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, turn := range serverDomains.Turns {
|
||||||
|
if turn != "" {
|
||||||
|
domains = append(domains, turn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
// extractDomainsFromConfig extracts all domains from a NetbirdConfig.
|
// extractDomainsFromConfig extracts all domains from a NetbirdConfig.
|
||||||
func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []domain.Domain {
|
func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||||
if config == nil {
|
if config == nil {
|
||||||
@@ -354,26 +419,62 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
|
|||||||
|
|
||||||
var domains []domain.Domain
|
var domains []domain.Domain
|
||||||
|
|
||||||
if config.Signal != nil && config.Signal.Uri != "" {
|
// Extract signal domain
|
||||||
if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil {
|
domains = append(domains, m.extractSignalDomain(config)...)
|
||||||
|
|
||||||
|
// Extract relay domains
|
||||||
|
domains = append(domains, m.extractRelayDomains(config)...)
|
||||||
|
|
||||||
|
// Extract flow domain
|
||||||
|
domains = append(domains, m.extractFlowDomain(config)...)
|
||||||
|
|
||||||
|
// Extract STUN domains
|
||||||
|
domains = append(domains, m.extractSTUNDomains(config)...)
|
||||||
|
|
||||||
|
// Extract TURN domains
|
||||||
|
domains = append(domains, m.extractTURNDomains(config)...)
|
||||||
|
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Resolver) extractSignalDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||||
|
if config.Signal == nil || config.Signal.Uri == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil {
|
||||||
|
return []domain.Domain{d}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Resolver) extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||||
|
if config.Relay == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var domains []domain.Domain
|
||||||
|
for _, relayURL := range config.Relay.Urls {
|
||||||
|
if d, err := m.extractDomainFromURL(relayURL); err == nil {
|
||||||
domains = append(domains, d)
|
domains = append(domains, d)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
if config.Relay != nil {
|
func (m *Resolver) extractFlowDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||||
for _, relayURL := range config.Relay.Urls {
|
if config.Flow == nil || config.Flow.Url == "" {
|
||||||
if d, err := m.extractDomainFromURL(relayURL); err == nil {
|
return nil
|
||||||
domains = append(domains, d)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Flow != nil && config.Flow.Url != "" {
|
if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil {
|
||||||
if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil {
|
return []domain.Domain{d}
|
||||||
domains = append(domains, d)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Resolver) extractSTUNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||||
|
var domains []domain.Domain
|
||||||
for _, stun := range config.Stuns {
|
for _, stun := range config.Stuns {
|
||||||
if stun != nil && stun.Uri != "" {
|
if stun != nil && stun.Uri != "" {
|
||||||
if d, err := m.extractDomainFromURL(stun.Uri); err == nil {
|
if d, err := m.extractDomainFromURL(stun.Uri); err == nil {
|
||||||
@@ -381,7 +482,11 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Resolver) extractTURNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||||
|
var domains []domain.Domain
|
||||||
for _, turn := range config.Turns {
|
for _, turn := range config.Turns {
|
||||||
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
|
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
|
||||||
if d, err := m.extractDomainFromURL(turn.HostConfig.Uri); err == nil {
|
if d, err := m.extractDomainFromURL(turn.HostConfig.Uri); err == nil {
|
||||||
@@ -389,7 +494,6 @@ func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []do
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return domains
|
return domains
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -424,18 +528,18 @@ func (m *Resolver) addStunDomains(ctx context.Context, stuns []*mgmProto.HostCon
|
|||||||
|
|
||||||
stunURL, err := url.Parse(stun.Uri)
|
stunURL, err := url.Parse(stun.Uri)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("MgmtCache: failed to parse STUN URL %s: %v", stun.Uri, err)
|
log.Warnf("failed to parse STUN URL %s: %v", stun.Uri, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := extractDomainFromURL(stunURL)
|
d, err := extractDomainFromURL(stunURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("MgmtCache: failed to extract STUN domain from %s: %v", stun.Uri, err)
|
log.Warnf("failed to extract STUN domain from %s: %v", stun.Uri, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.AddDomain(ctx, d); err != nil {
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
log.Warnf("MgmtCache: failed to add STUN domain: %v", err)
|
log.Warnf("failed to add STUN domain: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -449,18 +553,18 @@ func (m *Resolver) addTurnDomains(ctx context.Context, turns []*mgmProto.Protect
|
|||||||
|
|
||||||
turnURL, err := url.Parse(turn.HostConfig.Uri)
|
turnURL, err := url.Parse(turn.HostConfig.Uri)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("MgmtCache: failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err)
|
log.Warnf("failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := extractDomainFromURL(turnURL)
|
d, err := extractDomainFromURL(turnURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("MgmtCache: failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err)
|
log.Warnf("failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.AddDomain(ctx, d); err != nil {
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
log.Warnf("MgmtCache: failed to add TURN domain: %v", err)
|
log.Warnf("failed to add TURN domain: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -5,7 +5,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -114,16 +113,15 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
|
|||||||
|
|
||||||
resolver := NewResolver()
|
resolver := NewResolver()
|
||||||
|
|
||||||
mgmtURL, _ := url.Parse("https://api.netbird.io")
|
// Use IP address to avoid DNS resolution timeout
|
||||||
|
mgmtURL, _ := url.Parse("https://127.0.0.1")
|
||||||
|
|
||||||
err := resolver.PopulateFromConfig(ctx, mgmtURL)
|
err := resolver.PopulateFromConfig(ctx, mgmtURL)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Give some time for async population
|
// IP addresses are rejected, so no domains should be cached
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
domains := resolver.GetCachedDomains()
|
domains := resolver.GetCachedDomains()
|
||||||
assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature
|
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
|
func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
|
||||||
@@ -132,32 +130,33 @@ func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
|
|||||||
|
|
||||||
resolver := NewResolver()
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
// Use IP addresses to avoid DNS resolution timeouts
|
||||||
netbirdConfig := &mgmProto.NetbirdConfig{
|
netbirdConfig := &mgmProto.NetbirdConfig{
|
||||||
Signal: &mgmProto.HostConfig{
|
Signal: &mgmProto.HostConfig{
|
||||||
Uri: "https://signal.netbird.io",
|
Uri: "https://10.0.0.1",
|
||||||
},
|
},
|
||||||
Relay: &mgmProto.RelayConfig{
|
Relay: &mgmProto.RelayConfig{
|
||||||
Urls: []string{
|
Urls: []string{
|
||||||
"https://relay1.netbird.io:443",
|
"https://10.0.0.2:443",
|
||||||
"https://relay2.netbird.io:443",
|
"https://10.0.0.3:443",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Flow: &mgmProto.FlowConfig{
|
Flow: &mgmProto.FlowConfig{
|
||||||
Url: "https://flow.netbird.io:80",
|
Url: "https://10.0.0.4:80",
|
||||||
},
|
},
|
||||||
Stuns: []*mgmProto.HostConfig{
|
Stuns: []*mgmProto.HostConfig{
|
||||||
{Uri: "stun:stun1.netbird.io:3478"},
|
{Uri: "stun:10.0.0.5:3478"},
|
||||||
{Uri: "stun:stun2.netbird.io:3478"},
|
{Uri: "stun:10.0.0.6:3478"},
|
||||||
},
|
},
|
||||||
Turns: []*mgmProto.ProtectedHostConfig{
|
Turns: []*mgmProto.ProtectedHostConfig{
|
||||||
{
|
{
|
||||||
HostConfig: &mgmProto.HostConfig{
|
HostConfig: &mgmProto.HostConfig{
|
||||||
Uri: "turn:turn1.netbird.io:3478",
|
Uri: "turn:10.0.0.7:3478",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
HostConfig: &mgmProto.HostConfig{
|
HostConfig: &mgmProto.HostConfig{
|
||||||
Uri: "turn:turn2.netbird.io:3478",
|
Uri: "turn:10.0.0.8:3478",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -166,11 +165,42 @@ func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
|
|||||||
err := resolver.PopulateFromNetbirdConfig(ctx, netbirdConfig)
|
err := resolver.PopulateFromNetbirdConfig(ctx, netbirdConfig)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Give some time for async population
|
// IP addresses are rejected, so no domains should be cached
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
domains := resolver.GetCachedDomains()
|
domains := resolver.GetCachedDomains()
|
||||||
assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature
|
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_UpdateFromNetbirdConfig(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
// Test with empty initial config and then add domains
|
||||||
|
initialConfig := &mgmProto.NetbirdConfig{}
|
||||||
|
|
||||||
|
// Start with empty config
|
||||||
|
removedDomains, err := resolver.UpdateFromNetbirdConfig(context.Background(), initialConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, len(removedDomains), "No domains should be removed from empty cache")
|
||||||
|
|
||||||
|
// Update to config with IP addresses instead of domains to avoid DNS resolution
|
||||||
|
// IP addresses will be rejected by extractDomainFromURL so no actual resolution happens
|
||||||
|
updatedConfig := &mgmProto.NetbirdConfig{
|
||||||
|
Signal: &mgmProto.HostConfig{
|
||||||
|
Uri: "https://127.0.0.1",
|
||||||
|
},
|
||||||
|
Flow: &mgmProto.FlowConfig{
|
||||||
|
Url: "https://192.168.1.1:80",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
removedDomains, err = resolver.UpdateFromNetbirdConfig(context.Background(), updatedConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the method completes successfully without DNS timeouts
|
||||||
|
assert.GreaterOrEqual(t, len(removedDomains), 0, "Should not error on config update")
|
||||||
|
|
||||||
|
// Verify no domains were actually added since IPs are rejected
|
||||||
|
domains := resolver.GetCachedDomains()
|
||||||
|
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolver_ContinueToNext(t *testing.T) {
|
func TestResolver_ContinueToNext(t *testing.T) {
|
||||||
|
@@ -5,17 +5,19 @@ import (
|
|||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockServer is the mock instance of a dns server
|
// MockServer is the mock instance of a dns server
|
||||||
type MockServer struct {
|
type MockServer struct {
|
||||||
InitializeFunc func() error
|
InitializeFunc func() error
|
||||||
StopFunc func()
|
StopFunc func()
|
||||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||||
RegisterHandlerFunc func(domain.List, dns.Handler, int)
|
RegisterHandlerFunc func(domain.List, dns.Handler, int)
|
||||||
DeregisterHandlerFunc func(domain.List, int)
|
DeregisterHandlerFunc func(domain.List, int)
|
||||||
|
UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||||
@@ -69,3 +71,10 @@ func (m *MockServer) SearchDomains() []string {
|
|||||||
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
|
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
|
||||||
func (m *MockServer) ProbeAvailability() {
|
func (m *MockServer) ProbeAvailability() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||||
|
if m.UpdateServerConfigFunc != nil {
|
||||||
|
return m.UpdateServerConfigFunc(domains)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@@ -16,6 +16,7 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/mgmt"
|
"github.com/netbirdio/netbird/client/internal/dns/mgmt"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
@@ -25,7 +26,6 @@ import (
|
|||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
@@ -49,6 +49,7 @@ type Server interface {
|
|||||||
OnUpdatedHostDNSServer(strings []string)
|
OnUpdatedHostDNSServer(strings []string)
|
||||||
SearchDomains() []string
|
SearchDomains() []string
|
||||||
ProbeAvailability()
|
ProbeAvailability()
|
||||||
|
UpdateServerConfig(domains dnsconfig.ServerDomains) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type nsGroupsByDomain struct {
|
type nsGroupsByDomain struct {
|
||||||
@@ -103,20 +104,23 @@ type handlerWrapper struct {
|
|||||||
|
|
||||||
type registeredHandlerMap map[types.HandlerID]handlerWrapper
|
type registeredHandlerMap map[types.HandlerID]handlerWrapper
|
||||||
|
|
||||||
|
// DefaultServerConfig holds configuration parameters for NewDefaultServer
|
||||||
|
type DefaultServerConfig struct {
|
||||||
|
Ctx context.Context
|
||||||
|
WgInterface WGIface
|
||||||
|
CustomAddress string
|
||||||
|
StatusRecorder *peer.Status
|
||||||
|
StateManager *statemanager.Manager
|
||||||
|
DisableSys bool
|
||||||
|
MgmtURL *url.URL
|
||||||
|
ServerDomains dnsconfig.ServerDomains
|
||||||
|
}
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(
|
func NewDefaultServer(config DefaultServerConfig) (*DefaultServer, error) {
|
||||||
ctx context.Context,
|
|
||||||
wgInterface WGIface,
|
|
||||||
customAddress string,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
stateManager *statemanager.Manager,
|
|
||||||
disableSys bool,
|
|
||||||
mgmtURL *url.URL,
|
|
||||||
netbirdConfig *mgmProto.NetbirdConfig,
|
|
||||||
) (*DefaultServer, error) {
|
|
||||||
var addrPort *netip.AddrPort
|
var addrPort *netip.AddrPort
|
||||||
if customAddress != "" {
|
if config.CustomAddress != "" {
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
||||||
}
|
}
|
||||||
@@ -124,31 +128,23 @@ func NewDefaultServer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var dnsService service
|
var dnsService service
|
||||||
if wgInterface.IsUserspaceBind() {
|
if config.WgInterface.IsUserspaceBind() {
|
||||||
dnsService = NewServiceViaMemory(wgInterface)
|
dnsService = NewServiceViaMemory(config.WgInterface)
|
||||||
} else {
|
} else {
|
||||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
dnsService = newServiceViaListener(config.WgInterface, addrPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys)
|
server := newDefaultServer(config.Ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
|
||||||
|
|
||||||
// Pre-populate management cache with management URL
|
if config.MgmtURL != nil && server.mgmtCacheResolver != nil {
|
||||||
if mgmtURL != nil && server.mgmtCacheResolver != nil {
|
if err := server.mgmtCacheResolver.PopulateFromConfig(config.Ctx, config.MgmtURL); err != nil {
|
||||||
if err := server.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL); err != nil {
|
|
||||||
log.Warnf("Failed to populate management cache from management URL: %v", err)
|
log.Warnf("Failed to populate management cache from management URL: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pre-populate management cache with NetbirdConfig domains
|
if server.mgmtCacheResolver != nil {
|
||||||
if netbirdConfig != nil && server.mgmtCacheResolver != nil {
|
if err := server.UpdateServerConfig(config.ServerDomains); err != nil {
|
||||||
if err := server.mgmtCacheResolver.PopulateFromNetbirdConfig(ctx, netbirdConfig); err != nil {
|
log.Warnf("Failed to populate management cache from ServerDomains: %v", err)
|
||||||
log.Warnf("Failed to populate management cache from NetbirdConfig: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register newly populated domains
|
|
||||||
domains := server.mgmtCacheResolver.GetCachedDomains()
|
|
||||||
if len(domains) > 0 {
|
|
||||||
server.RegisterHandler(domains, server.mgmtCacheResolver, PriorityMgmtCache)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,19 +216,11 @@ func newDefaultServer(
|
|||||||
mgmtCacheResolver: mgmtCacheResolver,
|
mgmtCacheResolver: mgmtCacheResolver,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register cached domains with the handler chain
|
domains := mgmtCacheResolver.GetCachedDomains()
|
||||||
registerMgmtCacheDomains := func() {
|
if len(domains) > 0 {
|
||||||
domains := mgmtCacheResolver.GetCachedDomains()
|
defaultServer.RegisterHandler(domains, mgmtCacheResolver, PriorityMgmtCache)
|
||||||
if len(domains) > 0 {
|
|
||||||
defaultServer.RegisterHandler(domains, mgmtCacheResolver, PriorityMgmtCache)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register any pre-populated domains from management cache
|
|
||||||
registerMgmtCacheDomains()
|
|
||||||
|
|
||||||
// Management cache resolver will be registered for specific domains when they are added
|
|
||||||
|
|
||||||
// register with root zone, handler chain takes care of the routing
|
// register with root zone, handler chain takes care of the routing
|
||||||
dnsService.RegisterMux(".", handlerChain)
|
dnsService.RegisterMux(".", handlerChain)
|
||||||
|
|
||||||
@@ -352,7 +340,6 @@ func (s *DefaultServer) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
s.service.Stop()
|
s.service.Stop()
|
||||||
|
|
||||||
maps.Clear(s.extraDomains)
|
maps.Clear(s.extraDomains)
|
||||||
@@ -368,15 +355,6 @@ func (s *DefaultServer) PopulateMgmtCacheFromConfig(mgmtURL *url.URL) error {
|
|||||||
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
|
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PopulateMgmtCacheFromNetbirdConfig populates the management cache with domains from the netbird configuration
|
|
||||||
func (s *DefaultServer) PopulateMgmtCacheFromNetbirdConfig(config *mgmProto.NetbirdConfig) error {
|
|
||||||
if s.mgmtCacheResolver == nil {
|
|
||||||
return fmt.Errorf("management cache resolver not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("populating management cache from netbird configuration")
|
|
||||||
return s.mgmtCacheResolver.PopulateFromNetbirdConfig(s.ctx, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||||
@@ -476,6 +454,29 @@ func (s *DefaultServer) ProbeAvailability() {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
if s.mgmtCacheResolver != nil {
|
||||||
|
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update management cache resolver: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(removedDomains) > 0 {
|
||||||
|
s.DeregisterHandler(removedDomains, PriorityMgmtCache)
|
||||||
|
}
|
||||||
|
|
||||||
|
newDomains := s.mgmtCacheResolver.GetCachedDomains()
|
||||||
|
if len(newDomains) > 0 {
|
||||||
|
s.RegisterHandler(newDomains, s.mgmtCacheResolver, PriorityMgmtCache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||||
// is the service should be Disabled, we stop the listener or fake resolver
|
// is the service should be Disabled, we stop the listener or fake resolver
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
|
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
@@ -363,7 +364,16 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
t.Log(err)
|
t.Log(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil)
|
dnsServer, err := NewDefaultServer(DefaultServerConfig{
|
||||||
|
Ctx: context.Background(),
|
||||||
|
WgInterface: wgIface,
|
||||||
|
CustomAddress: "",
|
||||||
|
StatusRecorder: peer.NewRecorder("mgm"),
|
||||||
|
StateManager: nil,
|
||||||
|
DisableSys: false,
|
||||||
|
MgmtURL: nil,
|
||||||
|
ServerDomains: dnsconfig.ServerDomains{},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -473,7 +483,16 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil)
|
dnsServer, err := NewDefaultServer(DefaultServerConfig{
|
||||||
|
Ctx: context.Background(),
|
||||||
|
WgInterface: wgIface,
|
||||||
|
CustomAddress: "",
|
||||||
|
StatusRecorder: peer.NewRecorder("mgm"),
|
||||||
|
StateManager: nil,
|
||||||
|
DisableSys: false,
|
||||||
|
MgmtURL: nil,
|
||||||
|
ServerDomains: dnsconfig.ServerDomains{},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create DNS server: %v", err)
|
t.Errorf("create DNS server: %v", err)
|
||||||
return
|
return
|
||||||
@@ -575,7 +594,16 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false, nil, nil)
|
dnsServer, err := NewDefaultServer(DefaultServerConfig{
|
||||||
|
Ctx: context.Background(),
|
||||||
|
WgInterface: &mocWGIface{},
|
||||||
|
CustomAddress: testCase.addrPort,
|
||||||
|
StatusRecorder: peer.NewRecorder("mgm"),
|
||||||
|
StateManager: nil,
|
||||||
|
DisableSys: false,
|
||||||
|
MgmtURL: nil,
|
||||||
|
ServerDomains: dnsconfig.ServerDomains{},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%v", err)
|
t.Fatalf("%v", err)
|
||||||
}
|
}
|
||||||
|
@@ -8,6 +8,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -26,10 +27,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
UpstreamTimeout = 15 * time.Second
|
UpstreamTimeout = 4 * time.Second
|
||||||
// ClientTimeout is the timeout for the dns.Client.
|
// ClientTimeout is the timeout for the dns.Client.
|
||||||
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
|
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
|
||||||
ClientTimeout = 30 * time.Second
|
ClientTimeout = 5 * time.Second
|
||||||
|
|
||||||
reactivatePeriod = 30 * time.Second
|
reactivatePeriod = 30 * time.Second
|
||||||
probeTimeout = 2 * time.Second
|
probeTimeout = 2 * time.Second
|
||||||
@@ -105,52 +106,111 @@ func (u *upstreamResolverBase) Stop() {
|
|||||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
requestID := GenerateRequestID()
|
requestID := GenerateRequestID()
|
||||||
logger := log.WithField("request_id", requestID)
|
logger := log.WithField("request_id", requestID)
|
||||||
var err error
|
|
||||||
|
|
||||||
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
|
|
||||||
|
u.prepareRequest(r)
|
||||||
|
|
||||||
|
if u.isContextDone(logger) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.tryUpstreamServers(w, r, logger) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u.writeErrorResponse(w, r, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||||
if r.Extra == nil {
|
if r.Extra == nil {
|
||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) isContextDone(logger *log.Entry) bool {
|
||||||
select {
|
select {
|
||||||
case <-u.ctx.Done():
|
case <-u.ctx.Done():
|
||||||
logger.Tracef("%s has been stopped", u)
|
logger.Tracef("%s has been stopped", u)
|
||||||
return
|
return true
|
||||||
default:
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
|
||||||
|
timeout := u.upstreamTimeout
|
||||||
|
if len(u.upstreamServers) > 1 {
|
||||||
|
maxTotal := 5 * time.Second
|
||||||
|
minPerUpstream := 2 * time.Second
|
||||||
|
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
|
||||||
|
if scaledTimeout > minPerUpstream {
|
||||||
|
timeout = scaledTimeout
|
||||||
|
} else {
|
||||||
|
timeout = minPerUpstream
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, upstream := range u.upstreamServers {
|
for _, upstream := range u.upstreamServers {
|
||||||
var rm *dns.Msg
|
if u.queryUpstream(w, r, upstream, timeout, logger) {
|
||||||
var t time.Duration
|
return true
|
||||||
|
|
||||||
func() {
|
|
||||||
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
|
||||||
defer cancel()
|
|
||||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
|
||||||
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if rm == nil || !rm.Response {
|
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream string, timeout time.Duration, logger *log.Entry) bool {
|
||||||
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
var rm *dns.Msg
|
||||||
continue
|
var t time.Duration
|
||||||
}
|
var err error
|
||||||
|
|
||||||
u.successCount.Add(1)
|
var startTime time.Time
|
||||||
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
|
func() {
|
||||||
|
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
startTime = time.Now()
|
||||||
|
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r)
|
||||||
|
}()
|
||||||
|
|
||||||
if err = w.WriteMsg(rm); err != nil {
|
if err != nil {
|
||||||
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
|
u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
|
||||||
}
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if rm == nil || !rm.Response {
|
||||||
|
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||||
|
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
elapsed := time.Since(startTime)
|
||||||
|
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
|
||||||
|
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
|
||||||
|
timeoutMsg += " " + peerInfo
|
||||||
|
}
|
||||||
|
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
||||||
|
logger.Warnf(timeoutMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream, domain string, t time.Duration, logger *log.Entry) bool {
|
||||||
|
u.successCount.Add(1)
|
||||||
|
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
|
||||||
|
|
||||||
|
if err := w.WriteMsg(rm); err != nil {
|
||||||
|
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
|
||||||
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
||||||
|
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
@@ -355,3 +415,97 @@ func GenerateRequestID() string {
|
|||||||
}
|
}
|
||||||
return hex.EncodeToString(bytes)
|
return hex.EncodeToString(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
||||||
|
func FormatPeerStatus(peerState *peer.State) string {
|
||||||
|
isConnected := peerState.ConnStatus == peer.StatusConnected
|
||||||
|
hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() &&
|
||||||
|
time.Since(peerState.LastWireguardHandshake) < 3*time.Minute
|
||||||
|
|
||||||
|
statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case !isConnected:
|
||||||
|
statusInfo += " DISCONNECTED"
|
||||||
|
case !hasRecentHandshake:
|
||||||
|
statusInfo += " NO_RECENT_HANDSHAKE"
|
||||||
|
default:
|
||||||
|
statusInfo += " connected"
|
||||||
|
}
|
||||||
|
|
||||||
|
if !peerState.LastWireguardHandshake.IsZero() {
|
||||||
|
timeSinceHandshake := time.Since(peerState.LastWireguardHandshake)
|
||||||
|
statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second))
|
||||||
|
} else {
|
||||||
|
statusInfo += " no_handshake"
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerState.Relayed {
|
||||||
|
statusInfo += " via_relay"
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerState.Latency > 0 {
|
||||||
|
statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency)
|
||||||
|
}
|
||||||
|
|
||||||
|
return statusInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// findPeerForIP finds which peer handles the given IP address
|
||||||
|
func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
|
||||||
|
if statusRecorder == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fullStatus := statusRecorder.GetFullStatus()
|
||||||
|
var bestMatch *peer.State
|
||||||
|
var bestPrefixLen int
|
||||||
|
|
||||||
|
for _, peerState := range fullStatus.Peers {
|
||||||
|
routes := peerState.GetRoutes()
|
||||||
|
for route := range routes {
|
||||||
|
prefix, err := netip.ParsePrefix(route)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen {
|
||||||
|
peerStateCopy := peerState
|
||||||
|
bestMatch = &peerStateCopy
|
||||||
|
bestPrefixLen = prefix.Bits()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bestMatch
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseUpstreamIP parses an upstream server address to extract the IP
|
||||||
|
func parseUpstreamIP(upstream string) (netip.Addr, error) {
|
||||||
|
upstreamIP, err := netip.ParseAddr(upstream)
|
||||||
|
if err != nil {
|
||||||
|
if host, _, err := net.SplitHostPort(upstream); err == nil {
|
||||||
|
return netip.ParseAddr(host)
|
||||||
|
}
|
||||||
|
return netip.Addr{}, err
|
||||||
|
}
|
||||||
|
return upstreamIP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream string) string {
|
||||||
|
if u.statusRecorder == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamIP, err := parseUpstreamIP(upstream)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
peerInfo := findPeerForIP(upstreamIP, u.statusRecorder)
|
||||||
|
if peerInfo == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
|
||||||
|
}
|
||||||
|
@@ -33,6 +33,7 @@ import (
|
|||||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
@@ -696,6 +697,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.dnsServer != nil {
|
||||||
|
serverDomains := config.ExtractFromNetbirdConfig(wCfg)
|
||||||
|
if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
|
||||||
|
log.Warnf("Failed to update DNS server config: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// todo update signal
|
// todo update signal
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1604,7 +1612,19 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbird
|
|||||||
return dnsServer, nil
|
return dnsServer, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS, mgmtURL, netbirdConfig)
|
// Extract domains from NetBird configuration
|
||||||
|
serverDomains := config.ExtractFromNetbirdConfig(netbirdConfig)
|
||||||
|
|
||||||
|
dnsServer, err := dns.NewDefaultServer(dns.DefaultServerConfig{
|
||||||
|
Ctx: e.ctx,
|
||||||
|
WgInterface: e.wgInterface,
|
||||||
|
CustomAddress: e.config.CustomDNSAddress,
|
||||||
|
StatusRecorder: e.statusRecorder,
|
||||||
|
StateManager: e.stateManager,
|
||||||
|
DisableSys: e.config.DisableDNS,
|
||||||
|
MgmtURL: mgmtURL,
|
||||||
|
ServerDomains: serverDomains,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -2,11 +2,13 @@ package dnsinterceptor
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@@ -26,6 +28,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const dnsTimeout = 8 * time.Second
|
||||||
|
|
||||||
type domainMap map[domain.Domain][]netip.Prefix
|
type domainMap map[domain.Domain][]netip.Prefix
|
||||||
|
|
||||||
type internalDNATer interface {
|
type internalDNATer interface {
|
||||||
@@ -243,7 +247,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout)
|
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
|
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
|
||||||
return
|
return
|
||||||
@@ -254,9 +258,20 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||||
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
|
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
elapsed := time.Since(startTime)
|
||||||
|
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
|
||||||
|
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
|
||||||
|
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
|
||||||
|
} else {
|
||||||
|
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||||
|
}
|
||||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||||
logger.Errorf("failed writing DNS response: %v", err)
|
logger.Errorf("failed writing DNS response: %v", err)
|
||||||
}
|
}
|
||||||
@@ -568,3 +583,16 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
|
||||||
|
if d.statusRecorder == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
peerState, err := d.statusRecorder.GetPeer(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState))
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user