More cleanup

This commit is contained in:
Viktor Liu
2025-07-11 17:20:37 +02:00
parent 90bf1baec2
commit 4b22ad036d
8 changed files with 38 additions and 57 deletions

View File

@@ -259,7 +259,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
peerConfig := loginResp.GetPeerConfig() peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, loginResp.GetNetbirdConfig()) engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return wrapErr(err) return wrapErr(err)
@@ -414,7 +414,7 @@ func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
} }
// createEngineConfig converts configuration received from Management Service to EngineConfig // createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, netbirdConfig *mgmProto.NetbirdConfig) (*EngineConfig, error) { func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false nm := false
if config.NetworkMonitor != nil { if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor nm = *config.NetworkMonitor

View File

@@ -1,6 +1,7 @@
package config package config
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -13,6 +14,12 @@ import (
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
var (
ErrEmptyURL = errors.New("empty URL")
ErrEmptyHost = errors.New("empty host")
ErrIPNotAllowed = errors.New("IP address not allowed")
)
// ServerDomains represents the management server domains extracted from NetBird configuration // ServerDomains represents the management server domains extracted from NetBird configuration
type ServerDomains struct { type ServerDomains struct {
Signal domain.Domain Signal domain.Domain
@@ -39,10 +46,10 @@ func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
return domains return domains
} }
// extractValidDomain extracts a valid domain from a URL, filtering out IP addresses // ExtractValidDomain extracts a valid domain from a URL, filtering out IP addresses
func extractValidDomain(rawURL string) (domain.Domain, error) { func ExtractValidDomain(rawURL string) (domain.Domain, error) {
if rawURL == "" { if rawURL == "" {
return "", fmt.Errorf("empty URL") return "", ErrEmptyURL
} }
// Try standard URL parsing first (handles https://, http://, rels://, etc.) // Try standard URL parsing first (handles https://, http://, rels://, etc.)
@@ -77,11 +84,11 @@ func extractValidDomain(rawURL string) (domain.Domain, error) {
// extractDomainFromHost extracts domain from a host string, filtering out IP addresses // extractDomainFromHost extracts domain from a host string, filtering out IP addresses
func extractDomainFromHost(host string) (domain.Domain, error) { func extractDomainFromHost(host string) (domain.Domain, error) {
if host == "" { if host == "" {
return "", fmt.Errorf("empty host") return "", ErrEmptyHost
} }
if _, err := netip.ParseAddr(host); err == nil { if _, err := netip.ParseAddr(host); err == nil {
return "", fmt.Errorf("IP address not allowed: %s", host) return "", fmt.Errorf("%w: %s", ErrIPNotAllowed, host)
} }
d, err := domain.FromString(host) d, err := domain.FromString(host)
@@ -98,7 +105,7 @@ func extractSingleDomain(url, serviceType string) domain.Domain {
return "" return ""
} }
d, err := extractValidDomain(url) d, err := ExtractValidDomain(url)
if err != nil { if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err) log.Debugf("Skipping %s: %v", serviceType, err)
return "" return ""
@@ -114,7 +121,7 @@ func extractMultipleDomains(urls []string, serviceType string) []domain.Domain {
if url == "" { if url == "" {
continue continue
} }
d, err := extractValidDomain(url) d, err := ExtractValidDomain(url)
if err != nil { if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err) log.Debugf("Skipping %s: %v", serviceType, err)
continue continue

View File

@@ -87,7 +87,7 @@ func TestExtractValidDomain(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result, err := extractValidDomain(tt.url) result, err := ExtractValidDomain(tt.url)
if tt.expectError { if tt.expectError {
assert.Error(t, err, "Expected error for URL: %s", tt.url) assert.Error(t, err, "Expected error for URL: %s", tt.url)

View File

@@ -2,10 +2,8 @@ package mgmt
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
@@ -162,7 +160,7 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err
return nil return nil
} }
d, err := extractDomainFromURL(mgmtURL) d, err := dnsconfig.ExtractValidDomain(mgmtURL.String())
if err != nil { if err != nil {
return fmt.Errorf("extract domain from URL: %w", err) return fmt.Errorf("extract domain from URL: %w", err)
} }
@@ -313,40 +311,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains return domains
} }
// extractDomainFromURL extracts the domain from a URL.
func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
if u == nil {
return "", errors.New("invalid URL")
}
host := u.Host
// If Host is empty, try to extract from Opaque (for schemes like stun:domain:port)
if host == "" && u.Opaque != "" {
host = u.Opaque
}
if host == "" && u.Path != "" {
host = strings.TrimPrefix(u.Path, "/")
}
if host == "" {
return "", errors.New("empty host")
}
host, _, err := net.SplitHostPort(host)
if err != nil {
switch {
case u.Host != "":
host = u.Host
case u.Opaque != "":
host = u.Opaque
default:
host = strings.TrimPrefix(u.Path, "/")
}
}
if _, err := netip.ParseAddr(host); err == nil {
return "", errors.New("host is an IP address, skipping")
}
return domain.FromString(host)
}

View File

@@ -2,6 +2,7 @@ package mgmt
import ( import (
"context" "context"
"fmt"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
@@ -120,7 +121,7 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
err := resolver.PopulateFromConfig(ctx, mgmtURL) err := resolver.PopulateFromConfig(ctx, mgmtURL)
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "host is an IP address") assert.ErrorIs(t, err, dnsconfig.ErrIPNotAllowed)
// No domains should be cached when using IP addresses // No domains should be cached when using IP addresses
domains := resolver.GetCachedDomains() domains := resolver.GetCachedDomains()
@@ -285,3 +286,11 @@ func TestResolver_ManagementDomainProtection(t *testing.T) {
} }
assert.True(t, managementStillCached, "Management domain should never be removed") assert.True(t, managementStillCached, "Management domain should never be removed")
} }
// extractDomainFromURL extracts a domain from a URL - test helper function
func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
if u == nil {
return "", fmt.Errorf("URL is nil")
}
return dnsconfig.ExtractValidDomain(u.String())
}

View File

@@ -2,6 +2,7 @@ package dns
import ( import (
"fmt" "fmt"
"net/url"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -78,3 +79,7 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
} }
return nil return nil
} }
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}

