diff --git a/client/internal/connect.go b/client/internal/connect.go index 7b49fa3ad..b8aad354b 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -259,7 +259,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan peerConfig := loginResp.GetPeerConfig() - engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig) + engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, loginResp.GetNetbirdConfig()) if err != nil { log.Error(err) return wrapErr(err) @@ -413,7 +413,7 @@ func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) { } // createEngineConfig converts configuration received from Management Service to EngineConfig -func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { +func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, netbirdConfig *mgmProto.NetbirdConfig) (*EngineConfig, error) { nm := false if config.NetworkMonitor != nil { nm = *config.NetworkMonitor @@ -442,6 +442,8 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe BlockInbound: config.BlockInbound, LazyConnectionEnabled: config.LazyConnectionEnabled, + ManagementURL: config.ManagementURL, + NetbirdConfig: netbirdConfig, } if config.PreSharedKey != "" { diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 7e7e7cc2d..ee80113e1 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -11,10 +11,11 @@ import ( ) const ( - PriorityLocal = 100 - PriorityDNSRoute = 75 - PriorityUpstream = 50 - PriorityDefault = 1 + PriorityMgmtCache = 150 + PriorityLocal = 100 + PriorityDNSRoute = 75 + PriorityUpstream = 50 + PriorityDefault = 1 ) type SubdomainMatcher interface { diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index c19356a7e..c714331b8 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -34,7 +34,7 @@ func (d *Resolver) MatchSubdomains() bool { // String returns a string representation of the local resolver func (d *Resolver) String() string { - return fmt.Sprintf("local resolver [%d records]", len(d.records)) + return fmt.Sprintf("LocalResolver [%d records]", len(d.records)) } func (d *Resolver) Stop() {} diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go new file mode 100644 index 000000000..ba8b1640b --- /dev/null +++ b/client/internal/dns/mgmt/mgmt.go @@ -0,0 +1,504 @@ +package mgmt + +import ( + "context" + "errors" + "net" + "net/netip" + "net/url" + "strings" + "sync" + "time" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/domain" + mgmProto "github.com/netbirdio/netbird/management/proto" +) + +// CacheEntry holds DNS records for a cached domain +type CacheEntry struct { + ARecords []dns.RR + AAAARecords []dns.RR +} + +// Resolver caches critical NetBird infrastructure domains +type Resolver struct { + cache map[domain.Domain]CacheEntry + mutex sync.RWMutex + systemResolver *net.Resolver +} + +// NewResolver creates a new management domains cache resolver. +func NewResolver() *Resolver { + return &Resolver{ + cache: make(map[domain.Domain]CacheEntry), + systemResolver: net.DefaultResolver, + } +} + +// String returns a string representation of the resolver. +func (m *Resolver) String() string { + return "MgmtCacheResolver" +} + +// ServeDNS implements dns.Handler interface. +func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + m.continueToNext(w, r) + return + } + + question := r.Question[0] + qname := strings.ToLower(strings.TrimSuffix(question.Name, ".")) + + if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA { + m.continueToNext(w, r) + return + } + + log.Tracef("MgmtCache: checking cache for domain=%s type=%s", qname, dns.TypeToString[question.Qtype]) + + m.mutex.RLock() + parsedDomain, err := domain.FromString(qname) + if err != nil { + log.Tracef("MgmtCache: invalid domain format: %s", qname) + m.mutex.RUnlock() + m.continueToNext(w, r) + return + } + + entry, found := m.cache[parsedDomain] + m.mutex.RUnlock() + + if !found { + log.Tracef("MgmtCache: no cache entry found for domain=%s", qname) + m.continueToNext(w, r) + return + } + + resp := &dns.Msg{} + resp.SetReply(r) + resp.Authoritative = false + resp.RecursionAvailable = true + + var records []dns.RR + if question.Qtype == dns.TypeA { + records = entry.ARecords + } else if question.Qtype == dns.TypeAAAA { + records = entry.AAAARecords + } + + if len(records) == 0 { + log.Tracef("MgmtCache: no %s records for domain=%s", dns.TypeToString[question.Qtype], parsedDomain.SafeString()) + m.continueToNext(w, r) + return + } + + for _, rr := range records { + rrCopy := dns.Copy(rr) + rrCopy.Header().Name = question.Name + resp.Answer = append(resp.Answer, rrCopy) + } + + log.Tracef("MgmtCache: serving %d cached records for domain=%s", len(resp.Answer), parsedDomain.SafeString()) + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("MgmtCache: failed to write response: %v", err) + } +} + +// MatchSubdomains always returns true as required by the interface. +func (m *Resolver) MatchSubdomains() bool { + return true +} + +// continueToNext signals the handler chain to continue to the next handler. +func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true + if err := w.WriteMsg(resp); err != nil { + log.Errorf("MgmtCache: failed to write continue signal: %v", err) + } +} + +// AddDomain manually adds a domain to cache by resolving it. +func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { + log.Debugf("MgmtCache: adding domain=%s to cache", d.SafeString()) + + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + var aRecords, aaaaRecords []dns.RR + + if ips, err := m.systemResolver.LookupNetIP(ctx, "ip", d.PunycodeString()); err == nil { + for _, ip := range ips { + if ip.Is4() { + rr := &dns.A{ + Hdr: dns.RR_Header{ + Name: d.PunycodeString() + ".", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: ip.AsSlice(), + } + aRecords = append(aRecords, rr) + } else if ip.Is6() { + rr := &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: d.PunycodeString() + ".", + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 300, + }, + AAAA: ip.AsSlice(), + } + aaaaRecords = append(aaaaRecords, rr) + } + } + + m.mutex.Lock() + m.cache[d] = CacheEntry{ + ARecords: aRecords, + AAAARecords: aaaaRecords, + } + m.mutex.Unlock() + + log.Debugf("MgmtCache: added domain=%s with %d A records and %d AAAA records", + d.SafeString(), len(aRecords), len(aaaaRecords)) + } else { + log.Warnf("MgmtCache: failed to resolve domain=%s: %v", d.SafeString(), err) + return err + } + + return nil +} + +// PopulateFromConfig extracts and caches domains from the client configuration. +func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error { + if mgmtURL != nil { + if d, err := extractDomainFromURL(mgmtURL); err == nil { + if err := m.AddDomain(ctx, d); err != nil { + log.Warnf("MgmtCache: failed to add management domain: %v", err) + } + } + } + + return nil +} + +// PopulateFromNetbirdConfig extracts and caches domains from the netbird config. +func (m *Resolver) PopulateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) error { + if config == nil { + return nil + } + + m.addSignalDomain(ctx, config.Signal) + m.addRelayDomains(ctx, config.Relay) + m.addFlowDomain(ctx, config.Flow) + m.addStunDomains(ctx, config.Stuns) + m.addTurnDomains(ctx, config.Turns) + + return nil +} + +// addSignalDomain adds signal server domain to cache. +func (m *Resolver) addSignalDomain(ctx context.Context, signal *mgmProto.HostConfig) { + if signal == nil || signal.Uri == "" { + return + } + + signalURL, err := url.Parse(signal.Uri) + if err != nil { + // If parsing fails, it might be a raw host:port, try adding a scheme + signalURL, err = url.Parse("https://" + signal.Uri) + if err != nil { + log.Warnf("MgmtCache: failed to parse signal URL: %v", err) + return + } + } + + d, err := extractDomainFromURL(signalURL) + if err != nil { + log.Warnf("MgmtCache: failed to extract signal domain: %v", err) + return + } + + if err := m.AddDomain(ctx, d); err != nil { + log.Warnf("MgmtCache: failed to add signal domain: %v", err) + } +} + +// addRelayDomains adds relay server domains to cache. +func (m *Resolver) addRelayDomains(ctx context.Context, relay *mgmProto.RelayConfig) { + if relay == nil { + return + } + + for _, relayAddr := range relay.Urls { + relayURL, err := url.Parse(relayAddr) + if err != nil { + log.Warnf("MgmtCache: failed to parse relay URL %s: %v", relayAddr, err) + continue + } + + d, err := extractDomainFromURL(relayURL) + if err != nil { + log.Warnf("MgmtCache: failed to extract relay domain from %s: %v", relayAddr, err) + continue + } + + if err := m.AddDomain(ctx, d); err != nil { + log.Warnf("MgmtCache: failed to add relay domain: %v", err) + } + } +} + +// addFlowDomain adds traffic flow server domain to cache. +func (m *Resolver) addFlowDomain(ctx context.Context, flow *mgmProto.FlowConfig) { + if flow == nil || flow.Url == "" { + return + } + + flowURL, err := url.Parse(flow.Url) + if err != nil { + log.Warnf("MgmtCache: failed to parse flow URL: %v", err) + return + } + + d, err := extractDomainFromURL(flowURL) + if err != nil { + log.Warnf("MgmtCache: failed to extract flow domain: %v", err) + return + } + + if err := m.AddDomain(ctx, d); err != nil { + log.Warnf("MgmtCache: failed to add flow domain: %v", err) + } +} + +// GetCachedDomains returns a list of all cached domains. +func (m *Resolver) GetCachedDomains() []domain.Domain { + m.mutex.RLock() + defer m.mutex.RUnlock() + + domains := make([]domain.Domain, 0, len(m.cache)) + for d := range m.cache { + domains = append(domains, d) + } + return domains +} + +// ClearCache removes all cached domains and returns them for external deregistration. +func (m *Resolver) ClearCache() []domain.Domain { + m.mutex.Lock() + defer m.mutex.Unlock() + + domains := make([]domain.Domain, 0, len(m.cache)) + for d := range m.cache { + domains = append(domains, d) + } + + m.cache = make(map[domain.Domain]CacheEntry) + log.Debugf("MgmtCache: cleared %d cached domains", len(domains)) + + return domains +} + +// UpdateFromNetbirdConfig updates the cache intelligently by comparing current and new configurations. +// Returns domains that were removed for external deregistration. +func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) ([]domain.Domain, error) { + log.Debugf("MgmtCache: updating cache from NetbirdConfig") + + currentDomains := m.GetCachedDomains() + newDomains := m.extractDomainsFromConfig(config) + + var removedDomains []domain.Domain + for _, currentDomain := range currentDomains { + found := false + for _, newDomain := range newDomains { + if currentDomain.SafeString() == newDomain.SafeString() { + found = true + break + } + } + if !found { + removedDomains = append(removedDomains, currentDomain) + } + } + + m.mutex.Lock() + for _, domainToRemove := range removedDomains { + delete(m.cache, domainToRemove) + log.Debugf("MgmtCache: removed domain=%s from cache", domainToRemove.SafeString()) + } + m.mutex.Unlock() + + for _, newDomain := range newDomains { + if err := m.AddDomain(ctx, newDomain); err != nil { + log.Warnf("MgmtCache: failed to add/update domain=%s: %v", newDomain.SafeString(), err) + } + } + + return removedDomains, nil +} + +// extractDomainsFromConfig extracts all domains from a NetbirdConfig. +func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []domain.Domain { + if config == nil { + return nil + } + + var domains []domain.Domain + + if config.Signal != nil && config.Signal.Uri != "" { + if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil { + domains = append(domains, d) + } + } + + if config.Relay != nil { + for _, relayURL := range config.Relay.Urls { + if d, err := m.extractDomainFromURL(relayURL); err == nil { + domains = append(domains, d) + } + } + } + + if config.Flow != nil && config.Flow.Url != "" { + if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil { + domains = append(domains, d) + } + } + + for _, stun := range config.Stuns { + if stun != nil && stun.Uri != "" { + if d, err := m.extractDomainFromURL(stun.Uri); err == nil { + domains = append(domains, d) + } + } + } + + for _, turn := range config.Turns { + if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" { + if d, err := m.extractDomainFromURL(turn.HostConfig.Uri); err == nil { + domains = append(domains, d) + } + } + } + + return domains +} + +// extractDomainFromSignalConfig extracts domain from signal configuration. +func (m *Resolver) extractDomainFromSignalConfig(signal *mgmProto.HostConfig) (domain.Domain, error) { + signalURL, err := url.Parse(signal.Uri) + if err != nil { + // If parsing fails, it might be a raw host:port, try adding a scheme + signalURL, err = url.Parse("https://" + signal.Uri) + if err != nil { + return "", err + } + } + return extractDomainFromURL(signalURL) +} + +// extractDomainFromURL extracts domain from a URL string. +func (m *Resolver) extractDomainFromURL(urlStr string) (domain.Domain, error) { + parsedURL, err := url.Parse(urlStr) + if err != nil { + return "", err + } + return extractDomainFromURL(parsedURL) +} + +// addStunDomains adds STUN server domains to cache. +func (m *Resolver) addStunDomains(ctx context.Context, stuns []*mgmProto.HostConfig) { + for _, stun := range stuns { + if stun == nil || stun.Uri == "" { + continue + } + + stunURL, err := url.Parse(stun.Uri) + if err != nil { + log.Warnf("MgmtCache: failed to parse STUN URL %s: %v", stun.Uri, err) + continue + } + + d, err := extractDomainFromURL(stunURL) + if err != nil { + log.Warnf("MgmtCache: failed to extract STUN domain from %s: %v", stun.Uri, err) + continue + } + + if err := m.AddDomain(ctx, d); err != nil { + log.Warnf("MgmtCache: failed to add STUN domain: %v", err) + } + } +} + +// addTurnDomains adds TURN server domains to cache. +func (m *Resolver) addTurnDomains(ctx context.Context, turns []*mgmProto.ProtectedHostConfig) { + for _, turn := range turns { + if turn == nil || turn.HostConfig == nil || turn.HostConfig.Uri == "" { + continue + } + + turnURL, err := url.Parse(turn.HostConfig.Uri) + if err != nil { + log.Warnf("MgmtCache: failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err) + continue + } + + d, err := extractDomainFromURL(turnURL) + if err != nil { + log.Warnf("MgmtCache: failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err) + continue + } + + if err := m.AddDomain(ctx, d); err != nil { + log.Warnf("MgmtCache: failed to add TURN domain: %v", err) + } + } +} + +// extractDomainFromURL extracts the domain from a URL. +func extractDomainFromURL(u *url.URL) (domain.Domain, error) { + if u == nil { + return "", errors.New("invalid URL") + } + + host := u.Host + // If Host is empty, try to extract from Opaque (for schemes like stun:domain:port) + if host == "" && u.Opaque != "" { + host = u.Opaque + } + if host == "" && u.Path != "" { + host = strings.TrimPrefix(u.Path, "/") + } + + if host == "" { + return "", errors.New("empty host") + } + + host, _, err := net.SplitHostPort(host) + if err != nil { + switch { + case u.Host != "": + host = u.Host + case u.Opaque != "": + host = u.Opaque + default: + host = strings.TrimPrefix(u.Path, "/") + } + } + + if _, err := netip.ParseAddr(host); err == nil { + return "", errors.New("host is an IP address, skipping") + } + + return domain.FromString(host) +} diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go new file mode 100644 index 000000000..bd4aec99b --- /dev/null +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -0,0 +1,227 @@ +package mgmt + +import ( + "context" + "net" + "net/url" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + + mgmProto "github.com/netbirdio/netbird/management/proto" +) + +func TestResolver_NewResolver(t *testing.T) { + resolver := NewResolver() + + assert.NotNil(t, resolver) + assert.NotNil(t, resolver.cache) + assert.True(t, resolver.MatchSubdomains()) +} + +func TestResolver_ExtractDomainFromURL(t *testing.T) { + tests := []struct { + name string + urlStr string + expectedDom string + expectError bool + }{ + { + name: "HTTPS URL with port", + urlStr: "https://api.netbird.io:443", + expectedDom: "api.netbird.io", + expectError: false, + }, + { + name: "HTTP URL without port", + urlStr: "http://signal.example.com", + expectedDom: "signal.example.com", + expectError: false, + }, + { + name: "URL with path", + urlStr: "https://relay.netbird.io/status", + expectedDom: "relay.netbird.io", + expectError: false, + }, + { + name: "Invalid URL", + urlStr: "not-a-valid-url", + expectedDom: "not-a-valid-url", + expectError: false, + }, + { + name: "Empty URL", + urlStr: "", + expectedDom: "", + expectError: true, + }, + { + name: "STUN URL", + urlStr: "stun:stun.example.com:3478", + expectedDom: "stun.example.com", + expectError: false, + }, + { + name: "TURN URL", + urlStr: "turn:turn.example.com:3478", + expectedDom: "turn.example.com", + expectError: false, + }, + { + name: "REL URL", + urlStr: "rel://relay.example.com:443", + expectedDom: "relay.example.com", + expectError: false, + }, + { + name: "RELS URL", + urlStr: "rels://relay.example.com:443", + expectedDom: "relay.example.com", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var parsedURL *url.URL + var err error + + if tt.urlStr != "" { + parsedURL, err = url.Parse(tt.urlStr) + if err != nil && !tt.expectError { + t.Fatalf("Failed to parse URL: %v", err) + } + } + + domain, err := extractDomainFromURL(parsedURL) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedDom, domain.SafeString()) + } + }) + } +} + +func TestResolver_PopulateFromConfig(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := NewResolver() + + mgmtURL, _ := url.Parse("https://api.netbird.io") + + err := resolver.PopulateFromConfig(ctx, mgmtURL) + assert.NoError(t, err) + + // Give some time for async population + time.Sleep(100 * time.Millisecond) + + domains := resolver.GetCachedDomains() + assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature +} + +func TestResolver_PopulateFromNetbirdConfig(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := NewResolver() + + netbirdConfig := &mgmProto.NetbirdConfig{ + Signal: &mgmProto.HostConfig{ + Uri: "https://signal.netbird.io", + }, + Relay: &mgmProto.RelayConfig{ + Urls: []string{ + "https://relay1.netbird.io:443", + "https://relay2.netbird.io:443", + }, + }, + Flow: &mgmProto.FlowConfig{ + Url: "https://flow.netbird.io:80", + }, + Stuns: []*mgmProto.HostConfig{ + {Uri: "stun:stun1.netbird.io:3478"}, + {Uri: "stun:stun2.netbird.io:3478"}, + }, + Turns: []*mgmProto.ProtectedHostConfig{ + { + HostConfig: &mgmProto.HostConfig{ + Uri: "turn:turn1.netbird.io:3478", + }, + }, + { + HostConfig: &mgmProto.HostConfig{ + Uri: "turn:turn2.netbird.io:3478", + }, + }, + }, + } + + err := resolver.PopulateFromNetbirdConfig(ctx, netbirdConfig) + assert.NoError(t, err) + + // Give some time for async population + time.Sleep(100 * time.Millisecond) + + domains := resolver.GetCachedDomains() + assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature +} + +func TestResolver_ContinueToNext(t *testing.T) { + resolver := NewResolver() + + // Create a mock response writer to capture the response + mockWriter := &MockResponseWriter{} + + // Create a test DNS query + req := new(dns.Msg) + req.SetQuestion("unknown.example.com.", dns.TypeA) + + // Call continueToNext + resolver.continueToNext(mockWriter, req) + + // Verify the response + assert.NotNil(t, mockWriter.msg) + assert.Equal(t, dns.RcodeNameError, mockWriter.msg.Rcode) + assert.True(t, mockWriter.msg.MsgHdr.Zero) +} + +// MockResponseWriter is a simple mock implementation of dns.ResponseWriter for testing +type MockResponseWriter struct { + msg *dns.Msg +} + +func (m *MockResponseWriter) LocalAddr() net.Addr { + return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 53} +} + +func (m *MockResponseWriter) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345} +} + +func (m *MockResponseWriter) WriteMsg(msg *dns.Msg) error { + m.msg = msg + return nil +} + +func (m *MockResponseWriter) Write([]byte) (int, error) { + return 0, nil +} + +func (m *MockResponseWriter) Close() error { + return nil +} + +func (m *MockResponseWriter) TsigStatus() error { + return nil +} + +func (m *MockResponseWriter) TsigTimersOnly(bool) {} + +func (m *MockResponseWriter) Hijack() {} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index e81aebf98..1fa76c8ef 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/netip" + "net/url" "runtime" "strings" "sync" @@ -16,6 +17,7 @@ import ( "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns/local" + "github.com/netbirdio/netbird/client/internal/dns/mgmt" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" @@ -23,6 +25,7 @@ import ( cProto "github.com/netbirdio/netbird/client/proto" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" + mgmProto "github.com/netbirdio/netbird/management/proto" ) // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes @@ -70,6 +73,9 @@ type DefaultServer struct { handlerChain *HandlerChain extraDomains map[domain.Domain]int + // management cache resolver for critical infrastructure domains + mgmtCacheResolver *mgmt.Resolver + // permanent related properties permanent bool hostsDNSHolder *hostsDNSHolder @@ -105,6 +111,8 @@ func NewDefaultServer( statusRecorder *peer.Status, stateManager *statemanager.Manager, disableSys bool, + mgmtURL *url.URL, + netbirdConfig *mgmProto.NetbirdConfig, ) (*DefaultServer, error) { var addrPort *netip.AddrPort if customAddress != "" { @@ -122,7 +130,29 @@ func NewDefaultServer( dnsService = newServiceViaListener(wgInterface, addrPort) } - return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil + server := newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys) + + // Pre-populate management cache with management URL + if mgmtURL != nil && server.mgmtCacheResolver != nil { + if err := server.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL); err != nil { + log.Warnf("Failed to populate management cache from management URL: %v", err) + } + } + + // Pre-populate management cache with NetbirdConfig domains + if netbirdConfig != nil && server.mgmtCacheResolver != nil { + if err := server.mgmtCacheResolver.PopulateFromNetbirdConfig(ctx, netbirdConfig); err != nil { + log.Warnf("Failed to populate management cache from NetbirdConfig: %v", err) + } + + // Register newly populated domains + domains := server.mgmtCacheResolver.GetCachedDomains() + if len(domains) > 0 { + server.RegisterHandler(domains, server.mgmtCacheResolver, PriorityMgmtCache) + } + } + + return server, nil } // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems @@ -170,21 +200,39 @@ func newDefaultServer( ) *DefaultServer { handlerChain := NewHandlerChain() ctx, stop := context.WithCancel(ctx) + + // Create management cache resolver + mgmtCacheResolver := mgmt.NewResolver() + defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - disableSys: disableSys, - service: dnsService, - handlerChain: handlerChain, - extraDomains: make(map[domain.Domain]int), - dnsMuxMap: make(registeredHandlerMap), - localResolver: local.NewResolver(), - wgInterface: wgInterface, - statusRecorder: statusRecorder, - stateManager: stateManager, - hostsDNSHolder: newHostsDNSHolder(), + ctx: ctx, + ctxCancel: stop, + disableSys: disableSys, + service: dnsService, + handlerChain: handlerChain, + extraDomains: make(map[domain.Domain]int), + dnsMuxMap: make(registeredHandlerMap), + localResolver: local.NewResolver(), + wgInterface: wgInterface, + statusRecorder: statusRecorder, + stateManager: stateManager, + hostsDNSHolder: newHostsDNSHolder(), + mgmtCacheResolver: mgmtCacheResolver, } + // Register cached domains with the handler chain + registerMgmtCacheDomains := func() { + domains := mgmtCacheResolver.GetCachedDomains() + if len(domains) > 0 { + defaultServer.RegisterHandler(domains, mgmtCacheResolver, PriorityMgmtCache) + } + } + + // Register any pre-populated domains from management cache + registerMgmtCacheDomains() + + // Management cache resolver will be registered for specific domains when they are added + // register with root zone, handler chain takes care of the routing dnsService.RegisterMux(".", handlerChain) @@ -208,7 +256,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler } func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { - log.Debugf("registering handler %s with priority %d", handler, priority) + log.Debugf("registering handler %s with priority %d for %v", handler, priority, domains) for _, domain := range domains { if domain == "" { @@ -236,7 +284,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) { } func (s *DefaultServer) deregisterHandler(domains []string, priority int) { - log.Debugf("deregistering handler %v with priority %d", domains, priority) + log.Debugf("deregistering handler with priority %d for %v", priority, domains) for _, domain := range domains { if domain == "" { @@ -304,11 +352,32 @@ func (s *DefaultServer) Stop() { } } + s.service.Stop() maps.Clear(s.extraDomains) } +// PopulateMgmtCacheFromConfig populates the management cache with domains from the client configuration +func (s *DefaultServer) PopulateMgmtCacheFromConfig(mgmtURL *url.URL) error { + if s.mgmtCacheResolver == nil { + return fmt.Errorf("management cache resolver not initialized") + } + + log.Debug("populating management cache from client configuration") + return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL) +} + +// PopulateMgmtCacheFromNetbirdConfig populates the management cache with domains from the netbird configuration +func (s *DefaultServer) PopulateMgmtCacheFromNetbirdConfig(config *mgmProto.NetbirdConfig) error { + if s.mgmtCacheResolver == nil { + return fmt.Errorf("management cache resolver not initialized") + } + + log.Debug("populating management cache from netbird configuration") + return s.mgmtCacheResolver.PopulateFromNetbirdConfig(s.ctx, config) +} + // OnUpdatedHostDNSServer update the DNS servers addresses for root zones // It will be applied if the mgm server do not enforce DNS settings for root zone diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 21a9e2f2d..4a806be3f 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -363,7 +363,7 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil) if err != nil { t.Fatal(err) } @@ -473,7 +473,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -575,7 +575,7 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false, nil, nil) if err != nil { t.Fatalf("%v", err) } diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index c7c609ebe..04d25790d 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -75,7 +75,7 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d // String returns a string representation of the upstream resolver func (u *upstreamResolverBase) String() string { - return fmt.Sprintf("upstream %v", u.upstreamServers) + return fmt.Sprintf("Upstream %v", u.upstreamServers) } // ID returns the unique handler ID diff --git a/client/internal/engine.go b/client/internal/engine.go index e9772b359..9104f042f 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -7,6 +7,7 @@ import ( "math/rand" "net" "net/netip" + "net/url" "reflect" "runtime" "slices" @@ -124,6 +125,12 @@ type EngineConfig struct { BlockInbound bool LazyConnectionEnabled bool + + // ManagementURL is the URL of the management server for DNS cache + ManagementURL *url.URL + + // NetbirdConfig contains signal, relay, and flow server configuration + NetbirdConfig *mgmProto.NetbirdConfig } // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. @@ -387,7 +394,7 @@ func (e *Engine) Start() error { return fmt.Errorf("read initial settings: %w", err) } - dnsServer, err := e.newDnsServer(dnsConfig) + dnsServer, err := e.newDnsServer(dnsConfig, e.config.ManagementURL, e.config.NetbirdConfig) if err != nil { e.close() return fmt.Errorf("create dns server: %w", err) @@ -1572,7 +1579,7 @@ func (e *Engine) wgInterfaceCreate() (err error) { return err } -func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { +func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbirdConfig *mgmProto.NetbirdConfig) (dns.Server, error) { // due to tests where we are using a mocked version of the DNS server if e.dnsServer != nil { return e.dnsServer, nil @@ -1597,7 +1604,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { return dnsServer, nil default: - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS) + dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS, mgmtURL, netbirdConfig) if err != nil { return nil, err } @@ -1616,6 +1623,11 @@ func (e *Engine) GetFirewallManager() firewallManager.Manager { return e.firewall } +// GetDNSServer returns the DNS server +func (e *Engine) GetDNSServer() dns.Server { + return e.dnsServer +} + func findIPFromInterfaceName(ifaceName string) (net.IP, error) { iface, err := net.InterfaceByName(ifaceName) if err != nil {