mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-18 11:00:06 +02:00
More cleanup
This commit is contained in:
@@ -259,7 +259,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, loginResp.GetNetbirdConfig())
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
@@ -414,7 +414,7 @@ func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
|
||||
}
|
||||
|
||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, netbirdConfig *mgmProto.NetbirdConfig) (*EngineConfig, error) {
|
||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||
nm := false
|
||||
if config.NetworkMonitor != nil {
|
||||
nm = *config.NetworkMonitor
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -13,6 +14,12 @@ import (
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyURL = errors.New("empty URL")
|
||||
ErrEmptyHost = errors.New("empty host")
|
||||
ErrIPNotAllowed = errors.New("IP address not allowed")
|
||||
)
|
||||
|
||||
// ServerDomains represents the management server domains extracted from NetBird configuration
|
||||
type ServerDomains struct {
|
||||
Signal domain.Domain
|
||||
@@ -39,10 +46,10 @@ func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
|
||||
return domains
|
||||
}
|
||||
|
||||
// extractValidDomain extracts a valid domain from a URL, filtering out IP addresses
|
||||
func extractValidDomain(rawURL string) (domain.Domain, error) {
|
||||
// ExtractValidDomain extracts a valid domain from a URL, filtering out IP addresses
|
||||
func ExtractValidDomain(rawURL string) (domain.Domain, error) {
|
||||
if rawURL == "" {
|
||||
return "", fmt.Errorf("empty URL")
|
||||
return "", ErrEmptyURL
|
||||
}
|
||||
|
||||
// Try standard URL parsing first (handles https://, http://, rels://, etc.)
|
||||
@@ -77,11 +84,11 @@ func extractValidDomain(rawURL string) (domain.Domain, error) {
|
||||
// extractDomainFromHost extracts domain from a host string, filtering out IP addresses
|
||||
func extractDomainFromHost(host string) (domain.Domain, error) {
|
||||
if host == "" {
|
||||
return "", fmt.Errorf("empty host")
|
||||
return "", ErrEmptyHost
|
||||
}
|
||||
|
||||
if _, err := netip.ParseAddr(host); err == nil {
|
||||
return "", fmt.Errorf("IP address not allowed: %s", host)
|
||||
return "", fmt.Errorf("%w: %s", ErrIPNotAllowed, host)
|
||||
}
|
||||
|
||||
d, err := domain.FromString(host)
|
||||
@@ -98,7 +105,7 @@ func extractSingleDomain(url, serviceType string) domain.Domain {
|
||||
return ""
|
||||
}
|
||||
|
||||
d, err := extractValidDomain(url)
|
||||
d, err := ExtractValidDomain(url)
|
||||
if err != nil {
|
||||
log.Debugf("Skipping %s: %v", serviceType, err)
|
||||
return ""
|
||||
@@ -114,7 +121,7 @@ func extractMultipleDomains(urls []string, serviceType string) []domain.Domain {
|
||||
if url == "" {
|
||||
continue
|
||||
}
|
||||
d, err := extractValidDomain(url)
|
||||
d, err := ExtractValidDomain(url)
|
||||
if err != nil {
|
||||
log.Debugf("Skipping %s: %v", serviceType, err)
|
||||
continue
|
||||
|
@@ -87,7 +87,7 @@ func TestExtractValidDomain(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := extractValidDomain(tt.url)
|
||||
result, err := ExtractValidDomain(tt.url)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, "Expected error for URL: %s", tt.url)
|
||||
|
@@ -2,10 +2,8 @@ package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -162,7 +160,7 @@ func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) err
|
||||
return nil
|
||||
}
|
||||
|
||||
d, err := extractDomainFromURL(mgmtURL)
|
||||
d, err := dnsconfig.ExtractValidDomain(mgmtURL.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("extract domain from URL: %w", err)
|
||||
}
|
||||
@@ -313,40 +311,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
||||
return domains
|
||||
}
|
||||
|
||||
// extractDomainFromURL extracts the domain from a URL.
|
||||
func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
|
||||
if u == nil {
|
||||
return "", errors.New("invalid URL")
|
||||
}
|
||||
|
||||
host := u.Host
|
||||
// If Host is empty, try to extract from Opaque (for schemes like stun:domain:port)
|
||||
if host == "" && u.Opaque != "" {
|
||||
host = u.Opaque
|
||||
}
|
||||
if host == "" && u.Path != "" {
|
||||
host = strings.TrimPrefix(u.Path, "/")
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
return "", errors.New("empty host")
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
switch {
|
||||
case u.Host != "":
|
||||
host = u.Host
|
||||
case u.Opaque != "":
|
||||
host = u.Opaque
|
||||
default:
|
||||
host = strings.TrimPrefix(u.Path, "/")
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := netip.ParseAddr(host); err == nil {
|
||||
return "", errors.New("host is an IP address, skipping")
|
||||
}
|
||||
|
||||
return domain.FromString(host)
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -120,7 +121,7 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
|
||||
|
||||
err := resolver.PopulateFromConfig(ctx, mgmtURL)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "host is an IP address")
|
||||
assert.ErrorIs(t, err, dnsconfig.ErrIPNotAllowed)
|
||||
|
||||
// No domains should be cached when using IP addresses
|
||||
domains := resolver.GetCachedDomains()
|
||||
@@ -285,3 +286,11 @@ func TestResolver_ManagementDomainProtection(t *testing.T) {
|
||||
}
|
||||
assert.True(t, managementStillCached, "Management domain should never be removed")
|
||||
}
|
||||
|
||||
// extractDomainFromURL extracts a domain from a URL - test helper function
|
||||
func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
|
||||
if u == nil {
|
||||
return "", fmt.Errorf("URL is nil")
|
||||
}
|
||||
return dnsconfig.ExtractValidDomain(u.String())
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
@@ -78,3 +79,7 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||
return nil
|
||||
}
|
@@ -50,6 +50,7 @@ type Server interface {
|
||||
SearchDomains() []string
|
||||
ProbeAvailability()
|
||||
UpdateServerConfig(domains dnsconfig.ServerDomains) error
|
||||
PopulateManagementDomain(mgmtURL *url.URL) error
|
||||
}
|
||||
|
||||
type nsGroupsByDomain struct {
|
||||
@@ -907,9 +908,9 @@ func toZone(d domain.Domain) domain.Domain {
|
||||
}
|
||||
|
||||
// PopulateManagementDomain populates the DNS cache with management domain
|
||||
func (s *DefaultServer) PopulateManagementDomain(ctx context.Context, mgmtURL *url.URL) error {
|
||||
func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||
if s.mgmtCacheResolver != nil && mgmtURL != nil {
|
||||
return s.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL)
|
||||
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@@ -673,10 +673,8 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
|
||||
|
||||
// Populate management URL if provided
|
||||
if mgmtURL != nil {
|
||||
if defaultServer, ok := e.dnsServer.(*dns.DefaultServer); ok {
|
||||
if err := defaultServer.PopulateManagementDomain(e.ctx, mgmtURL); err != nil {
|
||||
log.Warnf("failed to populate DNS cache with management URL: %v", err)
|
||||
}
|
||||
if err := e.dnsServer.PopulateManagementDomain(mgmtURL); err != nil {
|
||||
log.Warnf("failed to populate DNS cache with management URL: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user