From 90bf1baec203a66888f3291e56b731c71571d0c4 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 11 Jul 2025 11:06:19 +0200 Subject: [PATCH] Cleanup --- client/internal/connect.go | 5 +- client/internal/dns/config/domains.go | 38 +- client/internal/dns/config/domains_test.go | 148 +++++++ client/internal/dns/mgmt/mgmt.go | 478 +++++---------------- client/internal/dns/mgmt/mgmt_test.go | 266 +++++++----- client/internal/dns/server.go | 46 +- client/internal/dns/server_test.go | 16 +- client/internal/engine.go | 66 +-- client/internal/engine_test.go | 6 +- client/internal/login.go | 18 +- 10 files changed, 504 insertions(+), 583 deletions(-) create mode 100644 client/internal/dns/config/domains_test.go diff --git a/client/internal/connect.go b/client/internal/connect.go index b8aad354b..ee6158cf6 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -272,11 +272,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan c.engine.SetNetworkMapPersistence(c.persistNetworkMap) c.engineMutex.Unlock() - if err := c.engine.Start(); err != nil { + if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { log.Errorf("error while starting Netbird Connection Engine: %s", err) return wrapErr(err) } + log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) @@ -442,8 +443,6 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe BlockInbound: config.BlockInbound, LazyConnectionEnabled: config.LazyConnectionEnabled, - ManagementURL: config.ManagementURL, - NetbirdConfig: netbirdConfig, } if config.PreSharedKey != "" { diff --git a/client/internal/dns/config/domains.go b/client/internal/dns/config/domains.go index 067608b2e..27df39d1d 100644 --- a/client/internal/dns/config/domains.go +++ b/client/internal/dns/config/domains.go @@ -5,6 +5,7 @@ import ( "net" "net/netip" "net/url" + "strings" log "github.com/sirupsen/logrus" @@ -40,19 +41,34 @@ func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains { // extractValidDomain extracts a valid domain from a URL, filtering out IP addresses func extractValidDomain(rawURL string) (domain.Domain, error) { - parsedURL, err := url.Parse(rawURL) - if err != nil { - // If URL parsing fails, it might be a raw host:port, try parsing as such - if host, _, err := net.SplitHostPort(rawURL); err == nil { - return extractDomainFromHost(host) - } - // If not host:port, try as raw hostname - return extractDomainFromHost(rawURL) + if rawURL == "" { + return "", fmt.Errorf("empty URL") } - host := parsedURL.Hostname() - if host == "" { - return "", fmt.Errorf("no hostname in URL") + // Try standard URL parsing first (handles https://, http://, rels://, etc.) + if parsedURL, err := url.Parse(rawURL); err == nil && parsedURL.Hostname() != "" { + return extractDomainFromHost(parsedURL.Hostname()) + } + + // Extract domain from various formats: + // - stun:domain:port -> domain + // - turns:domain:port?params -> domain + // - domain:port -> domain + host := rawURL + + // Remove scheme prefix (stun:, turn:, turns:) + if colonIndex := strings.Index(host, ":"); colonIndex > 0 && colonIndex < 10 && !strings.Contains(host[:colonIndex], ".") { + host = host[colonIndex+1:] + } + + // Remove port suffix + if hostOnly, _, err := net.SplitHostPort(host); err == nil { + host = hostOnly + } + + // Remove query parameters + if queryIndex := strings.Index(host, "?"); queryIndex > 0 { + host = host[:queryIndex] } return extractDomainFromHost(host) diff --git a/client/internal/dns/config/domains_test.go b/client/internal/dns/config/domains_test.go new file mode 100644 index 000000000..fef808ab1 --- /dev/null +++ b/client/internal/dns/config/domains_test.go @@ -0,0 +1,148 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtractValidDomain(t *testing.T) { + tests := []struct { + name string + url string + expected string + expectError bool + }{ + { + name: "HTTPS URL with port", + url: "https://api.netbird.io:443", + expected: "api.netbird.io", + }, + { + name: "HTTP URL without port", + url: "http://signal.example.com", + expected: "signal.example.com", + }, + { + name: "Host with port (no scheme)", + url: "signal.netbird.io:443", + expected: "signal.netbird.io", + }, + { + name: "STUN URL", + url: "stun:stun.netbird.io:443", + expected: "stun.netbird.io", + }, + { + name: "STUN URL with different port", + url: "stun:stun.netbird.io:5555", + expected: "stun.netbird.io", + }, + { + name: "TURNS URL with query params", + url: "turns:turn.netbird.io:443?transport=tcp", + expected: "turn.netbird.io", + }, + { + name: "TURN URL", + url: "turn:turn.example.com:3478", + expected: "turn.example.com", + }, + { + name: "REL URL", + url: "rel://relay.example.com:443", + expected: "relay.example.com", + }, + { + name: "RELS URL", + url: "rels://relay.netbird.io:443", + expected: "relay.netbird.io", + }, + { + name: "Raw hostname", + url: "example.org", + expected: "example.org", + }, + { + name: "IP address should be rejected", + url: "192.168.1.1", + expectError: true, + }, + { + name: "IP address with port should be rejected", + url: "192.168.1.1:443", + expectError: true, + }, + { + name: "IPv6 address should be rejected", + url: "2001:db8::1", + expectError: true, + }, + { + name: "Empty URL", + url: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := extractValidDomain(tt.url) + + if tt.expectError { + assert.Error(t, err, "Expected error for URL: %s", tt.url) + } else { + assert.NoError(t, err, "Unexpected error for URL: %s", tt.url) + assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for URL: %s", tt.url) + } + }) + } +} + +func TestExtractDomainFromHost(t *testing.T) { + tests := []struct { + name string + host string + expected string + expectError bool + }{ + { + name: "Valid domain", + host: "example.com", + expected: "example.com", + }, + { + name: "Subdomain", + host: "api.example.com", + expected: "api.example.com", + }, + { + name: "IPv4 address", + host: "192.168.1.1", + expectError: true, + }, + { + name: "IPv6 address", + host: "2001:db8::1", + expectError: true, + }, + { + name: "Empty host", + host: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := extractDomainFromHost(tt.host) + + if tt.expectError { + assert.Error(t, err, "Expected error for host: %s", tt.host) + } else { + assert.NoError(t, err, "Unexpected error for host: %s", tt.host) + assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for host: %s", tt.host) + } + }) + } +} diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 993a51297..78dcdae8e 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -16,25 +16,19 @@ import ( dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/management/domain" - mgmProto "github.com/netbirdio/netbird/management/proto" ) -// CacheEntry holds DNS records for a cached domain -type CacheEntry struct { - ARecords []dns.RR - AAAARecords []dns.RR -} - // Resolver caches critical NetBird infrastructure domains type Resolver struct { - cache map[domain.Domain]CacheEntry - mutex sync.RWMutex + records map[dns.Question][]dns.RR + managementDomain *domain.Domain + mutex sync.RWMutex } // NewResolver creates a new management domains cache resolver. func NewResolver() *Resolver { return &Resolver{ - cache: make(map[domain.Domain]CacheEntry), + records: make(map[dns.Question][]dns.RR), } } @@ -51,7 +45,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } question := r.Question[0] - qname := strings.ToLower(strings.TrimSuffix(question.Name, ".")) + question.Name = strings.ToLower(dns.Fqdn(question.Name)) if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA { m.continueToNext(w, r) @@ -59,8 +53,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } m.mutex.RLock() - domainKey := domain.Domain(qname) - entry, found := m.cache[domainKey] + records, found := m.records[question] m.mutex.RUnlock() if !found { @@ -73,34 +66,19 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { resp.Authoritative = false resp.RecursionAvailable = true - var records []dns.RR - if question.Qtype == dns.TypeA { - records = entry.ARecords - } else if question.Qtype == dns.TypeAAAA { - records = entry.AAAARecords - } + resp.Answer = append(resp.Answer, records...) - if len(records) == 0 { - m.continueToNext(w, r) - return - } - - for _, rr := range records { - rrCopy := dns.Copy(rr) - rrCopy.Header().Name = question.Name - resp.Answer = append(resp.Answer, rrCopy) - } - - log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), domainKey.SafeString()) + log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) if err := w.WriteMsg(resp); err != nil { log.Errorf("failed to write response: %v", err) } } -// MatchSubdomains always returns true as required by the interface. +// MatchSubdomains returns false since this resolver only handles exact domain matches +// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains. func (m *Resolver) MatchSubdomains() bool { - return true + return false } // continueToNext signals the handler chain to continue to the next handler. @@ -115,8 +93,6 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { // AddDomain manually adds a domain to cache by resolving it. func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { - log.Debugf("adding domain=%s to cache", d.SafeString()) - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() @@ -130,7 +106,7 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { if ip.Is4() { rr := &dns.A{ Hdr: dns.RR_Header{ - Name: d.PunycodeString() + ".", + Name: strings.ToLower(dns.Fqdn(d.PunycodeString())), Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300, @@ -141,7 +117,7 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { } else if ip.Is6() { rr := &dns.AAAA{ Hdr: dns.RR_Header{ - Name: d.PunycodeString() + ".", + Name: strings.ToLower(dns.Fqdn(d.PunycodeString())), Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 300, @@ -153,10 +129,25 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { } m.mutex.Lock() - m.cache[d] = CacheEntry{ - ARecords: aRecords, - AAAARecords: aaaaRecords, + + if len(aRecords) > 0 { + aQuestion := dns.Question{ + Name: strings.ToLower(dns.Fqdn(d.PunycodeString())), + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + m.records[aQuestion] = aRecords } + + if len(aaaaRecords) > 0 { + aaaaQuestion := dns.Question{ + Name: strings.ToLower(dns.Fqdn(d.PunycodeString())), + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + } + m.records[aaaaQuestion] = aaaaRecords + } + m.mutex.Unlock() log.Debugf("added domain=%s with %d A records and %d AAAA records", @@ -167,12 +158,21 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { // PopulateFromConfig extracts and caches domains from the client configuration. func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error { - if mgmtURL != nil { - if d, err := extractDomainFromURL(mgmtURL); err == nil { - if err := m.AddDomain(ctx, d); err != nil { - log.Warnf("failed to add management domain: %v", err) - } - } + if mgmtURL == nil { + return nil + } + + d, err := extractDomainFromURL(mgmtURL) + if err != nil { + return fmt.Errorf("extract domain from URL: %w", err) + } + + m.mutex.Lock() + m.managementDomain = &d + m.mutex.Unlock() + + if err := m.AddDomain(ctx, d); err != nil { + return fmt.Errorf("add domain: %w", err) } return nil @@ -183,191 +183,95 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error { m.mutex.Lock() defer m.mutex.Unlock() - delete(m.cache, d) + aQuestion := dns.Question{ + Name: strings.ToLower(dns.Fqdn(d.PunycodeString())), + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + delete(m.records, aQuestion) + + aaaaQuestion := dns.Question{ + Name: strings.ToLower(dns.Fqdn(d.PunycodeString())), + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + } + delete(m.records, aaaaQuestion) + log.Debugf("removed domain=%s from cache", d.SafeString()) return nil } -// PopulateFromNetbirdConfig extracts and caches domains from the netbird config. -func (m *Resolver) PopulateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) error { - if config == nil { - return nil - } - - m.addSignalDomain(ctx, config.Signal) - m.addRelayDomains(ctx, config.Relay) - m.addFlowDomain(ctx, config.Flow) - m.addStunDomains(ctx, config.Stuns) - m.addTurnDomains(ctx, config.Turns) - - return nil -} - -// addSignalDomain adds signal server domain to cache. -func (m *Resolver) addSignalDomain(ctx context.Context, signal *mgmProto.HostConfig) { - if signal == nil || signal.Uri == "" { - return - } - - signalURL, err := url.Parse(signal.Uri) - if err != nil { - // If parsing fails, it might be a raw host:port, try adding a scheme - signalURL, err = url.Parse("https://" + signal.Uri) - if err != nil { - log.Warnf("failed to parse signal URL: %v", err) - return - } - } - - d, err := extractDomainFromURL(signalURL) - if err != nil { - log.Warnf("failed to extract signal domain: %v", err) - return - } - - if err := m.AddDomain(ctx, d); err != nil { - log.Warnf("failed to add signal domain: %v", err) - } -} - -// addRelayDomains adds relay server domains to cache. -func (m *Resolver) addRelayDomains(ctx context.Context, relay *mgmProto.RelayConfig) { - if relay == nil { - return - } - - for _, relayAddr := range relay.Urls { - relayURL, err := url.Parse(relayAddr) - if err != nil { - log.Warnf("failed to parse relay URL %s: %v", relayAddr, err) - continue - } - - d, err := extractDomainFromURL(relayURL) - if err != nil { - log.Warnf("failed to extract relay domain from %s: %v", relayAddr, err) - continue - } - - if err := m.AddDomain(ctx, d); err != nil { - log.Warnf("failed to add relay domain: %v", err) - } - } -} - -// addFlowDomain adds traffic flow server domain to cache. -func (m *Resolver) addFlowDomain(ctx context.Context, flow *mgmProto.FlowConfig) { - if flow == nil || flow.Url == "" { - return - } - - flowURL, err := url.Parse(flow.Url) - if err != nil { - log.Warnf("failed to parse flow URL: %v", err) - return - } - - d, err := extractDomainFromURL(flowURL) - if err != nil { - log.Warnf("failed to extract flow domain: %v", err) - return - } - - if err := m.AddDomain(ctx, d); err != nil { - log.Warnf("failed to add flow domain: %v", err) - } -} - // GetCachedDomains returns a list of all cached domains. -func (m *Resolver) GetCachedDomains() []domain.Domain { +func (m *Resolver) GetCachedDomains() domain.List { m.mutex.RLock() defer m.mutex.RUnlock() - domains := make([]domain.Domain, 0, len(m.cache)) - for d := range m.cache { - domains = append(domains, d) - } - return domains -} - -// ClearCache removes all cached domains and returns them for external deregistration. -func (m *Resolver) ClearCache() []domain.Domain { - m.mutex.Lock() - defer m.mutex.Unlock() - - domains := make([]domain.Domain, 0, len(m.cache)) - for d := range m.cache { - domains = append(domains, d) + domainSet := make(map[domain.Domain]struct{}) + for question := range m.records { + domainName := strings.TrimSuffix(question.Name, ".") + domainSet[domain.Domain(domainName)] = struct{}{} } - m.cache = make(map[domain.Domain]CacheEntry) - log.Debugf("cleared %d cached domains", len(domains)) + domains := make(domain.List, 0, len(domainSet)) + for d := range domainSet { + domains = append(domains, d) + } return domains } -// UpdateFromNetbirdConfig updates the cache intelligently by comparing current and new configurations. -// Returns domains that were removed for external deregistration. -func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) ([]domain.Domain, error) { - log.Debugf("updating cache from NetbirdConfig") - +// UpdateFromServerDomains updates the cache using the simplified ServerDomains struct +func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) { currentDomains := m.GetCachedDomains() - newDomains := m.extractDomainsFromConfig(config) + newDomains := m.extractDomainsFromServerDomains(serverDomains) - var removedDomains []domain.Domain - for _, currentDomain := range currentDomains { - found := false - for _, newDomain := range newDomains { - if currentDomain.SafeString() == newDomain.SafeString() { - found = true - break - } - } - if !found { - removedDomains = append(removedDomains, currentDomain) - } - } - - m.mutex.Lock() - for _, domainToRemove := range removedDomains { - delete(m.cache, domainToRemove) - log.Debugf("removed domain=%s from cache", domainToRemove.SafeString()) - } - m.mutex.Unlock() - - for _, newDomain := range newDomains { - if err := m.AddDomain(ctx, newDomain); err != nil { - log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err) - } - } + removedDomains := m.removeStaleDomainsExceptManagement(currentDomains, newDomains) + m.addNewDomains(ctx, newDomains) return removedDomains, nil } -// UpdateFromServerDomains updates the cache using the simplified ServerDomains struct -func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) ([]domain.Domain, error) { - log.Debugf("updating cache from ServerDomains") +// removeStaleDomainsExceptManagement removes domains not in newDomains, except management domain +func (m *Resolver) removeStaleDomainsExceptManagement(currentDomains, newDomains domain.List) domain.List { + var removedDomains domain.List - currentDomains := m.GetCachedDomains() - newDomains := m.extractDomainsFromServerDomains(serverDomains) - - var removedDomains []domain.Domain for _, currentDomain := range currentDomains { - found := false - for _, newDomain := range newDomains { - if currentDomain.SafeString() == newDomain.SafeString() { - found = true - break - } + if m.isDomainInList(currentDomain, newDomains) { + continue } - if !found { - removedDomains = append(removedDomains, currentDomain) - if err := m.RemoveDomain(currentDomain); err != nil { - log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err) - } + + if m.isManagementDomain(currentDomain) { + continue + } + + removedDomains = append(removedDomains, currentDomain) + if err := m.RemoveDomain(currentDomain); err != nil { + log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err) } } + return removedDomains +} + +// isDomainInList checks if domain exists in the list +func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool { + for _, d := range list { + if domain.SafeString() == d.SafeString() { + return true + } + } + return false +} + +// isManagementDomain checks if domain is the protected management domain +func (m *Resolver) isManagementDomain(domain domain.Domain) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + return m.managementDomain != nil && domain.SafeString() == m.managementDomain.SafeString() +} + +// addNewDomains adds all new domains to the cache +func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) { for _, newDomain := range newDomains { if err := m.AddDomain(ctx, newDomain); err != nil { log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err) @@ -375,12 +279,10 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn log.Debugf("added/updated management cache domain=%s", newDomain.SafeString()) } } - - return removedDomains, nil } -func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) []domain.Domain { - var domains []domain.Domain +func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List { + var domains domain.List if serverDomains.Signal != "" { domains = append(domains, serverDomains.Signal) @@ -411,164 +313,6 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve return domains } -// extractDomainsFromConfig extracts all domains from a NetbirdConfig. -func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []domain.Domain { - if config == nil { - return nil - } - - var domains []domain.Domain - - // Extract signal domain - domains = append(domains, m.extractSignalDomain(config)...) - - // Extract relay domains - domains = append(domains, m.extractRelayDomains(config)...) - - // Extract flow domain - domains = append(domains, m.extractFlowDomain(config)...) - - // Extract STUN domains - domains = append(domains, m.extractSTUNDomains(config)...) - - // Extract TURN domains - domains = append(domains, m.extractTURNDomains(config)...) - - return domains -} - -func (m *Resolver) extractSignalDomain(config *mgmProto.NetbirdConfig) []domain.Domain { - if config.Signal == nil || config.Signal.Uri == "" { - return nil - } - - if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil { - return []domain.Domain{d} - } - return nil -} - -func (m *Resolver) extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain { - if config.Relay == nil { - return nil - } - - var domains []domain.Domain - for _, relayURL := range config.Relay.Urls { - if d, err := m.extractDomainFromURL(relayURL); err == nil { - domains = append(domains, d) - } - } - return domains -} - -func (m *Resolver) extractFlowDomain(config *mgmProto.NetbirdConfig) []domain.Domain { - if config.Flow == nil || config.Flow.Url == "" { - return nil - } - - if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil { - return []domain.Domain{d} - } - return nil -} - -func (m *Resolver) extractSTUNDomains(config *mgmProto.NetbirdConfig) []domain.Domain { - var domains []domain.Domain - for _, stun := range config.Stuns { - if stun != nil && stun.Uri != "" { - if d, err := m.extractDomainFromURL(stun.Uri); err == nil { - domains = append(domains, d) - } - } - } - return domains -} - -func (m *Resolver) extractTURNDomains(config *mgmProto.NetbirdConfig) []domain.Domain { - var domains []domain.Domain - for _, turn := range config.Turns { - if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" { - if d, err := m.extractDomainFromURL(turn.HostConfig.Uri); err == nil { - domains = append(domains, d) - } - } - } - return domains -} - -// extractDomainFromSignalConfig extracts domain from signal configuration. -func (m *Resolver) extractDomainFromSignalConfig(signal *mgmProto.HostConfig) (domain.Domain, error) { - signalURL, err := url.Parse(signal.Uri) - if err != nil { - // If parsing fails, it might be a raw host:port, try adding a scheme - signalURL, err = url.Parse("https://" + signal.Uri) - if err != nil { - return "", err - } - } - return extractDomainFromURL(signalURL) -} - -// extractDomainFromURL extracts domain from a URL string. -func (m *Resolver) extractDomainFromURL(urlStr string) (domain.Domain, error) { - parsedURL, err := url.Parse(urlStr) - if err != nil { - return "", err - } - return extractDomainFromURL(parsedURL) -} - -// addStunDomains adds STUN server domains to cache. -func (m *Resolver) addStunDomains(ctx context.Context, stuns []*mgmProto.HostConfig) { - for _, stun := range stuns { - if stun == nil || stun.Uri == "" { - continue - } - - stunURL, err := url.Parse(stun.Uri) - if err != nil { - log.Warnf("failed to parse STUN URL %s: %v", stun.Uri, err) - continue - } - - d, err := extractDomainFromURL(stunURL) - if err != nil { - log.Warnf("failed to extract STUN domain from %s: %v", stun.Uri, err) - continue - } - - if err := m.AddDomain(ctx, d); err != nil { - log.Warnf("failed to add STUN domain: %v", err) - } - } -} - -// addTurnDomains adds TURN server domains to cache. -func (m *Resolver) addTurnDomains(ctx context.Context, turns []*mgmProto.ProtectedHostConfig) { - for _, turn := range turns { - if turn == nil || turn.HostConfig == nil || turn.HostConfig.Uri == "" { - continue - } - - turnURL, err := url.Parse(turn.HostConfig.Uri) - if err != nil { - log.Warnf("failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err) - continue - } - - d, err := extractDomainFromURL(turnURL) - if err != nil { - log.Warnf("failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err) - continue - } - - if err := m.AddDomain(ctx, d); err != nil { - log.Warnf("failed to add TURN domain: %v", err) - } - } -} - // extractDomainFromURL extracts the domain from a URL. func extractDomainFromURL(u *url.URL) (domain.Domain, error) { if u == nil { diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go index 17da1a75c..7e211bda3 100644 --- a/client/internal/dns/mgmt/mgmt_test.go +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -2,22 +2,24 @@ package mgmt import ( "context" - "net" "net/url" + "strings" "testing" "github.com/miekg/dns" "github.com/stretchr/testify/assert" - mgmProto "github.com/netbirdio/netbird/management/proto" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/management/domain" ) func TestResolver_NewResolver(t *testing.T) { resolver := NewResolver() assert.NotNil(t, resolver) - assert.NotNil(t, resolver.cache) - assert.True(t, resolver.MatchSubdomains()) + assert.NotNil(t, resolver.records) + assert.False(t, resolver.MatchSubdomains()) } func TestResolver_ExtractDomainFromURL(t *testing.T) { @@ -113,145 +115,173 @@ func TestResolver_PopulateFromConfig(t *testing.T) { resolver := NewResolver() - // Use IP address to avoid DNS resolution timeout + // Test with IP address - should return error since IP addresses are rejected mgmtURL, _ := url.Parse("https://127.0.0.1") err := resolver.PopulateFromConfig(ctx, mgmtURL) - assert.NoError(t, err) + assert.Error(t, err) + assert.Contains(t, err.Error(), "host is an IP address") - // IP addresses are rejected, so no domains should be cached + // No domains should be cached when using IP addresses domains := resolver.GetCachedDomains() assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses") } -func TestResolver_PopulateFromNetbirdConfig(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - +func TestResolver_ServeDNS(t *testing.T) { resolver := NewResolver() + ctx := context.Background() - // Use IP addresses to avoid DNS resolution timeouts - netbirdConfig := &mgmProto.NetbirdConfig{ - Signal: &mgmProto.HostConfig{ - Uri: "https://10.0.0.1", - }, - Relay: &mgmProto.RelayConfig{ - Urls: []string{ - "https://10.0.0.2:443", - "https://10.0.0.3:443", - }, - }, - Flow: &mgmProto.FlowConfig{ - Url: "https://10.0.0.4:80", - }, - Stuns: []*mgmProto.HostConfig{ - {Uri: "stun:10.0.0.5:3478"}, - {Uri: "stun:10.0.0.6:3478"}, - }, - Turns: []*mgmProto.ProtectedHostConfig{ - { - HostConfig: &mgmProto.HostConfig{ - Uri: "turn:10.0.0.7:3478", - }, - }, - { - HostConfig: &mgmProto.HostConfig{ - Uri: "turn:10.0.0.8:3478", - }, - }, - }, + // Add a test domain to the cache - use example.org which is reserved for testing + testDomain, err := domain.FromString("example.org") + if err != nil { + t.Fatalf("Failed to create domain: %v", err) + } + err = resolver.AddDomain(ctx, testDomain) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) } - err := resolver.PopulateFromNetbirdConfig(ctx, netbirdConfig) - assert.NoError(t, err) + // Test A record query for cached domain + t.Run("Cached domain A record", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } - // IP addresses are rejected, so no domains should be cached - domains := resolver.GetCachedDomains() - assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses") + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode) + assert.True(t, len(capturedMsg.Answer) > 0, "Should have at least one answer") + }) + + // Test uncached domain signals to continue to next handler + t.Run("Uncached domain signals continue to next handler", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + req := new(dns.Msg) + req.SetQuestion("unknown.example.com.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode) + // Zero flag set to true signals the handler chain to continue to next handler + assert.True(t, capturedMsg.MsgHdr.Zero, "Zero flag should be set to signal continuation to next handler") + assert.Empty(t, capturedMsg.Answer, "Should have no answers for uncached domain") + }) + + // Test that subdomains of cached domains are NOT resolved + t.Run("Subdomains of cached domains are not resolved", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + // Query for a subdomain of our cached domain + req := new(dns.Msg) + req.SetQuestion("sub.example.org.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode) + assert.True(t, capturedMsg.MsgHdr.Zero, "Should signal continuation to next handler for subdomains") + assert.Empty(t, capturedMsg.Answer, "Should have no answers for subdomains") + }) + + // Test case-insensitive matching + t.Run("Case-insensitive domain matching", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + // Query with different casing + req := new(dns.Msg) + req.SetQuestion("EXAMPLE.ORG.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode) + assert.True(t, len(capturedMsg.Answer) > 0, "Should resolve regardless of case") + }) } -func TestResolver_UpdateFromNetbirdConfig(t *testing.T) { +func TestResolver_GetCachedDomains(t *testing.T) { resolver := NewResolver() + ctx := context.Background() - // Test with empty initial config and then add domains - initialConfig := &mgmProto.NetbirdConfig{} - - // Start with empty config - removedDomains, err := resolver.UpdateFromNetbirdConfig(context.Background(), initialConfig) - assert.NoError(t, err) - assert.Equal(t, 0, len(removedDomains), "No domains should be removed from empty cache") - - // Update to config with IP addresses instead of domains to avoid DNS resolution - // IP addresses will be rejected by extractDomainFromURL so no actual resolution happens - updatedConfig := &mgmProto.NetbirdConfig{ - Signal: &mgmProto.HostConfig{ - Uri: "https://127.0.0.1", - }, - Flow: &mgmProto.FlowConfig{ - Url: "https://192.168.1.1:80", - }, + testDomain, err := domain.FromString("example.org") + if err != nil { + t.Fatalf("Failed to create domain: %v", err) + } + err = resolver.AddDomain(ctx, testDomain) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) } - removedDomains, err = resolver.UpdateFromNetbirdConfig(context.Background(), updatedConfig) - assert.NoError(t, err) + cachedDomains := resolver.GetCachedDomains() - // Verify the method completes successfully without DNS timeouts - assert.GreaterOrEqual(t, len(removedDomains), 0, "Should not error on config update") - - // Verify no domains were actually added since IPs are rejected - domains := resolver.GetCachedDomains() - assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses") + assert.Equal(t, 1, len(cachedDomains), "Should return exactly one domain for single added domain") + assert.Equal(t, testDomain.SafeString(), cachedDomains[0].SafeString(), "Cached domain should match original") + assert.False(t, strings.HasSuffix(cachedDomains[0].PunycodeString(), "."), "Domain should not have trailing dot") } -func TestResolver_ContinueToNext(t *testing.T) { +func TestResolver_ManagementDomainProtection(t *testing.T) { resolver := NewResolver() + ctx := context.Background() - // Create a mock response writer to capture the response - mockWriter := &MockResponseWriter{} + mgmtURL, _ := url.Parse("https://example.org") + err := resolver.PopulateFromConfig(ctx, mgmtURL) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } - // Create a test DNS query - req := new(dns.Msg) - req.SetQuestion("unknown.example.com.", dns.TypeA) + initialDomains := resolver.GetCachedDomains() + if len(initialDomains) == 0 { + t.Skip("Management domain failed to resolve, skipping test") + } + assert.Equal(t, 1, len(initialDomains), "Should have management domain cached") + assert.Equal(t, "example.org", initialDomains[0].SafeString()) - // Call continueToNext - resolver.continueToNext(mockWriter, req) + serverDomains := dnsconfig.ServerDomains{ + Signal: "google.com", + Relay: []domain.Domain{"cloudflare.com"}, + } - // Verify the response - assert.NotNil(t, mockWriter.msg) - assert.Equal(t, dns.RcodeNameError, mockWriter.msg.Rcode) - assert.True(t, mockWriter.msg.MsgHdr.Zero) + _, err = resolver.UpdateFromServerDomains(ctx, serverDomains) + if err != nil { + t.Logf("Server domains update failed: %v", err) + } + + finalDomains := resolver.GetCachedDomains() + + managementStillCached := false + for _, d := range finalDomains { + if d.SafeString() == "example.org" { + managementStillCached = true + break + } + } + assert.True(t, managementStillCached, "Management domain should never be removed") } - -// MockResponseWriter is a simple mock implementation of dns.ResponseWriter for testing -type MockResponseWriter struct { - msg *dns.Msg -} - -func (m *MockResponseWriter) LocalAddr() net.Addr { - return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 53} -} - -func (m *MockResponseWriter) RemoteAddr() net.Addr { - return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345} -} - -func (m *MockResponseWriter) WriteMsg(msg *dns.Msg) error { - m.msg = msg - return nil -} - -func (m *MockResponseWriter) Write([]byte) (int, error) { - return 0, nil -} - -func (m *MockResponseWriter) Close() error { - return nil -} - -func (m *MockResponseWriter) TsigStatus() error { - return nil -} - -func (m *MockResponseWriter) TsigTimersOnly(bool) {} - -func (m *MockResponseWriter) Hijack() {} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 8d7ebfcac..268fd01b2 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -74,7 +74,6 @@ type DefaultServer struct { handlerChain *HandlerChain extraDomains map[domain.Domain]int - // management cache resolver for critical infrastructure domains mgmtCacheResolver *mgmt.Resolver // permanent related properties @@ -106,18 +105,15 @@ type registeredHandlerMap map[types.HandlerID]handlerWrapper // DefaultServerConfig holds configuration parameters for NewDefaultServer type DefaultServerConfig struct { - Ctx context.Context WgInterface WGIface CustomAddress string StatusRecorder *peer.Status StateManager *statemanager.Manager DisableSys bool - MgmtURL *url.URL - ServerDomains dnsconfig.ServerDomains } // NewDefaultServer returns a new dns server -func NewDefaultServer(config DefaultServerConfig) (*DefaultServer, error) { +func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*DefaultServer, error) { var addrPort *netip.AddrPort if config.CustomAddress != "" { parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress) @@ -134,20 +130,7 @@ func NewDefaultServer(config DefaultServerConfig) (*DefaultServer, error) { dnsService = newServiceViaListener(config.WgInterface, addrPort) } - server := newDefaultServer(config.Ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys) - - if config.MgmtURL != nil && server.mgmtCacheResolver != nil { - if err := server.mgmtCacheResolver.PopulateFromConfig(config.Ctx, config.MgmtURL); err != nil { - log.Warnf("Failed to populate management cache from management URL: %v", err) - } - } - - if server.mgmtCacheResolver != nil { - if err := server.UpdateServerConfig(config.ServerDomains); err != nil { - log.Warnf("Failed to populate management cache from ServerDomains: %v", err) - } - } - + server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys) return server, nil } @@ -197,7 +180,6 @@ func newDefaultServer( handlerChain := NewHandlerChain() ctx, stop := context.WithCancel(ctx) - // Create management cache resolver mgmtCacheResolver := mgmt.NewResolver() defaultServer := &DefaultServer{ @@ -345,20 +327,8 @@ func (s *DefaultServer) Stop() { maps.Clear(s.extraDomains) } -// PopulateMgmtCacheFromConfig populates the management cache with domains from the client configuration -func (s *DefaultServer) PopulateMgmtCacheFromConfig(mgmtURL *url.URL) error { - if s.mgmtCacheResolver == nil { - return fmt.Errorf("management cache resolver not initialized") - } - - log.Debug("populating management cache from client configuration") - return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL) -} - - // OnUpdatedHostDNSServer update the DNS servers addresses for root zones // It will be applied if the mgm server do not enforce DNS settings for root zone - func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { s.hostsDNSHolder.set(hostsDnsList) @@ -465,12 +435,12 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro } if len(removedDomains) > 0 { - s.DeregisterHandler(removedDomains, PriorityMgmtCache) + s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache) } newDomains := s.mgmtCacheResolver.GetCachedDomains() if len(newDomains) > 0 { - s.RegisterHandler(newDomains, s.mgmtCacheResolver, PriorityMgmtCache) + s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache) } } @@ -935,3 +905,11 @@ func toZone(d domain.Domain) domain.Domain { ), ) } + +// PopulateManagementDomain populates the DNS cache with management domain +func (s *DefaultServer) PopulateManagementDomain(ctx context.Context, mgmtURL *url.URL) error { + if s.mgmtCacheResolver != nil && mgmtURL != nil { + return s.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL) + } + return nil +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 4892fb3af..905fff68e 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -23,7 +23,6 @@ import ( "github.com/netbirdio/netbird/client/iface/device" pfmock "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/wgaddr" - dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dns/local" "github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/types" @@ -364,15 +363,12 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(DefaultServerConfig{ - Ctx: context.Background(), + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ WgInterface: wgIface, CustomAddress: "", StatusRecorder: peer.NewRecorder("mgm"), StateManager: nil, DisableSys: false, - MgmtURL: nil, - ServerDomains: dnsconfig.ServerDomains{}, }) if err != nil { t.Fatal(err) @@ -483,15 +479,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(DefaultServerConfig{ - Ctx: context.Background(), + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ WgInterface: wgIface, CustomAddress: "", StatusRecorder: peer.NewRecorder("mgm"), StateManager: nil, DisableSys: false, - MgmtURL: nil, - ServerDomains: dnsconfig.ServerDomains{}, }) if err != nil { t.Errorf("create DNS server: %v", err) @@ -594,15 +587,12 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(DefaultServerConfig{ - Ctx: context.Background(), + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ WgInterface: &mocWGIface{}, CustomAddress: testCase.addrPort, StatusRecorder: peer.NewRecorder("mgm"), StateManager: nil, DisableSys: false, - MgmtURL: nil, - ServerDomains: dnsconfig.ServerDomains{}, }) if err != nil { t.Fatalf("%v", err) diff --git a/client/internal/engine.go b/client/internal/engine.go index f4aed0b3f..08c6cb97a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -33,7 +33,7 @@ import ( nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" - "github.com/netbirdio/netbird/client/internal/dns/config" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/netflow" @@ -126,12 +126,6 @@ type EngineConfig struct { BlockInbound bool LazyConnectionEnabled bool - - // ManagementURL is the URL of the management server for DNS cache - ManagementURL *url.URL - - // NetbirdConfig contains signal, relay, and flow server configuration - NetbirdConfig *mgmProto.NetbirdConfig } // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. @@ -350,7 +344,7 @@ func (e *Engine) Stop() error { // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Connections to remote peers are not established here. // However, they will be established once an event with a list of peers to connect to will be received from Management Service -func (e *Engine) Start() error { +func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -395,13 +389,18 @@ func (e *Engine) Start() error { return fmt.Errorf("read initial settings: %w", err) } - dnsServer, err := e.newDnsServer(dnsConfig, e.config.ManagementURL, e.config.NetbirdConfig) + dnsServer, err := e.newDnsServer(dnsConfig) if err != nil { e.close() return fmt.Errorf("create dns server: %w", err) } e.dnsServer = dnsServer + // Populate DNS cache with NetbirdConfig and management URL for early resolution + if err := e.PopulateNetbirdConfig(netbirdConfig, mgmtURL); err != nil { + log.Warnf("failed to populate DNS cache: %v", err) + } + e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ Context: e.ctx, PublicKey: e.config.WgPrivateKey.PublicKey().String(), @@ -666,6 +665,32 @@ func (e *Engine) removePeer(peerKey string) error { return nil } +// PopulateNetbirdConfig populates the DNS cache with infrastructure domains from login response +func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error { + if e.dnsServer == nil { + return nil + } + + // Populate management URL if provided + if mgmtURL != nil { + if defaultServer, ok := e.dnsServer.(*dns.DefaultServer); ok { + if err := defaultServer.PopulateManagementDomain(e.ctx, mgmtURL); err != nil { + log.Warnf("failed to populate DNS cache with management URL: %v", err) + } + } + } + + // Populate NetbirdConfig domains if provided + if netbirdConfig != nil { + serverDomains := dnsconfig.ExtractFromNetbirdConfig(netbirdConfig) + if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil { + return fmt.Errorf("update DNS server config from NetbirdConfig: %w", err) + } + } + + return nil +} + func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -697,11 +722,8 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return fmt.Errorf("handle the flow configuration: %w", err) } - if e.dnsServer != nil { - serverDomains := config.ExtractFromNetbirdConfig(wCfg) - if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil { - log.Warnf("Failed to update DNS server config: %v", err) - } + if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil { + log.Warnf("Failed to update DNS server config: %v", err) } // todo update signal @@ -1587,7 +1609,7 @@ func (e *Engine) wgInterfaceCreate() (err error) { return err } -func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbirdConfig *mgmProto.NetbirdConfig) (dns.Server, error) { +func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { // due to tests where we are using a mocked version of the DNS server if e.dnsServer != nil { return e.dnsServer, nil @@ -1612,18 +1634,13 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbird return dnsServer, nil default: - // Extract domains from NetBird configuration - serverDomains := config.ExtractFromNetbirdConfig(netbirdConfig) - - dnsServer, err := dns.NewDefaultServer(dns.DefaultServerConfig{ - Ctx: e.ctx, + + dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{ WgInterface: e.wgInterface, CustomAddress: e.config.CustomDNSAddress, StatusRecorder: e.statusRecorder, StateManager: e.stateManager, DisableSys: e.config.DisableDNS, - MgmtURL: mgmtURL, - ServerDomains: serverDomains, }) if err != nil { return nil, err @@ -1643,11 +1660,6 @@ func (e *Engine) GetFirewallManager() firewallManager.Manager { return e.firewall } -// GetDNSServer returns the DNS server -func (e *Engine) GetDNSServer() dns.Server { - return e.dnsServer -} - func findIPFromInterfaceName(ifaceName string) (net.IP, error) { iface, err := net.InterfaceByName(ifaceName) if err != nil { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index f4ed8f1c0..6b317f244 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -261,7 +261,7 @@ func TestEngine_SSH(t *testing.T) { }, }, nil } - err = engine.Start() + err = engine.Start(nil, nil) if err != nil { t.Fatal(err) } @@ -605,7 +605,7 @@ func TestEngine_Sync(t *testing.T) { } }() - err = engine.Start() + err = engine.Start(nil, nil) if err != nil { t.Fatal(err) return @@ -1060,7 +1060,7 @@ func TestEngine_MultiplePeers(t *testing.T) { defer mu.Unlock() guid := fmt.Sprintf("{%s}", uuid.New().String()) device.CustomWindowsGUIDString = strings.ToLower(guid) - err = engine.Start() + err = engine.Start(nil, nil) if err != nil { t.Errorf("unable to start engine for peer %d with error %v", j, err) wg.Done() diff --git a/client/internal/login.go b/client/internal/login.go index 53fa17d90..b25962814 100644 --- a/client/internal/login.go +++ b/client/internal/login.go @@ -39,7 +39,7 @@ func IsLoginRequired(ctx context.Context, config *Config) (bool, error) { return false, err } - _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) + _, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) if isLoginNeeded(err) { return true, nil } @@ -68,14 +68,18 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string return err } - serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) + serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) if serverKey != nil && isRegistrationNeeded(err) { log.Debugf("peer registration required") _, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config) + if err != nil { + return err + } + } else if err != nil { return err } - return err + return nil } func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) { @@ -100,11 +104,11 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm return mgmClient, err } -func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *Config) (*wgtypes.Key, error) { +func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) { serverKey, err := mgmClient.GetServerPublicKey() if err != nil { log.Errorf("failed while getting Management Service public key: %v", err) - return nil, err + return nil, nil, err } sysInfo := system.GetInfo(ctx) @@ -120,8 +124,8 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte config.BlockInbound, config.LazyConnectionEnabled, ) - _, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) - return serverKey, err + loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) + return serverKey, loginResp, err } // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.