This commit is contained in:
Viktor Liu
2025-07-11 11:06:19 +02:00
parent 629757c911
commit 90bf1baec2
10 changed files with 504 additions and 583 deletions

View File

@@ -272,11 +272,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.engine.SetNetworkMapPersistence(c.persistNetworkMap) c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
c.engineMutex.Unlock() c.engineMutex.Unlock()
if err := c.engine.Start(); err != nil { if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err) return wrapErr(err)
} }
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected) state.Set(StatusConnected)
@@ -442,8 +443,6 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
BlockInbound: config.BlockInbound, BlockInbound: config.BlockInbound,
LazyConnectionEnabled: config.LazyConnectionEnabled, LazyConnectionEnabled: config.LazyConnectionEnabled,
ManagementURL: config.ManagementURL,
NetbirdConfig: netbirdConfig,
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {

View File

@@ -5,6 +5,7 @@ import (
"net" "net"
"net/netip" "net/netip"
"net/url" "net/url"
"strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -40,19 +41,34 @@ func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
// extractValidDomain extracts a valid domain from a URL, filtering out IP addresses // extractValidDomain extracts a valid domain from a URL, filtering out IP addresses
func extractValidDomain(rawURL string) (domain.Domain, error) { func extractValidDomain(rawURL string) (domain.Domain, error) {
parsedURL, err := url.Parse(rawURL) if rawURL == "" {
if err != nil { return "", fmt.Errorf("empty URL")
// If URL parsing fails, it might be a raw host:port, try parsing as such
if host, _, err := net.SplitHostPort(rawURL); err == nil {
return extractDomainFromHost(host)
}
// If not host:port, try as raw hostname
return extractDomainFromHost(rawURL)
} }
host := parsedURL.Hostname() // Try standard URL parsing first (handles https://, http://, rels://, etc.)
if host == "" { if parsedURL, err := url.Parse(rawURL); err == nil && parsedURL.Hostname() != "" {
return "", fmt.Errorf("no hostname in URL") return extractDomainFromHost(parsedURL.Hostname())
}
// Extract domain from various formats:
// - stun:domain:port -> domain
// - turns:domain:port?params -> domain
// - domain:port -> domain
host := rawURL
// Remove scheme prefix (stun:, turn:, turns:)
if colonIndex := strings.Index(host, ":"); colonIndex > 0 && colonIndex < 10 && !strings.Contains(host[:colonIndex], ".") {
host = host[colonIndex+1:]
}
// Remove port suffix
if hostOnly, _, err := net.SplitHostPort(host); err == nil {
host = hostOnly
}
// Remove query parameters
if queryIndex := strings.Index(host, "?"); queryIndex > 0 {
host = host[:queryIndex]
} }
return extractDomainFromHost(host) return extractDomainFromHost(host)

View File

@@ -0,0 +1,148 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestExtractValidDomain(t *testing.T) {
tests := []struct {
name string
url string
expected string
expectError bool
}{
{
name: "HTTPS URL with port",
url: "https://api.netbird.io:443",
expected: "api.netbird.io",
},
{
name: "HTTP URL without port",
url: "http://signal.example.com",
expected: "signal.example.com",
},
{
name: "Host with port (no scheme)",
url: "signal.netbird.io:443",
expected: "signal.netbird.io",
},
{
name: "STUN URL",
url: "stun:stun.netbird.io:443",
expected: "stun.netbird.io",
},
{
name: "STUN URL with different port",
url: "stun:stun.netbird.io:5555",
expected: "stun.netbird.io",
},
{
name: "TURNS URL with query params",
url: "turns:turn.netbird.io:443?transport=tcp",
expected: "turn.netbird.io",
},
{
name: "TURN URL",
url: "turn:turn.example.com:3478",
expected: "turn.example.com",
},
{
name: "REL URL",
url: "rel://relay.example.com:443",
expected: "relay.example.com",
},
{
name: "RELS URL",
url: "rels://relay.netbird.io:443",
expected: "relay.netbird.io",
},
{
name: "Raw hostname",
url: "example.org",
expected: "example.org",
},
{
name: "IP address should be rejected",
url: "192.168.1.1",
expectError: true,
},
{
name: "IP address with port should be rejected",
url: "192.168.1.1:443",
expectError: true,
},
{
name: "IPv6 address should be rejected",
url: "2001:db8::1",
expectError: true,
},
{
name: "Empty URL",
url: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := extractValidDomain(tt.url)
if tt.expectError {
assert.Error(t, err, "Expected error for URL: %s", tt.url)
} else {
assert.NoError(t, err, "Unexpected error for URL: %s", tt.url)
assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for URL: %s", tt.url)
}
})
}
}
func TestExtractDomainFromHost(t *testing.T) {
tests := []struct {
name string
host string
expected string
expectError bool
}{
{
name: "Valid domain",
host: "example.com",
expected: "example.com",
},
{
name: "Subdomain",
host: "api.example.com",
expected: "api.example.com",
},
{
name: "IPv4 address",
host: "192.168.1.1",
expectError: true,
},
{
name: "IPv6 address",
host: "2001:db8::1",
expectError: true,
},
{
name: "Empty host",
host: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := extractDomainFromHost(tt.host)
if tt.expectError {
assert.Error(t, err, "Expected error for host: %s", tt.host)
} else {
assert.NoError(t, err, "Unexpected error for host: %s", tt.host)
assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for host: %s", tt.host)
}
})
}
}

View File

@@ -16,25 +16,19 @@ import (
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
) )
// CacheEntry holds DNS records for a cached domain
type CacheEntry struct {
ARecords []dns.RR
AAAARecords []dns.RR
}
// Resolver caches critical NetBird infrastructure domains // Resolver caches critical NetBird infrastructure domains
type Resolver struct { type Resolver struct {
cache map[domain.Domain]CacheEntry records map[dns.Question][]dns.RR
mutex sync.RWMutex managementDomain *domain.Domain
mutex sync.RWMutex
} }
// NewResolver creates a new management domains cache resolver. // NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver { func NewResolver() *Resolver {
return &Resolver{ return &Resolver{
cache: make(map[domain.Domain]CacheEntry), records: make(map[dns.Question][]dns.RR),
} }
} }
@@ -51,7 +45,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
question := r.Question[0] question := r.Question[0]
qname := strings.ToLower(strings.TrimSuffix(question.Name, ".")) question.Name = strings.ToLower(dns.Fqdn(question.Name))
if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA { if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA {
m.continueToNext(w, r) m.continueToNext(w, r)
@@ -59,8 +53,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
m.mutex.RLock() m.mutex.RLock()
domainKey := domain.Domain(qname) records, found := m.records[question]
entry, found := m.cache[domainKey]
m.mutex.RUnlock() m.mutex.RUnlock()
if !found { if !found {
@@ -73,34 +66,19 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp.Authoritative = false resp.Authoritative = false
resp.RecursionAvailable = true resp.RecursionAvailable = true
var records []dns.RR resp.Answer = append(resp.Answer, records...)
if question.Qtype == dns.TypeA {
records = entry.ARecords
} else if question.Qtype == dns.TypeAAAA {
records = entry.AAAARecords
}
if len(records) == 0 { log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
m.continueToNext(w, r)
return
}
for _, rr := range records {
rrCopy := dns.Copy(rr)
rrCopy.Header().Name = question.Name
resp.Answer = append(resp.Answer, rrCopy)
}
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), domainKey.SafeString())
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write response: %v", err) log.Errorf("failed to write response: %v", err)
} }
} }
// MatchSubdomains always returns true as required by the interface. // MatchSubdomains returns false since this resolver only handles exact domain matches
// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains.
func (m *Resolver) MatchSubdomains() bool { func (m *Resolver) MatchSubdomains() bool {
return true return false
} }
// continueToNext signals the handler chain to continue to the next handler. // continueToNext signals the handler chain to continue to the next handler.
@@ -115,8 +93,6 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
// AddDomain manually adds a domain to cache by resolving it. // AddDomain manually adds a domain to cache by resolving it.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
log.Debugf("adding domain=%s to cache", d.SafeString())
ctx, cancel := context.WithTimeout(ctx, 10*time.Second) ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel() defer cancel()
@@ -130,7 +106,7 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
if ip.Is4() { if ip.Is4() {
rr := &dns.A{ rr := &dns.A{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: d.PunycodeString() + ".", Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
Rrtype: dns.TypeA, Rrtype: dns.TypeA,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: 300, Ttl: 300,
@@ -141,7 +117,7 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
} else if ip.Is6() { } else if ip.Is6() {
rr := &dns.AAAA{ rr := &dns.AAAA{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: d.PunycodeString() + ".", Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
Rrtype: dns.TypeAAAA, Rrtype: dns.TypeAAAA,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: 300, Ttl: 300,
@@ -153,10 +129,25 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
} }
m.mutex.Lock() m.mutex.Lock()
m.cache[d] = CacheEntry{
ARecords: aRecords, if len(aRecords) > 0 {
AAAARecords: aaaaRecords, aQuestion := dns.Question{
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
} }
if len(aaaaRecords) > 0 {
aaaaQuestion := dns.Question{
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock() m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records", log.Debugf("added domain=%s with %d A records and %d AAAA records",
@@ -167,12 +158,21 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
// PopulateFromConfig extracts and caches domains from the client configuration. // PopulateFromConfig extracts and caches domains from the client configuration.
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error { func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
if mgmtURL != nil { if mgmtURL == nil {
if d, err := extractDomainFromURL(mgmtURL); err == nil { return nil
if err := m.AddDomain(ctx, d); err != nil { }
log.Warnf("failed to add management domain: %v", err)
} d, err := extractDomainFromURL(mgmtURL)
} if err != nil {
return fmt.Errorf("extract domain from URL: %w", err)
}
m.mutex.Lock()
m.managementDomain = &d
m.mutex.Unlock()
if err := m.AddDomain(ctx, d); err != nil {
return fmt.Errorf("add domain: %w", err)
} }
return nil return nil
@@ -183,191 +183,95 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
delete(m.cache, d) aQuestion := dns.Question{
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
delete(m.records, aQuestion)
aaaaQuestion := dns.Question{
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
log.Debugf("removed domain=%s from cache", d.SafeString()) log.Debugf("removed domain=%s from cache", d.SafeString())
return nil return nil
} }
// PopulateFromNetbirdConfig extracts and caches domains from the netbird config.
func (m *Resolver) PopulateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) error {
if config == nil {
return nil
}
m.addSignalDomain(ctx, config.Signal)
m.addRelayDomains(ctx, config.Relay)
m.addFlowDomain(ctx, config.Flow)
m.addStunDomains(ctx, config.Stuns)
m.addTurnDomains(ctx, config.Turns)
return nil
}
// addSignalDomain adds signal server domain to cache.
func (m *Resolver) addSignalDomain(ctx context.Context, signal *mgmProto.HostConfig) {
if signal == nil || signal.Uri == "" {
return
}
signalURL, err := url.Parse(signal.Uri)
if err != nil {
// If parsing fails, it might be a raw host:port, try adding a scheme
signalURL, err = url.Parse("https://" + signal.Uri)
if err != nil {
log.Warnf("failed to parse signal URL: %v", err)
return
}
}
d, err := extractDomainFromURL(signalURL)
if err != nil {
log.Warnf("failed to extract signal domain: %v", err)
return
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("failed to add signal domain: %v", err)
}
}
// addRelayDomains adds relay server domains to cache.
func (m *Resolver) addRelayDomains(ctx context.Context, relay *mgmProto.RelayConfig) {
if relay == nil {
return
}
for _, relayAddr := range relay.Urls {
relayURL, err := url.Parse(relayAddr)
if err != nil {
log.Warnf("failed to parse relay URL %s: %v", relayAddr, err)
continue
}
d, err := extractDomainFromURL(relayURL)
if err != nil {
log.Warnf("failed to extract relay domain from %s: %v", relayAddr, err)
continue
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("failed to add relay domain: %v", err)
}
}
}
// addFlowDomain adds traffic flow server domain to cache.
func (m *Resolver) addFlowDomain(ctx context.Context, flow *mgmProto.FlowConfig) {
if flow == nil || flow.Url == "" {
return
}
flowURL, err := url.Parse(flow.Url)
if err != nil {
log.Warnf("failed to parse flow URL: %v", err)
return
}
d, err := extractDomainFromURL(flowURL)
if err != nil {
log.Warnf("failed to extract flow domain: %v", err)
return
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("failed to add flow domain: %v", err)
}
}
// GetCachedDomains returns a list of all cached domains. // GetCachedDomains returns a list of all cached domains.
func (m *Resolver) GetCachedDomains() []domain.Domain { func (m *Resolver) GetCachedDomains() domain.List {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
domains := make([]domain.Domain, 0, len(m.cache)) domainSet := make(map[domain.Domain]struct{})
for d := range m.cache { for question := range m.records {
domains = append(domains, d) domainName := strings.TrimSuffix(question.Name, ".")
} domainSet[domain.Domain(domainName)] = struct{}{}
return domains
}
// ClearCache removes all cached domains and returns them for external deregistration.
func (m *Resolver) ClearCache() []domain.Domain {
m.mutex.Lock()
defer m.mutex.Unlock()
domains := make([]domain.Domain, 0, len(m.cache))
for d := range m.cache {
domains = append(domains, d)
} }
m.cache = make(map[domain.Domain]CacheEntry) domains := make(domain.List, 0, len(domainSet))
log.Debugf("cleared %d cached domains", len(domains)) for d := range domainSet {
domains = append(domains, d)
}
return domains return domains
} }
// UpdateFromNetbirdConfig updates the cache intelligently by comparing current and new configurations. // UpdateFromServerDomains updates the cache using the simplified ServerDomains struct
// Returns domains that were removed for external deregistration. func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) {
func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) ([]domain.Domain, error) {
log.Debugf("updating cache from NetbirdConfig")
currentDomains := m.GetCachedDomains() currentDomains := m.GetCachedDomains()
newDomains := m.extractDomainsFromConfig(config) newDomains := m.extractDomainsFromServerDomains(serverDomains)
var removedDomains []domain.Domain removedDomains := m.removeStaleDomainsExceptManagement(currentDomains, newDomains)
for _, currentDomain := range currentDomains { m.addNewDomains(ctx, newDomains)
found := false
for _, newDomain := range newDomains {
if currentDomain.SafeString() == newDomain.SafeString() {
found = true
break
}
}
if !found {
removedDomains = append(removedDomains, currentDomain)
}
}
m.mutex.Lock()
for _, domainToRemove := range removedDomains {
delete(m.cache, domainToRemove)
log.Debugf("removed domain=%s from cache", domainToRemove.SafeString())
}
m.mutex.Unlock()
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
}
}
return removedDomains, nil return removedDomains, nil
} }
// UpdateFromServerDomains updates the cache using the simplified ServerDomains struct // removeStaleDomainsExceptManagement removes domains not in newDomains, except management domain
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) ([]domain.Domain, error) { func (m *Resolver) removeStaleDomainsExceptManagement(currentDomains, newDomains domain.List) domain.List {
log.Debugf("updating cache from ServerDomains") var removedDomains domain.List
currentDomains := m.GetCachedDomains()
newDomains := m.extractDomainsFromServerDomains(serverDomains)
var removedDomains []domain.Domain
for _, currentDomain := range currentDomains { for _, currentDomain := range currentDomains {
found := false if m.isDomainInList(currentDomain, newDomains) {
for _, newDomain := range newDomains { continue
if currentDomain.SafeString() == newDomain.SafeString() {
found = true
break
}
} }
if !found {
removedDomains = append(removedDomains, currentDomain) if m.isManagementDomain(currentDomain) {
if err := m.RemoveDomain(currentDomain); err != nil { continue
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err) }
}
removedDomains = append(removedDomains, currentDomain)
if err := m.RemoveDomain(currentDomain); err != nil {
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
} }
} }
return removedDomains
}
// isDomainInList checks if domain exists in the list
func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool {
for _, d := range list {
if domain.SafeString() == d.SafeString() {
return true
}
}
return false
}
// isManagementDomain checks if domain is the protected management domain
func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.managementDomain != nil && domain.SafeString() == m.managementDomain.SafeString()
}
// addNewDomains adds all new domains to the cache
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
for _, newDomain := range newDomains { for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil { if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err) log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
@@ -375,12 +279,10 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString()) log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
} }
} }
return removedDomains, nil
} }
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) []domain.Domain { func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List {
var domains []domain.Domain var domains domain.List
if serverDomains.Signal != "" { if serverDomains.Signal != "" {
domains = append(domains, serverDomains.Signal) domains = append(domains, serverDomains.Signal)
@@ -411,164 +313,6 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains return domains
} }
// extractDomainsFromConfig extracts all domains from a NetbirdConfig.
func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []domain.Domain {
if config == nil {
return nil
}
var domains []domain.Domain
// Extract signal domain
domains = append(domains, m.extractSignalDomain(config)...)
// Extract relay domains
domains = append(domains, m.extractRelayDomains(config)...)
// Extract flow domain
domains = append(domains, m.extractFlowDomain(config)...)
// Extract STUN domains
domains = append(domains, m.extractSTUNDomains(config)...)
// Extract TURN domains
domains = append(domains, m.extractTURNDomains(config)...)
return domains
}
func (m *Resolver) extractSignalDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Signal == nil || config.Signal.Uri == "" {
return nil
}
if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil {
return []domain.Domain{d}
}
return nil
}
func (m *Resolver) extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Relay == nil {
return nil
}
var domains []domain.Domain
for _, relayURL := range config.Relay.Urls {
if d, err := m.extractDomainFromURL(relayURL); err == nil {
domains = append(domains, d)
}
}
return domains
}
func (m *Resolver) extractFlowDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Flow == nil || config.Flow.Url == "" {
return nil
}
if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil {
return []domain.Domain{d}
}
return nil
}
func (m *Resolver) extractSTUNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var domains []domain.Domain
for _, stun := range config.Stuns {
if stun != nil && stun.Uri != "" {
if d, err := m.extractDomainFromURL(stun.Uri); err == nil {
domains = append(domains, d)
}
}
}
return domains
}
func (m *Resolver) extractTURNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var domains []domain.Domain
for _, turn := range config.Turns {
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
if d, err := m.extractDomainFromURL(turn.HostConfig.Uri); err == nil {
domains = append(domains, d)
}
}
}
return domains
}
// extractDomainFromSignalConfig extracts domain from signal configuration.
func (m *Resolver) extractDomainFromSignalConfig(signal *mgmProto.HostConfig) (domain.Domain, error) {
signalURL, err := url.Parse(signal.Uri)
if err != nil {
// If parsing fails, it might be a raw host:port, try adding a scheme
signalURL, err = url.Parse("https://" + signal.Uri)
if err != nil {
return "", err
}
}
return extractDomainFromURL(signalURL)
}
// extractDomainFromURL extracts domain from a URL string.
func (m *Resolver) extractDomainFromURL(urlStr string) (domain.Domain, error) {
parsedURL, err := url.Parse(urlStr)
if err != nil {
return "", err
}
return extractDomainFromURL(parsedURL)
}
// addStunDomains adds STUN server domains to cache.
func (m *Resolver) addStunDomains(ctx context.Context, stuns []*mgmProto.HostConfig) {
for _, stun := range stuns {
if stun == nil || stun.Uri == "" {
continue
}
stunURL, err := url.Parse(stun.Uri)
if err != nil {
log.Warnf("failed to parse STUN URL %s: %v", stun.Uri, err)
continue
}
d, err := extractDomainFromURL(stunURL)
if err != nil {
log.Warnf("failed to extract STUN domain from %s: %v", stun.Uri, err)
continue
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("failed to add STUN domain: %v", err)
}
}
}
// addTurnDomains adds TURN server domains to cache.
func (m *Resolver) addTurnDomains(ctx context.Context, turns []*mgmProto.ProtectedHostConfig) {
for _, turn := range turns {
if turn == nil || turn.HostConfig == nil || turn.HostConfig.Uri == "" {
continue
}
turnURL, err := url.Parse(turn.HostConfig.Uri)
if err != nil {
log.Warnf("failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err)
continue
}
d, err := extractDomainFromURL(turnURL)
if err != nil {
log.Warnf("failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err)
continue
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("failed to add TURN domain: %v", err)
}
}
}
// extractDomainFromURL extracts the domain from a URL. // extractDomainFromURL extracts the domain from a URL.
func extractDomainFromURL(u *url.URL) (domain.Domain, error) { func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
if u == nil { if u == nil {

View File

@@ -2,22 +2,24 @@ package mgmt
import ( import (
"context" "context"
"net"
"net/url" "net/url"
"strings"
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
mgmProto "github.com/netbirdio/netbird/management/proto" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/management/domain"
) )
func TestResolver_NewResolver(t *testing.T) { func TestResolver_NewResolver(t *testing.T) {
resolver := NewResolver() resolver := NewResolver()
assert.NotNil(t, resolver) assert.NotNil(t, resolver)
assert.NotNil(t, resolver.cache) assert.NotNil(t, resolver.records)
assert.True(t, resolver.MatchSubdomains()) assert.False(t, resolver.MatchSubdomains())
} }
func TestResolver_ExtractDomainFromURL(t *testing.T) { func TestResolver_ExtractDomainFromURL(t *testing.T) {
@@ -113,145 +115,173 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
resolver := NewResolver() resolver := NewResolver()
// Use IP address to avoid DNS resolution timeout // Test with IP address - should return error since IP addresses are rejected
mgmtURL, _ := url.Parse("https://127.0.0.1") mgmtURL, _ := url.Parse("https://127.0.0.1")
err := resolver.PopulateFromConfig(ctx, mgmtURL) err := resolver.PopulateFromConfig(ctx, mgmtURL)
assert.NoError(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "host is an IP address")
// IP addresses are rejected, so no domains should be cached // No domains should be cached when using IP addresses
domains := resolver.GetCachedDomains() domains := resolver.GetCachedDomains()
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses") assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
} }
func TestResolver_PopulateFromNetbirdConfig(t *testing.T) { func TestResolver_ServeDNS(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := NewResolver() resolver := NewResolver()
ctx := context.Background()
// Use IP addresses to avoid DNS resolution timeouts // Add a test domain to the cache - use example.org which is reserved for testing
netbirdConfig := &mgmProto.NetbirdConfig{ testDomain, err := domain.FromString("example.org")
Signal: &mgmProto.HostConfig{ if err != nil {
Uri: "https://10.0.0.1", t.Fatalf("Failed to create domain: %v", err)
}, }
Relay: &mgmProto.RelayConfig{ err = resolver.AddDomain(ctx, testDomain)
Urls: []string{ if err != nil {
"https://10.0.0.2:443", t.Skipf("Skipping test due to DNS resolution failure: %v", err)
"https://10.0.0.3:443",
},
},
Flow: &mgmProto.FlowConfig{
Url: "https://10.0.0.4:80",
},
Stuns: []*mgmProto.HostConfig{
{Uri: "stun:10.0.0.5:3478"},
{Uri: "stun:10.0.0.6:3478"},
},
Turns: []*mgmProto.ProtectedHostConfig{
{
HostConfig: &mgmProto.HostConfig{
Uri: "turn:10.0.0.7:3478",
},
},
{
HostConfig: &mgmProto.HostConfig{
Uri: "turn:10.0.0.8:3478",
},
},
},
} }
err := resolver.PopulateFromNetbirdConfig(ctx, netbirdConfig) // Test A record query for cached domain
assert.NoError(t, err) t.Run("Cached domain A record", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
// IP addresses are rejected, so no domains should be cached req := new(dns.Msg)
domains := resolver.GetCachedDomains() req.SetQuestion("example.org.", dns.TypeA)
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
assert.True(t, len(capturedMsg.Answer) > 0, "Should have at least one answer")
})
// Test uncached domain signals to continue to next handler
t.Run("Uncached domain signals continue to next handler", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
req := new(dns.Msg)
req.SetQuestion("unknown.example.com.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
// Zero flag set to true signals the handler chain to continue to next handler
assert.True(t, capturedMsg.MsgHdr.Zero, "Zero flag should be set to signal continuation to next handler")
assert.Empty(t, capturedMsg.Answer, "Should have no answers for uncached domain")
})
// Test that subdomains of cached domains are NOT resolved
t.Run("Subdomains of cached domains are not resolved", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
// Query for a subdomain of our cached domain
req := new(dns.Msg)
req.SetQuestion("sub.example.org.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
assert.True(t, capturedMsg.MsgHdr.Zero, "Should signal continuation to next handler for subdomains")
assert.Empty(t, capturedMsg.Answer, "Should have no answers for subdomains")
})
// Test case-insensitive matching
t.Run("Case-insensitive domain matching", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
// Query with different casing
req := new(dns.Msg)
req.SetQuestion("EXAMPLE.ORG.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
assert.True(t, len(capturedMsg.Answer) > 0, "Should resolve regardless of case")
})
} }
func TestResolver_UpdateFromNetbirdConfig(t *testing.T) { func TestResolver_GetCachedDomains(t *testing.T) {
resolver := NewResolver() resolver := NewResolver()
ctx := context.Background()
// Test with empty initial config and then add domains testDomain, err := domain.FromString("example.org")
initialConfig := &mgmProto.NetbirdConfig{} if err != nil {
t.Fatalf("Failed to create domain: %v", err)
// Start with empty config }
removedDomains, err := resolver.UpdateFromNetbirdConfig(context.Background(), initialConfig) err = resolver.AddDomain(ctx, testDomain)
assert.NoError(t, err) if err != nil {
assert.Equal(t, 0, len(removedDomains), "No domains should be removed from empty cache") t.Skipf("Skipping test due to DNS resolution failure: %v", err)
// Update to config with IP addresses instead of domains to avoid DNS resolution
// IP addresses will be rejected by extractDomainFromURL so no actual resolution happens
updatedConfig := &mgmProto.NetbirdConfig{
Signal: &mgmProto.HostConfig{
Uri: "https://127.0.0.1",
},
Flow: &mgmProto.FlowConfig{
Url: "https://192.168.1.1:80",
},
} }
removedDomains, err = resolver.UpdateFromNetbirdConfig(context.Background(), updatedConfig) cachedDomains := resolver.GetCachedDomains()
assert.NoError(t, err)
// Verify the method completes successfully without DNS timeouts assert.Equal(t, 1, len(cachedDomains), "Should return exactly one domain for single added domain")
assert.GreaterOrEqual(t, len(removedDomains), 0, "Should not error on config update") assert.Equal(t, testDomain.SafeString(), cachedDomains[0].SafeString(), "Cached domain should match original")
assert.False(t, strings.HasSuffix(cachedDomains[0].PunycodeString(), "."), "Domain should not have trailing dot")
// Verify no domains were actually added since IPs are rejected
domains := resolver.GetCachedDomains()
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
} }
func TestResolver_ContinueToNext(t *testing.T) { func TestResolver_ManagementDomainProtection(t *testing.T) {
resolver := NewResolver() resolver := NewResolver()
ctx := context.Background()
// Create a mock response writer to capture the response mgmtURL, _ := url.Parse("https://example.org")
mockWriter := &MockResponseWriter{} err := resolver.PopulateFromConfig(ctx, mgmtURL)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
// Create a test DNS query initialDomains := resolver.GetCachedDomains()
req := new(dns.Msg) if len(initialDomains) == 0 {
req.SetQuestion("unknown.example.com.", dns.TypeA) t.Skip("Management domain failed to resolve, skipping test")
}
assert.Equal(t, 1, len(initialDomains), "Should have management domain cached")
assert.Equal(t, "example.org", initialDomains[0].SafeString())
// Call continueToNext serverDomains := dnsconfig.ServerDomains{
resolver.continueToNext(mockWriter, req) Signal: "google.com",
Relay: []domain.Domain{"cloudflare.com"},
}
// Verify the response _, err = resolver.UpdateFromServerDomains(ctx, serverDomains)
assert.NotNil(t, mockWriter.msg) if err != nil {
assert.Equal(t, dns.RcodeNameError, mockWriter.msg.Rcode) t.Logf("Server domains update failed: %v", err)
assert.True(t, mockWriter.msg.MsgHdr.Zero) }
finalDomains := resolver.GetCachedDomains()
managementStillCached := false
for _, d := range finalDomains {
if d.SafeString() == "example.org" {
managementStillCached = true
break
}
}
assert.True(t, managementStillCached, "Management domain should never be removed")
} }
// MockResponseWriter is a simple mock implementation of dns.ResponseWriter for testing
type MockResponseWriter struct {
msg *dns.Msg
}
func (m *MockResponseWriter) LocalAddr() net.Addr {
return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 53}
}
func (m *MockResponseWriter) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
}
func (m *MockResponseWriter) WriteMsg(msg *dns.Msg) error {
m.msg = msg
return nil
}
func (m *MockResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (m *MockResponseWriter) Close() error {
return nil
}
func (m *MockResponseWriter) TsigStatus() error {
return nil
}
func (m *MockResponseWriter) TsigTimersOnly(bool) {}
func (m *MockResponseWriter) Hijack() {}

View File

@@ -74,7 +74,6 @@ type DefaultServer struct {
handlerChain *HandlerChain handlerChain *HandlerChain
extraDomains map[domain.Domain]int extraDomains map[domain.Domain]int
// management cache resolver for critical infrastructure domains
mgmtCacheResolver *mgmt.Resolver mgmtCacheResolver *mgmt.Resolver
// permanent related properties // permanent related properties
@@ -106,18 +105,15 @@ type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer // DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct { type DefaultServerConfig struct {
Ctx context.Context
WgInterface WGIface WgInterface WGIface
CustomAddress string CustomAddress string
StatusRecorder *peer.Status StatusRecorder *peer.Status
StateManager *statemanager.Manager StateManager *statemanager.Manager
DisableSys bool DisableSys bool
MgmtURL *url.URL
ServerDomains dnsconfig.ServerDomains
} }
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer(config DefaultServerConfig) (*DefaultServer, error) { func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*DefaultServer, error) {
var addrPort *netip.AddrPort var addrPort *netip.AddrPort
if config.CustomAddress != "" { if config.CustomAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress) parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress)
@@ -134,20 +130,7 @@ func NewDefaultServer(config DefaultServerConfig) (*DefaultServer, error) {
dnsService = newServiceViaListener(config.WgInterface, addrPort) dnsService = newServiceViaListener(config.WgInterface, addrPort)
} }
server := newDefaultServer(config.Ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys) server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
if config.MgmtURL != nil && server.mgmtCacheResolver != nil {
if err := server.mgmtCacheResolver.PopulateFromConfig(config.Ctx, config.MgmtURL); err != nil {
log.Warnf("Failed to populate management cache from management URL: %v", err)
}
}
if server.mgmtCacheResolver != nil {
if err := server.UpdateServerConfig(config.ServerDomains); err != nil {
log.Warnf("Failed to populate management cache from ServerDomains: %v", err)
}
}
return server, nil return server, nil
} }
@@ -197,7 +180,6 @@ func newDefaultServer(
handlerChain := NewHandlerChain() handlerChain := NewHandlerChain()
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
// Create management cache resolver
mgmtCacheResolver := mgmt.NewResolver() mgmtCacheResolver := mgmt.NewResolver()
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
@@ -345,20 +327,8 @@ func (s *DefaultServer) Stop() {
maps.Clear(s.extraDomains) maps.Clear(s.extraDomains)
} }
// PopulateMgmtCacheFromConfig populates the management cache with domains from the client configuration
func (s *DefaultServer) PopulateMgmtCacheFromConfig(mgmtURL *url.URL) error {
if s.mgmtCacheResolver == nil {
return fmt.Errorf("management cache resolver not initialized")
}
log.Debug("populating management cache from client configuration")
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
}
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone // It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
s.hostsDNSHolder.set(hostsDnsList) s.hostsDNSHolder.set(hostsDnsList)
@@ -465,12 +435,12 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro
} }
if len(removedDomains) > 0 { if len(removedDomains) > 0 {
s.DeregisterHandler(removedDomains, PriorityMgmtCache) s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache)
} }
newDomains := s.mgmtCacheResolver.GetCachedDomains() newDomains := s.mgmtCacheResolver.GetCachedDomains()
if len(newDomains) > 0 { if len(newDomains) > 0 {
s.RegisterHandler(newDomains, s.mgmtCacheResolver, PriorityMgmtCache) s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache)
} }
} }
@@ -935,3 +905,11 @@ func toZone(d domain.Domain) domain.Domain {
), ),
) )
} }
// PopulateManagementDomain populates the DNS cache with management domain
func (s *DefaultServer) PopulateManagementDomain(ctx context.Context, mgmtURL *url.URL) error {
if s.mgmtCacheResolver != nil && mgmtURL != nil {
return s.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL)
}
return nil
}