View File

@@ -50,6 +50,7 @@ type Server interface {
SearchDomains() []string SearchDomains() []string
ProbeAvailability() ProbeAvailability()
UpdateServerConfig(domains dnsconfig.ServerDomains) error UpdateServerConfig(domains dnsconfig.ServerDomains) error
PopulateManagementDomain(mgmtURL *url.URL) error
} }
type nsGroupsByDomain struct { type nsGroupsByDomain struct {
@@ -907,9 +908,9 @@ func toZone(d domain.Domain) domain.Domain {
} }
// PopulateManagementDomain populates the DNS cache with management domain // PopulateManagementDomain populates the DNS cache with management domain
func (s *DefaultServer) PopulateManagementDomain(ctx context.Context, mgmtURL *url.URL) error { func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
if s.mgmtCacheResolver != nil && mgmtURL != nil { if s.mgmtCacheResolver != nil && mgmtURL != nil {
return s.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL) return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
} }
return nil return nil
} }

View File

@@ -673,10 +673,8 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
// Populate management URL if provided // Populate management URL if provided
if mgmtURL != nil { if mgmtURL != nil {
if defaultServer, ok := e.dnsServer.(*dns.DefaultServer); ok { if err := e.dnsServer.PopulateManagementDomain(mgmtURL); err != nil {
if err := defaultServer.PopulateManagementDomain(e.ctx, mgmtURL); err != nil { log.Warnf("failed to populate DNS cache with management URL: %v", err)
log.Warnf("failed to populate DNS cache with management URL: %v", err)
}
} }
} }