More cleanup

This commit is contained in:
Viktor Liu
2025-07-11 17:20:37 +02:00
parent 90bf1baec2
commit 4b22ad036d
8 changed files with 38 additions and 57 deletions

View File

@@ -259,7 +259,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, loginResp.GetNetbirdConfig())
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
if err != nil {
log.Error(err)
return wrapErr(err)
@@ -414,7 +414,7 @@ func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, netbirdConfig *mgmProto.NetbirdConfig) (*EngineConfig, error) {
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor

View File

@@ -1,6 +1,7 @@
package config
import (
"errors"
"fmt"
"net"
"net/netip"
@@ -13,6 +14,12 @@ import (
mgmProto "github.com/netbirdio/netbird/management/proto"
)
var (
ErrEmptyURL = errors.New("empty URL")
ErrEmptyHost = errors.New("empty host")
ErrIPNotAllowed = errors.New("IP address not allowed")
)
// ServerDomains represents the management server domains extracted from NetBird configuration
type ServerDomains struct {
Signal domain.Domain
@@ -39,10 +46,10 @@ func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
return domains
}
// extractValidDomain extracts a valid domain from a URL, filtering out IP addresses
func extractValidDomain(rawURL string) (domain.Domain, error) {
// ExtractValidDomain extracts a valid domain from a URL, filtering out IP addresses
func ExtractValidDomain(rawURL string) (domain.Domain, error) {
if rawURL == "" {
return "", fmt.Errorf("empty URL")
return "", ErrEmptyURL
}
// Try standard URL parsing first (handles https://, http://, rels://, etc.)
@@ -77,11 +84,11 @@ func extractValidDomain(rawURL string) (domain.Domain, error) {
// extractDomainFromHost extracts domain from a host string, filtering out IP addresses
func extractDomainFromHost(host string) (domain.Domain, error) {
if host == "" {
return "", fmt.Errorf("empty host")
return "", ErrEmptyHost
}
if _, err := netip.ParseAddr(host); err == nil {
return "", fmt.Errorf("IP address not allowed: %s", host)
return "", fmt.Errorf("%w: %s", ErrIPNotAllowed, host)
}
d, err := domain.FromString(host)
@@ -98,7 +105,7 @@ func extractSingleDomain(url, serviceType string) domain.Domain {
return ""
}
d, err := extractValidDomain(url)
d, err := ExtractValidDomain(url)
if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err)
return ""
@@ -114,7 +121,7 @@ func extractMultipleDomains(urls []string, serviceType string) []domain.Domain {
if url == "" {
continue
}
d, err := extractValidDomain(url)
d, err := ExtractValidDomain(url)
if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err)
continue

View File

@@ -87,7 +87,7 @@ func TestExtractValidDomain(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := extractValidDomain(tt.url)
result, err := ExtractValidDomain(tt.url)
if tt.expectError {
assert.Error(t, err, "Expected error for URL: %s", tt.url)

View File

@@ -2,10 +2,8 @@ package mgmt
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"strings"
"sync"
@@ -162,7 +160,7 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err
return nil
}
d, err := extractDomainFromURL(mgmtURL)
d, err := dnsconfig.ExtractValidDomain(mgmtURL.String())
if err != nil {
return fmt.Errorf("extract domain from URL: %w", err)
}
@@ -313,40 +311,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains
}
// extractDomainFromURL extracts the domain from a URL.
func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
if u == nil {
return "", errors.New("invalid URL")
}
host := u.Host
// If Host is empty, try to extract from Opaque (for schemes like stun:domain:port)
if host == "" && u.Opaque != "" {
host = u.Opaque
}
if host == "" && u.Path != "" {
host = strings.TrimPrefix(u.Path, "/")
}
if host == "" {
return "", errors.New("empty host")
}
host, _, err := net.SplitHostPort(host)
if err != nil {
switch {
case u.Host != "":
host = u.Host
case u.Opaque != "":
host = u.Opaque
default:
host = strings.TrimPrefix(u.Path, "/")
}
}
if _, err := netip.ParseAddr(host); err == nil {
return "", errors.New("host is an IP address, skipping")
}
return domain.FromString(host)
}

View File

@@ -2,6 +2,7 @@ package mgmt
import (
"context"
"fmt"
"net/url"
"strings"
"testing"
@@ -120,7 +121,7 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
err := resolver.PopulateFromConfig(ctx, mgmtURL)
assert.Error(t, err)
assert.Contains(t, err.Error(), "host is an IP address")
assert.ErrorIs(t, err, dnsconfig.ErrIPNotAllowed)
// No domains should be cached when using IP addresses
domains := resolver.GetCachedDomains()
@@ -285,3 +286,11 @@ func TestResolver_ManagementDomainProtection(t *testing.T) {
}
assert.True(t, managementStillCached, "Management domain should never be removed")
}
// extractDomainFromURL extracts a domain from a URL - test helper function
func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
if u == nil {
return "", fmt.Errorf("URL is nil")
}
return dnsconfig.ExtractValidDomain(u.String())
}

View File

@@ -2,6 +2,7 @@ package dns
import (
"fmt"
"net/url"
"github.com/miekg/dns"
@@ -78,3 +79,7 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
}
return nil
}
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}

View File

@@ -50,6 +50,7 @@ type Server interface {
SearchDomains() []string
ProbeAvailability()
UpdateServerConfig(domains dnsconfig.ServerDomains) error
PopulateManagementDomain(mgmtURL *url.URL) error
}
type nsGroupsByDomain struct {
@@ -907,9 +908,9 @@ func toZone(d domain.Domain) domain.Domain {
}
// PopulateManagementDomain populates the DNS cache with management domain
func (s *DefaultServer) PopulateManagementDomain(ctx context.Context, mgmtURL *url.URL) error {
func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
if s.mgmtCacheResolver != nil && mgmtURL != nil {
return s.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL)
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
}
return nil
}

View File

@@ -673,10 +673,8 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
// Populate management URL if provided
if mgmtURL != nil {
if defaultServer, ok := e.dnsServer.(*dns.DefaultServer); ok {
if err := defaultServer.PopulateManagementDomain(e.ctx, mgmtURL); err != nil {
log.Warnf("failed to populate DNS cache with management URL: %v", err)
}
if err := e.dnsServer.PopulateManagementDomain(mgmtURL); err != nil {
log.Warnf("failed to populate DNS cache with management URL: %v", err)
}
}