From 4b22ad036d5533c16727fd81c3541e0b01f9c4cd Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 11 Jul 2025 17:20:37 +0200 Subject: [PATCH] More cleanup --- client/internal/connect.go | 4 +- client/internal/dns/config/domains.go | 21 ++++++---- client/internal/dns/config/domains_test.go | 2 +- client/internal/dns/mgmt/mgmt.go | 41 +------------------ client/internal/dns/mgmt/mgmt_test.go | 11 ++++- .../{mock_server.go => mock_server_test.go} | 5 +++ client/internal/dns/server.go | 5 ++- client/internal/engine.go | 6 +-- 8 files changed, 38 insertions(+), 57 deletions(-) rename client/internal/dns/{mock_server.go => mock_server_test.go} (95%) diff --git a/client/internal/connect.go b/client/internal/connect.go index ee6158cf6..392cc6fb4 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -259,7 +259,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan peerConfig := loginResp.GetPeerConfig() - engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, loginResp.GetNetbirdConfig()) + engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig) if err != nil { log.Error(err) return wrapErr(err) @@ -414,7 +414,7 @@ func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) { } // createEngineConfig converts configuration received from Management Service to EngineConfig -func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, netbirdConfig *mgmProto.NetbirdConfig) (*EngineConfig, error) { +func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { nm := false if config.NetworkMonitor != nil { nm = *config.NetworkMonitor diff --git a/client/internal/dns/config/domains.go b/client/internal/dns/config/domains.go index 27df39d1d..51bb4e998 100644 --- a/client/internal/dns/config/domains.go +++ b/client/internal/dns/config/domains.go @@ -1,6 +1,7 @@ package config import ( + "errors" "fmt" "net" "net/netip" @@ -13,6 +14,12 @@ import ( mgmProto "github.com/netbirdio/netbird/management/proto" ) +var ( + ErrEmptyURL = errors.New("empty URL") + ErrEmptyHost = errors.New("empty host") + ErrIPNotAllowed = errors.New("IP address not allowed") +) + // ServerDomains represents the management server domains extracted from NetBird configuration type ServerDomains struct { Signal domain.Domain @@ -39,10 +46,10 @@ func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains { return domains } -// extractValidDomain extracts a valid domain from a URL, filtering out IP addresses -func extractValidDomain(rawURL string) (domain.Domain, error) { +// ExtractValidDomain extracts a valid domain from a URL, filtering out IP addresses +func ExtractValidDomain(rawURL string) (domain.Domain, error) { if rawURL == "" { - return "", fmt.Errorf("empty URL") + return "", ErrEmptyURL } // Try standard URL parsing first (handles https://, http://, rels://, etc.) @@ -77,11 +84,11 @@ func extractValidDomain(rawURL string) (domain.Domain, error) { // extractDomainFromHost extracts domain from a host string, filtering out IP addresses func extractDomainFromHost(host string) (domain.Domain, error) { if host == "" { - return "", fmt.Errorf("empty host") + return "", ErrEmptyHost } if _, err := netip.ParseAddr(host); err == nil { - return "", fmt.Errorf("IP address not allowed: %s", host) + return "", fmt.Errorf("%w: %s", ErrIPNotAllowed, host) } d, err := domain.FromString(host) @@ -98,7 +105,7 @@ func extractSingleDomain(url, serviceType string) domain.Domain { return "" } - d, err := extractValidDomain(url) + d, err := ExtractValidDomain(url) if err != nil { log.Debugf("Skipping %s: %v", serviceType, err) return "" @@ -114,7 +121,7 @@ func extractMultipleDomains(urls []string, serviceType string) []domain.Domain { if url == "" { continue } - d, err := extractValidDomain(url) + d, err := ExtractValidDomain(url) if err != nil { log.Debugf("Skipping %s: %v", serviceType, err) continue diff --git a/client/internal/dns/config/domains_test.go b/client/internal/dns/config/domains_test.go index fef808ab1..c7a6c3224 100644 --- a/client/internal/dns/config/domains_test.go +++ b/client/internal/dns/config/domains_test.go @@ -87,7 +87,7 @@ func TestExtractValidDomain(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := extractValidDomain(tt.url) + result, err := ExtractValidDomain(tt.url) if tt.expectError { assert.Error(t, err, "Expected error for URL: %s", tt.url) diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 78dcdae8e..f3d023c11 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -2,10 +2,8 @@ package mgmt import ( "context" - "errors" "fmt" "net" - "net/netip" "net/url" "strings" "sync" @@ -162,7 +160,7 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err return nil } - d, err := extractDomainFromURL(mgmtURL) + d, err := dnsconfig.ExtractValidDomain(mgmtURL.String()) if err != nil { return fmt.Errorf("extract domain from URL: %w", err) } @@ -313,40 +311,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve return domains } -// extractDomainFromURL extracts the domain from a URL. -func extractDomainFromURL(u *url.URL) (domain.Domain, error) { - if u == nil { - return "", errors.New("invalid URL") - } - - host := u.Host - // If Host is empty, try to extract from Opaque (for schemes like stun:domain:port) - if host == "" && u.Opaque != "" { - host = u.Opaque - } - if host == "" && u.Path != "" { - host = strings.TrimPrefix(u.Path, "/") - } - - if host == "" { - return "", errors.New("empty host") - } - - host, _, err := net.SplitHostPort(host) - if err != nil { - switch { - case u.Host != "": - host = u.Host - case u.Opaque != "": - host = u.Opaque - default: - host = strings.TrimPrefix(u.Path, "/") - } - } - - if _, err := netip.ParseAddr(host); err == nil { - return "", errors.New("host is an IP address, skipping") - } - - return domain.FromString(host) -} diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go index 7e211bda3..6a27073f1 100644 --- a/client/internal/dns/mgmt/mgmt_test.go +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -2,6 +2,7 @@ package mgmt import ( "context" + "fmt" "net/url" "strings" "testing" @@ -120,7 +121,7 @@ func TestResolver_PopulateFromConfig(t *testing.T) { err := resolver.PopulateFromConfig(ctx, mgmtURL) assert.Error(t, err) - assert.Contains(t, err.Error(), "host is an IP address") + assert.ErrorIs(t, err, dnsconfig.ErrIPNotAllowed) // No domains should be cached when using IP addresses domains := resolver.GetCachedDomains() @@ -285,3 +286,11 @@ func TestResolver_ManagementDomainProtection(t *testing.T) { } assert.True(t, managementStillCached, "Management domain should never be removed") } + +// extractDomainFromURL extracts a domain from a URL - test helper function +func extractDomainFromURL(u *url.URL) (domain.Domain, error) { + if u == nil { + return "", fmt.Errorf("URL is nil") + } + return dnsconfig.ExtractValidDomain(u.String()) +} diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server_test.go similarity index 95% rename from client/internal/dns/mock_server.go rename to client/internal/dns/mock_server_test.go index cd1cba1e6..5e20c979b 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server_test.go @@ -2,6 +2,7 @@ package dns import ( "fmt" + "net/url" "github.com/miekg/dns" @@ -78,3 +79,7 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { } return nil } + +func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { + return nil +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 268fd01b2..41c5b3560 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -50,6 +50,7 @@ type Server interface { SearchDomains() []string ProbeAvailability() UpdateServerConfig(domains dnsconfig.ServerDomains) error + PopulateManagementDomain(mgmtURL *url.URL) error } type nsGroupsByDomain struct { @@ -907,9 +908,9 @@ func toZone(d domain.Domain) domain.Domain { } // PopulateManagementDomain populates the DNS cache with management domain -func (s *DefaultServer) PopulateManagementDomain(ctx context.Context, mgmtURL *url.URL) error { +func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error { if s.mgmtCacheResolver != nil && mgmtURL != nil { - return s.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL) + return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL) } return nil } diff --git a/client/internal/engine.go b/client/internal/engine.go index 08c6cb97a..c6153cde3 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -673,10 +673,8 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg // Populate management URL if provided if mgmtURL != nil { - if defaultServer, ok := e.dnsServer.(*dns.DefaultServer); ok { - if err := defaultServer.PopulateManagementDomain(e.ctx, mgmtURL); err != nil { - log.Warnf("failed to populate DNS cache with management URL: %v", err) - } + if err := e.dnsServer.PopulateManagementDomain(mgmtURL); err != nil { + log.Warnf("failed to populate DNS cache with management URL: %v", err) } }