View File

@@ -23,7 +23,6 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks" pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/local" "github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
@@ -364,15 +363,12 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err) t.Log(err)
} }
}() }()
dnsServer, err := NewDefaultServer(DefaultServerConfig{ dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
Ctx: context.Background(),
WgInterface: wgIface, WgInterface: wgIface,
CustomAddress: "", CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"), StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil, StateManager: nil,
DisableSys: false, DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -483,15 +479,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return return
} }
dnsServer, err := NewDefaultServer(DefaultServerConfig{ dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
Ctx: context.Background(),
WgInterface: wgIface, WgInterface: wgIface,
CustomAddress: "", CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"), StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil, StateManager: nil,
DisableSys: false, DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
}) })
if err != nil { if err != nil {
t.Errorf("create DNS server: %v", err) t.Errorf("create DNS server: %v", err)
@@ -594,15 +587,12 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(DefaultServerConfig{ dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
Ctx: context.Background(),
WgInterface: &mocWGIface{}, WgInterface: &mocWGIface{},
CustomAddress: testCase.addrPort, CustomAddress: testCase.addrPort,
StatusRecorder: peer.NewRecorder("mgm"), StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil, StateManager: nil,
DisableSys: false, DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
}) })
if err != nil { if err != nil {
t.Fatalf("%v", err) t.Fatalf("%v", err)

View File

@@ -33,7 +33,7 @@ import (
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/config" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
@@ -126,12 +126,6 @@ type EngineConfig struct {
BlockInbound bool BlockInbound bool
LazyConnectionEnabled bool LazyConnectionEnabled bool
// ManagementURL is the URL of the management server for DNS cache
ManagementURL *url.URL
// NetbirdConfig contains signal, relay, and flow server configuration
NetbirdConfig *mgmProto.NetbirdConfig
} }
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -350,7 +344,7 @@ func (e *Engine) Stop() error {
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here. // Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service // However, they will be established once an event with a list of peers to connect to will be received from Management Service
func (e *Engine) Start() error { func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
@@ -395,13 +389,18 @@ func (e *Engine) Start() error {
return fmt.Errorf("read initial settings: %w", err) return fmt.Errorf("read initial settings: %w", err)
} }
dnsServer, err := e.newDnsServer(dnsConfig, e.config.ManagementURL, e.config.NetbirdConfig) dnsServer, err := e.newDnsServer(dnsConfig)
if err != nil { if err != nil {
e.close() e.close()
return fmt.Errorf("create dns server: %w", err) return fmt.Errorf("create dns server: %w", err)
} }
e.dnsServer = dnsServer e.dnsServer = dnsServer
// Populate DNS cache with NetbirdConfig and management URL for early resolution
if err := e.PopulateNetbirdConfig(netbirdConfig, mgmtURL); err != nil {
log.Warnf("failed to populate DNS cache: %v", err)
}
e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
Context: e.ctx, Context: e.ctx,
PublicKey: e.config.WgPrivateKey.PublicKey().String(), PublicKey: e.config.WgPrivateKey.PublicKey().String(),
@@ -666,6 +665,32 @@ func (e *Engine) removePeer(peerKey string) error {
return nil return nil
} }
// PopulateNetbirdConfig populates the DNS cache with infrastructure domains from login response
func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
if e.dnsServer == nil {
return nil
}
// Populate management URL if provided
if mgmtURL != nil {
if defaultServer, ok := e.dnsServer.(*dns.DefaultServer); ok {
if err := defaultServer.PopulateManagementDomain(e.ctx, mgmtURL); err != nil {
log.Warnf("failed to populate DNS cache with management URL: %v", err)
}
}
}
// Populate NetbirdConfig domains if provided
if netbirdConfig != nil {
serverDomains := dnsconfig.ExtractFromNetbirdConfig(netbirdConfig)
if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
return fmt.Errorf("update DNS server config from NetbirdConfig: %w", err)
}
}
return nil
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
@@ -697,11 +722,8 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return fmt.Errorf("handle the flow configuration: %w", err) return fmt.Errorf("handle the flow configuration: %w", err)
} }
if e.dnsServer != nil { if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
serverDomains := config.ExtractFromNetbirdConfig(wCfg) log.Warnf("Failed to update DNS server config: %v", err)
if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
} }
// todo update signal // todo update signal
@@ -1587,7 +1609,7 @@ func (e *Engine) wgInterfaceCreate() (err error) {
return err return err
} }
func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbirdConfig *mgmProto.NetbirdConfig) (dns.Server, error) { func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
// due to tests where we are using a mocked version of the DNS server // due to tests where we are using a mocked version of the DNS server
if e.dnsServer != nil { if e.dnsServer != nil {
return e.dnsServer, nil return e.dnsServer, nil
@@ -1612,18 +1634,13 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbird
return dnsServer, nil return dnsServer, nil
default: default:
// Extract domains from NetBird configuration
serverDomains := config.ExtractFromNetbirdConfig(netbirdConfig) dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{
dnsServer, err := dns.NewDefaultServer(dns.DefaultServerConfig{
Ctx: e.ctx,
WgInterface: e.wgInterface, WgInterface: e.wgInterface,
CustomAddress: e.config.CustomDNSAddress, CustomAddress: e.config.CustomDNSAddress,
StatusRecorder: e.statusRecorder, StatusRecorder: e.statusRecorder,
StateManager: e.stateManager, StateManager: e.stateManager,
DisableSys: e.config.DisableDNS, DisableSys: e.config.DisableDNS,
MgmtURL: mgmtURL,
ServerDomains: serverDomains,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1643,11 +1660,6 @@ func (e *Engine) GetFirewallManager() firewallManager.Manager {
return e.firewall return e.firewall
} }
// GetDNSServer returns the DNS server
func (e *Engine) GetDNSServer() dns.Server {
return e.dnsServer
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) { func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName) iface, err := net.InterfaceByName(ifaceName)
if err != nil { if err != nil {

View File

@@ -261,7 +261,7 @@ func TestEngine_SSH(t *testing.T) {
}, },
}, nil }, nil
} }
err = engine.Start() err = engine.Start(nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -605,7 +605,7 @@ func TestEngine_Sync(t *testing.T) {
} }
}() }()
err = engine.Start() err = engine.Start(nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -1060,7 +1060,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
defer mu.Unlock() defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String()) guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid) device.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start() err = engine.Start(nil, nil)
if err != nil { if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err) t.Errorf("unable to start engine for peer %d with error %v", j, err)
wg.Done() wg.Done()

View File

@@ -39,7 +39,7 @@ func IsLoginRequired(ctx context.Context, config *Config) (bool, error) {
return false, err return false, err
} }
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) _, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if isLoginNeeded(err) { if isLoginNeeded(err) {
return true, nil return true, nil
} }
@@ -68,14 +68,18 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
return err return err
} }
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if serverKey != nil && isRegistrationNeeded(err) { if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required") log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config) _, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
if err != nil {
return err
}
} else if err != nil {
return err return err
} }
return err return nil
} }
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) { func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
@@ -100,11 +104,11 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm
return mgmClient, err return mgmClient, err
} }
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *Config) (*wgtypes.Key, error) { func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := mgmClient.GetServerPublicKey() serverKey, err := mgmClient.GetServerPublicKey()
if err != nil { if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err) log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err return nil, nil, err
} }
sysInfo := system.GetInfo(ctx) sysInfo := system.GetInfo(ctx)
@@ -120,8 +124,8 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
config.BlockInbound, config.BlockInbound,
config.LazyConnectionEnabled, config.LazyConnectionEnabled,
) )
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, err return serverKey, loginResp, err
} }
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.