mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-18 11:00:06 +02:00
Cleanup
This commit is contained in:
@@ -272,11 +272,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
|
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
if err := c.engine.Start(); err != nil {
|
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
@@ -442,8 +443,6 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
|||||||
BlockInbound: config.BlockInbound,
|
BlockInbound: config.BlockInbound,
|
||||||
|
|
||||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||||
ManagementURL: config.ManagementURL,
|
|
||||||
NetbirdConfig: netbirdConfig,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PreSharedKey != "" {
|
if config.PreSharedKey != "" {
|
||||||
|
@@ -5,6 +5,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -40,19 +41,34 @@ func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
|
|||||||
|
|
||||||
// extractValidDomain extracts a valid domain from a URL, filtering out IP addresses
|
// extractValidDomain extracts a valid domain from a URL, filtering out IP addresses
|
||||||
func extractValidDomain(rawURL string) (domain.Domain, error) {
|
func extractValidDomain(rawURL string) (domain.Domain, error) {
|
||||||
parsedURL, err := url.Parse(rawURL)
|
if rawURL == "" {
|
||||||
if err != nil {
|
return "", fmt.Errorf("empty URL")
|
||||||
// If URL parsing fails, it might be a raw host:port, try parsing as such
|
|
||||||
if host, _, err := net.SplitHostPort(rawURL); err == nil {
|
|
||||||
return extractDomainFromHost(host)
|
|
||||||
}
|
|
||||||
// If not host:port, try as raw hostname
|
|
||||||
return extractDomainFromHost(rawURL)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
host := parsedURL.Hostname()
|
// Try standard URL parsing first (handles https://, http://, rels://, etc.)
|
||||||
if host == "" {
|
if parsedURL, err := url.Parse(rawURL); err == nil && parsedURL.Hostname() != "" {
|
||||||
return "", fmt.Errorf("no hostname in URL")
|
return extractDomainFromHost(parsedURL.Hostname())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract domain from various formats:
|
||||||
|
// - stun:domain:port -> domain
|
||||||
|
// - turns:domain:port?params -> domain
|
||||||
|
// - domain:port -> domain
|
||||||
|
host := rawURL
|
||||||
|
|
||||||
|
// Remove scheme prefix (stun:, turn:, turns:)
|
||||||
|
if colonIndex := strings.Index(host, ":"); colonIndex > 0 && colonIndex < 10 && !strings.Contains(host[:colonIndex], ".") {
|
||||||
|
host = host[colonIndex+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove port suffix
|
||||||
|
if hostOnly, _, err := net.SplitHostPort(host); err == nil {
|
||||||
|
host = hostOnly
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove query parameters
|
||||||
|
if queryIndex := strings.Index(host, "?"); queryIndex > 0 {
|
||||||
|
host = host[:queryIndex]
|
||||||
}
|
}
|
||||||
|
|
||||||
return extractDomainFromHost(host)
|
return extractDomainFromHost(host)
|
||||||
|
148
client/internal/dns/config/domains_test.go
Normal file
148
client/internal/dns/config/domains_test.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractValidDomain(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
url string
|
||||||
|
expected string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "HTTPS URL with port",
|
||||||
|
url: "https://api.netbird.io:443",
|
||||||
|
expected: "api.netbird.io",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTP URL without port",
|
||||||
|
url: "http://signal.example.com",
|
||||||
|
expected: "signal.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Host with port (no scheme)",
|
||||||
|
url: "signal.netbird.io:443",
|
||||||
|
expected: "signal.netbird.io",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "STUN URL",
|
||||||
|
url: "stun:stun.netbird.io:443",
|
||||||
|
expected: "stun.netbird.io",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "STUN URL with different port",
|
||||||
|
url: "stun:stun.netbird.io:5555",
|
||||||
|
expected: "stun.netbird.io",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TURNS URL with query params",
|
||||||
|
url: "turns:turn.netbird.io:443?transport=tcp",
|
||||||
|
expected: "turn.netbird.io",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TURN URL",
|
||||||
|
url: "turn:turn.example.com:3478",
|
||||||
|
expected: "turn.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "REL URL",
|
||||||
|
url: "rel://relay.example.com:443",
|
||||||
|
expected: "relay.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RELS URL",
|
||||||
|
url: "rels://relay.netbird.io:443",
|
||||||
|
expected: "relay.netbird.io",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Raw hostname",
|
||||||
|
url: "example.org",
|
||||||
|
expected: "example.org",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IP address should be rejected",
|
||||||
|
url: "192.168.1.1",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IP address with port should be rejected",
|
||||||
|
url: "192.168.1.1:443",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 address should be rejected",
|
||||||
|
url: "2001:db8::1",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty URL",
|
||||||
|
url: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := extractValidDomain(tt.url)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err, "Expected error for URL: %s", tt.url)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err, "Unexpected error for URL: %s", tt.url)
|
||||||
|
assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for URL: %s", tt.url)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractDomainFromHost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
host string
|
||||||
|
expected string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid domain",
|
||||||
|
host: "example.com",
|
||||||
|
expected: "example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Subdomain",
|
||||||
|
host: "api.example.com",
|
||||||
|
expected: "api.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 address",
|
||||||
|
host: "192.168.1.1",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 address",
|
||||||
|
host: "2001:db8::1",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty host",
|
||||||
|
host: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := extractDomainFromHost(tt.host)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err, "Expected error for host: %s", tt.host)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err, "Unexpected error for host: %s", tt.host)
|
||||||
|
assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for host: %s", tt.host)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@@ -16,25 +16,19 @@ import (
|
|||||||
|
|
||||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// CacheEntry holds DNS records for a cached domain
|
|
||||||
type CacheEntry struct {
|
|
||||||
ARecords []dns.RR
|
|
||||||
AAAARecords []dns.RR
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resolver caches critical NetBird infrastructure domains
|
// Resolver caches critical NetBird infrastructure domains
|
||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
cache map[domain.Domain]CacheEntry
|
records map[dns.Question][]dns.RR
|
||||||
mutex sync.RWMutex
|
managementDomain *domain.Domain
|
||||||
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResolver creates a new management domains cache resolver.
|
// NewResolver creates a new management domains cache resolver.
|
||||||
func NewResolver() *Resolver {
|
func NewResolver() *Resolver {
|
||||||
return &Resolver{
|
return &Resolver{
|
||||||
cache: make(map[domain.Domain]CacheEntry),
|
records: make(map[dns.Question][]dns.RR),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,7 +45,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
question := r.Question[0]
|
question := r.Question[0]
|
||||||
qname := strings.ToLower(strings.TrimSuffix(question.Name, "."))
|
question.Name = strings.ToLower(dns.Fqdn(question.Name))
|
||||||
|
|
||||||
if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA {
|
if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA {
|
||||||
m.continueToNext(w, r)
|
m.continueToNext(w, r)
|
||||||
@@ -59,8 +53,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
domainKey := domain.Domain(qname)
|
records, found := m.records[question]
|
||||||
entry, found := m.cache[domainKey]
|
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
@@ -73,34 +66,19 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
resp.Authoritative = false
|
resp.Authoritative = false
|
||||||
resp.RecursionAvailable = true
|
resp.RecursionAvailable = true
|
||||||
|
|
||||||
var records []dns.RR
|
resp.Answer = append(resp.Answer, records...)
|
||||||
if question.Qtype == dns.TypeA {
|
|
||||||
records = entry.ARecords
|
|
||||||
} else if question.Qtype == dns.TypeAAAA {
|
|
||||||
records = entry.AAAARecords
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(records) == 0 {
|
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
||||||
m.continueToNext(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rr := range records {
|
|
||||||
rrCopy := dns.Copy(rr)
|
|
||||||
rrCopy.Header().Name = question.Name
|
|
||||||
resp.Answer = append(resp.Answer, rrCopy)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), domainKey.SafeString())
|
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write response: %v", err)
|
log.Errorf("failed to write response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MatchSubdomains always returns true as required by the interface.
|
// MatchSubdomains returns false since this resolver only handles exact domain matches
|
||||||
|
// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains.
|
||||||
func (m *Resolver) MatchSubdomains() bool {
|
func (m *Resolver) MatchSubdomains() bool {
|
||||||
return true
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// continueToNext signals the handler chain to continue to the next handler.
|
// continueToNext signals the handler chain to continue to the next handler.
|
||||||
@@ -115,8 +93,6 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
// AddDomain manually adds a domain to cache by resolving it.
|
// AddDomain manually adds a domain to cache by resolving it.
|
||||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||||
log.Debugf("adding domain=%s to cache", d.SafeString())
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -130,7 +106,7 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
|||||||
if ip.Is4() {
|
if ip.Is4() {
|
||||||
rr := &dns.A{
|
rr := &dns.A{
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
Name: d.PunycodeString() + ".",
|
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
|
||||||
Rrtype: dns.TypeA,
|
Rrtype: dns.TypeA,
|
||||||
Class: dns.ClassINET,
|
Class: dns.ClassINET,
|
||||||
Ttl: 300,
|
Ttl: 300,
|
||||||
@@ -141,7 +117,7 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
|||||||
} else if ip.Is6() {
|
} else if ip.Is6() {
|
||||||
rr := &dns.AAAA{
|
rr := &dns.AAAA{
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
Name: d.PunycodeString() + ".",
|
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
|
||||||
Rrtype: dns.TypeAAAA,
|
Rrtype: dns.TypeAAAA,
|
||||||
Class: dns.ClassINET,
|
Class: dns.ClassINET,
|
||||||
Ttl: 300,
|
Ttl: 300,
|
||||||
@@ -153,10 +129,25 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
m.cache[d] = CacheEntry{
|
|
||||||
ARecords: aRecords,
|
if len(aRecords) > 0 {
|
||||||
AAAARecords: aaaaRecords,
|
aQuestion := dns.Question{
|
||||||
|
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
|
||||||
|
Qtype: dns.TypeA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
m.records[aQuestion] = aRecords
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(aaaaRecords) > 0 {
|
||||||
|
aaaaQuestion := dns.Question{
|
||||||
|
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
|
||||||
|
Qtype: dns.TypeAAAA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
m.records[aaaaQuestion] = aaaaRecords
|
||||||
|
}
|
||||||
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
||||||
@@ -167,12 +158,21 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
|||||||
|
|
||||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||||
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
|
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
|
||||||
if mgmtURL != nil {
|
if mgmtURL == nil {
|
||||||
if d, err := extractDomainFromURL(mgmtURL); err == nil {
|
return nil
|
||||||
if err := m.AddDomain(ctx, d); err != nil {
|
}
|
||||||
log.Warnf("failed to add management domain: %v", err)
|
|
||||||
}
|
d, err := extractDomainFromURL(mgmtURL)
|
||||||
}
|
if err != nil {
|
||||||
|
return fmt.Errorf("extract domain from URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.managementDomain = &d
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
|
return fmt.Errorf("add domain: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -183,191 +183,95 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
delete(m.cache, d)
|
aQuestion := dns.Question{
|
||||||
|
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
|
||||||
|
Qtype: dns.TypeA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
delete(m.records, aQuestion)
|
||||||
|
|
||||||
|
aaaaQuestion := dns.Question{
|
||||||
|
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
|
||||||
|
Qtype: dns.TypeAAAA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
delete(m.records, aaaaQuestion)
|
||||||
|
|
||||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PopulateFromNetbirdConfig extracts and caches domains from the netbird config.
|
|
||||||
func (m *Resolver) PopulateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) error {
|
|
||||||
if config == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
m.addSignalDomain(ctx, config.Signal)
|
|
||||||
m.addRelayDomains(ctx, config.Relay)
|
|
||||||
m.addFlowDomain(ctx, config.Flow)
|
|
||||||
m.addStunDomains(ctx, config.Stuns)
|
|
||||||
m.addTurnDomains(ctx, config.Turns)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addSignalDomain adds signal server domain to cache.
|
|
||||||
func (m *Resolver) addSignalDomain(ctx context.Context, signal *mgmProto.HostConfig) {
|
|
||||||
if signal == nil || signal.Uri == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
signalURL, err := url.Parse(signal.Uri)
|
|
||||||
if err != nil {
|
|
||||||
// If parsing fails, it might be a raw host:port, try adding a scheme
|
|
||||||
signalURL, err = url.Parse("https://" + signal.Uri)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse signal URL: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := extractDomainFromURL(signalURL)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to extract signal domain: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.AddDomain(ctx, d); err != nil {
|
|
||||||
log.Warnf("failed to add signal domain: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// addRelayDomains adds relay server domains to cache.
|
|
||||||
func (m *Resolver) addRelayDomains(ctx context.Context, relay *mgmProto.RelayConfig) {
|
|
||||||
if relay == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, relayAddr := range relay.Urls {
|
|
||||||
relayURL, err := url.Parse(relayAddr)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse relay URL %s: %v", relayAddr, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := extractDomainFromURL(relayURL)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to extract relay domain from %s: %v", relayAddr, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.AddDomain(ctx, d); err != nil {
|
|
||||||
log.Warnf("failed to add relay domain: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// addFlowDomain adds traffic flow server domain to cache.
|
|
||||||
func (m *Resolver) addFlowDomain(ctx context.Context, flow *mgmProto.FlowConfig) {
|
|
||||||
if flow == nil || flow.Url == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
flowURL, err := url.Parse(flow.Url)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse flow URL: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := extractDomainFromURL(flowURL)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to extract flow domain: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.AddDomain(ctx, d); err != nil {
|
|
||||||
log.Warnf("failed to add flow domain: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCachedDomains returns a list of all cached domains.
|
// GetCachedDomains returns a list of all cached domains.
|
||||||
func (m *Resolver) GetCachedDomains() []domain.Domain {
|
func (m *Resolver) GetCachedDomains() domain.List {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
domains := make([]domain.Domain, 0, len(m.cache))
|
domainSet := make(map[domain.Domain]struct{})
|
||||||
for d := range m.cache {
|
for question := range m.records {
|
||||||
domains = append(domains, d)
|
domainName := strings.TrimSuffix(question.Name, ".")
|
||||||
}
|
domainSet[domain.Domain(domainName)] = struct{}{}
|
||||||
return domains
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearCache removes all cached domains and returns them for external deregistration.
|
|
||||||
func (m *Resolver) ClearCache() []domain.Domain {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
domains := make([]domain.Domain, 0, len(m.cache))
|
|
||||||
for d := range m.cache {
|
|
||||||
domains = append(domains, d)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.cache = make(map[domain.Domain]CacheEntry)
|
domains := make(domain.List, 0, len(domainSet))
|
||||||
log.Debugf("cleared %d cached domains", len(domains))
|
for d := range domainSet {
|
||||||
|
domains = append(domains, d)
|
||||||
|
}
|
||||||
|
|
||||||
return domains
|
return domains
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateFromNetbirdConfig updates the cache intelligently by comparing current and new configurations.
|
// UpdateFromServerDomains updates the cache using the simplified ServerDomains struct
|
||||||
// Returns domains that were removed for external deregistration.
|
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) {
|
||||||
func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) ([]domain.Domain, error) {
|
|
||||||
log.Debugf("updating cache from NetbirdConfig")
|
|
||||||
|
|
||||||
currentDomains := m.GetCachedDomains()
|
currentDomains := m.GetCachedDomains()
|
||||||
newDomains := m.extractDomainsFromConfig(config)
|
newDomains := m.extractDomainsFromServerDomains(serverDomains)
|
||||||
|
|
||||||
var removedDomains []domain.Domain
|
removedDomains := m.removeStaleDomainsExceptManagement(currentDomains, newDomains)
|
||||||
for _, currentDomain := range currentDomains {
|
m.addNewDomains(ctx, newDomains)
|
||||||
found := false
|
|
||||||
for _, newDomain := range newDomains {
|
|
||||||
if currentDomain.SafeString() == newDomain.SafeString() {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
removedDomains = append(removedDomains, currentDomain)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
for _, domainToRemove := range removedDomains {
|
|
||||||
delete(m.cache, domainToRemove)
|
|
||||||
log.Debugf("removed domain=%s from cache", domainToRemove.SafeString())
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
for _, newDomain := range newDomains {
|
|
||||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
|
||||||
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return removedDomains, nil
|
return removedDomains, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateFromServerDomains updates the cache using the simplified ServerDomains struct
|
// removeStaleDomainsExceptManagement removes domains not in newDomains, except management domain
|
||||||
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) ([]domain.Domain, error) {
|
func (m *Resolver) removeStaleDomainsExceptManagement(currentDomains, newDomains domain.List) domain.List {
|
||||||
log.Debugf("updating cache from ServerDomains")
|
var removedDomains domain.List
|
||||||
|
|
||||||
currentDomains := m.GetCachedDomains()
|
|
||||||
newDomains := m.extractDomainsFromServerDomains(serverDomains)
|
|
||||||
|
|
||||||
var removedDomains []domain.Domain
|
|
||||||
for _, currentDomain := range currentDomains {
|
for _, currentDomain := range currentDomains {
|
||||||
found := false
|
if m.isDomainInList(currentDomain, newDomains) {
|
||||||
for _, newDomain := range newDomains {
|
continue
|
||||||
if currentDomain.SafeString() == newDomain.SafeString() {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if !found {
|
|
||||||
removedDomains = append(removedDomains, currentDomain)
|
if m.isManagementDomain(currentDomain) {
|
||||||
if err := m.RemoveDomain(currentDomain); err != nil {
|
continue
|
||||||
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
|
}
|
||||||
}
|
|
||||||
|
removedDomains = append(removedDomains, currentDomain)
|
||||||
|
if err := m.RemoveDomain(currentDomain); err != nil {
|
||||||
|
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return removedDomains
|
||||||
|
}
|
||||||
|
|
||||||
|
// isDomainInList checks if domain exists in the list
|
||||||
|
func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool {
|
||||||
|
for _, d := range list {
|
||||||
|
if domain.SafeString() == d.SafeString() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isManagementDomain checks if domain is the protected management domain
|
||||||
|
func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
return m.managementDomain != nil && domain.SafeString() == m.managementDomain.SafeString()
|
||||||
|
}
|
||||||
|
|
||||||
|
// addNewDomains adds all new domains to the cache
|
||||||
|
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
||||||
for _, newDomain := range newDomains {
|
for _, newDomain := range newDomains {
|
||||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
if err := m.AddDomain(ctx, newDomain); err != nil {
|
||||||
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||||
@@ -375,12 +279,10 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
|
|||||||
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return removedDomains, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) []domain.Domain {
|
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List {
|
||||||
var domains []domain.Domain
|
var domains domain.List
|
||||||
|
|
||||||
if serverDomains.Signal != "" {
|
if serverDomains.Signal != "" {
|
||||||
domains = append(domains, serverDomains.Signal)
|
domains = append(domains, serverDomains.Signal)
|
||||||
@@ -411,164 +313,6 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
|||||||
return domains
|
return domains
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractDomainsFromConfig extracts all domains from a NetbirdConfig.
|
|
||||||
func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []domain.Domain {
|
|
||||||
if config == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var domains []domain.Domain
|
|
||||||
|
|
||||||
// Extract signal domain
|
|
||||||
domains = append(domains, m.extractSignalDomain(config)...)
|
|
||||||
|
|
||||||
// Extract relay domains
|
|
||||||
domains = append(domains, m.extractRelayDomains(config)...)
|
|
||||||
|
|
||||||
// Extract flow domain
|
|
||||||
domains = append(domains, m.extractFlowDomain(config)...)
|
|
||||||
|
|
||||||
// Extract STUN domains
|
|
||||||
domains = append(domains, m.extractSTUNDomains(config)...)
|
|
||||||
|
|
||||||
// Extract TURN domains
|
|
||||||
domains = append(domains, m.extractTURNDomains(config)...)
|
|
||||||
|
|
||||||
return domains
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Resolver) extractSignalDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
|
|
||||||
if config.Signal == nil || config.Signal.Uri == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil {
|
|
||||||
return []domain.Domain{d}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Resolver) extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
|
||||||
if config.Relay == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var domains []domain.Domain
|
|
||||||
for _, relayURL := range config.Relay.Urls {
|
|
||||||
if d, err := m.extractDomainFromURL(relayURL); err == nil {
|
|
||||||
domains = append(domains, d)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return domains
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Resolver) extractFlowDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
|
|
||||||
if config.Flow == nil || config.Flow.Url == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil {
|
|
||||||
return []domain.Domain{d}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Resolver) extractSTUNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
|
||||||
var domains []domain.Domain
|
|
||||||
for _, stun := range config.Stuns {
|
|
||||||
if stun != nil && stun.Uri != "" {
|
|
||||||
if d, err := m.extractDomainFromURL(stun.Uri); err == nil {
|
|
||||||
domains = append(domains, d)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return domains
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Resolver) extractTURNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
|
|
||||||
var domains []domain.Domain
|
|
||||||
for _, turn := range config.Turns {
|
|
||||||
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
|
|
||||||
if d, err := m.extractDomainFromURL(turn.HostConfig.Uri); err == nil {
|
|
||||||
domains = append(domains, d)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return domains
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractDomainFromSignalConfig extracts domain from signal configuration.
|
|
||||||
func (m *Resolver) extractDomainFromSignalConfig(signal *mgmProto.HostConfig) (domain.Domain, error) {
|
|
||||||
signalURL, err := url.Parse(signal.Uri)
|
|
||||||
if err != nil {
|
|
||||||
// If parsing fails, it might be a raw host:port, try adding a scheme
|
|
||||||
signalURL, err = url.Parse("https://" + signal.Uri)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return extractDomainFromURL(signalURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractDomainFromURL extracts domain from a URL string.
|
|
||||||
func (m *Resolver) extractDomainFromURL(urlStr string) (domain.Domain, error) {
|
|
||||||
parsedURL, err := url.Parse(urlStr)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return extractDomainFromURL(parsedURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// addStunDomains adds STUN server domains to cache.
|
|
||||||
func (m *Resolver) addStunDomains(ctx context.Context, stuns []*mgmProto.HostConfig) {
|
|
||||||
for _, stun := range stuns {
|
|
||||||
if stun == nil || stun.Uri == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
stunURL, err := url.Parse(stun.Uri)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse STUN URL %s: %v", stun.Uri, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := extractDomainFromURL(stunURL)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to extract STUN domain from %s: %v", stun.Uri, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.AddDomain(ctx, d); err != nil {
|
|
||||||
log.Warnf("failed to add STUN domain: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// addTurnDomains adds TURN server domains to cache.
|
|
||||||
func (m *Resolver) addTurnDomains(ctx context.Context, turns []*mgmProto.ProtectedHostConfig) {
|
|
||||||
for _, turn := range turns {
|
|
||||||
if turn == nil || turn.HostConfig == nil || turn.HostConfig.Uri == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
turnURL, err := url.Parse(turn.HostConfig.Uri)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := extractDomainFromURL(turnURL)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.AddDomain(ctx, d); err != nil {
|
|
||||||
log.Warnf("failed to add TURN domain: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractDomainFromURL extracts the domain from a URL.
|
// extractDomainFromURL extracts the domain from a URL.
|
||||||
func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
|
func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
|
||||||
if u == nil {
|
if u == nil {
|
||||||
|
@@ -2,22 +2,24 @@ package mgmt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestResolver_NewResolver(t *testing.T) {
|
func TestResolver_NewResolver(t *testing.T) {
|
||||||
resolver := NewResolver()
|
resolver := NewResolver()
|
||||||
|
|
||||||
assert.NotNil(t, resolver)
|
assert.NotNil(t, resolver)
|
||||||
assert.NotNil(t, resolver.cache)
|
assert.NotNil(t, resolver.records)
|
||||||
assert.True(t, resolver.MatchSubdomains())
|
assert.False(t, resolver.MatchSubdomains())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
||||||
@@ -113,145 +115,173 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
|
|||||||
|
|
||||||
resolver := NewResolver()
|
resolver := NewResolver()
|
||||||
|
|
||||||
// Use IP address to avoid DNS resolution timeout
|
// Test with IP address - should return error since IP addresses are rejected
|
||||||
mgmtURL, _ := url.Parse("https://127.0.0.1")
|
mgmtURL, _ := url.Parse("https://127.0.0.1")
|
||||||
|
|
||||||
err := resolver.PopulateFromConfig(ctx, mgmtURL)
|
err := resolver.PopulateFromConfig(ctx, mgmtURL)
|
||||||
assert.NoError(t, err)
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "host is an IP address")
|
||||||
|
|
||||||
// IP addresses are rejected, so no domains should be cached
|
// No domains should be cached when using IP addresses
|
||||||
domains := resolver.GetCachedDomains()
|
domains := resolver.GetCachedDomains()
|
||||||
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
|
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
|
func TestResolver_ServeDNS(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
resolver := NewResolver()
|
resolver := NewResolver()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
// Use IP addresses to avoid DNS resolution timeouts
|
// Add a test domain to the cache - use example.org which is reserved for testing
|
||||||
netbirdConfig := &mgmProto.NetbirdConfig{
|
testDomain, err := domain.FromString("example.org")
|
||||||
Signal: &mgmProto.HostConfig{
|
if err != nil {
|
||||||
Uri: "https://10.0.0.1",
|
t.Fatalf("Failed to create domain: %v", err)
|
||||||
},
|
}
|
||||||
Relay: &mgmProto.RelayConfig{
|
err = resolver.AddDomain(ctx, testDomain)
|
||||||
Urls: []string{
|
if err != nil {
|
||||||
"https://10.0.0.2:443",
|
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||||
"https://10.0.0.3:443",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Flow: &mgmProto.FlowConfig{
|
|
||||||
Url: "https://10.0.0.4:80",
|
|
||||||
},
|
|
||||||
Stuns: []*mgmProto.HostConfig{
|
|
||||||
{Uri: "stun:10.0.0.5:3478"},
|
|
||||||
{Uri: "stun:10.0.0.6:3478"},
|
|
||||||
},
|
|
||||||
Turns: []*mgmProto.ProtectedHostConfig{
|
|
||||||
{
|
|
||||||
HostConfig: &mgmProto.HostConfig{
|
|
||||||
Uri: "turn:10.0.0.7:3478",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
HostConfig: &mgmProto.HostConfig{
|
|
||||||
Uri: "turn:10.0.0.8:3478",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err := resolver.PopulateFromNetbirdConfig(ctx, netbirdConfig)
|
// Test A record query for cached domain
|
||||||
assert.NoError(t, err)
|
t.Run("Cached domain A record", func(t *testing.T) {
|
||||||
|
var capturedMsg *dns.Msg
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
capturedMsg = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// IP addresses are rejected, so no domains should be cached
|
req := new(dns.Msg)
|
||||||
domains := resolver.GetCachedDomains()
|
req.SetQuestion("example.org.", dns.TypeA)
|
||||||
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
|
|
||||||
|
resolver.ServeDNS(mockWriter, req)
|
||||||
|
|
||||||
|
assert.NotNil(t, capturedMsg)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
|
||||||
|
assert.True(t, len(capturedMsg.Answer) > 0, "Should have at least one answer")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test uncached domain signals to continue to next handler
|
||||||
|
t.Run("Uncached domain signals continue to next handler", func(t *testing.T) {
|
||||||
|
var capturedMsg *dns.Msg
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
capturedMsg = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := new(dns.Msg)
|
||||||
|
req.SetQuestion("unknown.example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
resolver.ServeDNS(mockWriter, req)
|
||||||
|
|
||||||
|
assert.NotNil(t, capturedMsg)
|
||||||
|
assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
|
||||||
|
// Zero flag set to true signals the handler chain to continue to next handler
|
||||||
|
assert.True(t, capturedMsg.MsgHdr.Zero, "Zero flag should be set to signal continuation to next handler")
|
||||||
|
assert.Empty(t, capturedMsg.Answer, "Should have no answers for uncached domain")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test that subdomains of cached domains are NOT resolved
|
||||||
|
t.Run("Subdomains of cached domains are not resolved", func(t *testing.T) {
|
||||||
|
var capturedMsg *dns.Msg
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
capturedMsg = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query for a subdomain of our cached domain
|
||||||
|
req := new(dns.Msg)
|
||||||
|
req.SetQuestion("sub.example.org.", dns.TypeA)
|
||||||
|
|
||||||
|
resolver.ServeDNS(mockWriter, req)
|
||||||
|
|
||||||
|
assert.NotNil(t, capturedMsg)
|
||||||
|
assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
|
||||||
|
assert.True(t, capturedMsg.MsgHdr.Zero, "Should signal continuation to next handler for subdomains")
|
||||||
|
assert.Empty(t, capturedMsg.Answer, "Should have no answers for subdomains")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case-insensitive matching
|
||||||
|
t.Run("Case-insensitive domain matching", func(t *testing.T) {
|
||||||
|
var capturedMsg *dns.Msg
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
capturedMsg = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query with different casing
|
||||||
|
req := new(dns.Msg)
|
||||||
|
req.SetQuestion("EXAMPLE.ORG.", dns.TypeA)
|
||||||
|
|
||||||
|
resolver.ServeDNS(mockWriter, req)
|
||||||
|
|
||||||
|
assert.NotNil(t, capturedMsg)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
|
||||||
|
assert.True(t, len(capturedMsg.Answer) > 0, "Should resolve regardless of case")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolver_UpdateFromNetbirdConfig(t *testing.T) {
|
func TestResolver_GetCachedDomains(t *testing.T) {
|
||||||
resolver := NewResolver()
|
resolver := NewResolver()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
// Test with empty initial config and then add domains
|
testDomain, err := domain.FromString("example.org")
|
||||||
initialConfig := &mgmProto.NetbirdConfig{}
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create domain: %v", err)
|
||||||
// Start with empty config
|
}
|
||||||
removedDomains, err := resolver.UpdateFromNetbirdConfig(context.Background(), initialConfig)
|
err = resolver.AddDomain(ctx, testDomain)
|
||||||
assert.NoError(t, err)
|
if err != nil {
|
||||||
assert.Equal(t, 0, len(removedDomains), "No domains should be removed from empty cache")
|
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||||
|
|
||||||
// Update to config with IP addresses instead of domains to avoid DNS resolution
|
|
||||||
// IP addresses will be rejected by extractDomainFromURL so no actual resolution happens
|
|
||||||
updatedConfig := &mgmProto.NetbirdConfig{
|
|
||||||
Signal: &mgmProto.HostConfig{
|
|
||||||
Uri: "https://127.0.0.1",
|
|
||||||
},
|
|
||||||
Flow: &mgmProto.FlowConfig{
|
|
||||||
Url: "https://192.168.1.1:80",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
removedDomains, err = resolver.UpdateFromNetbirdConfig(context.Background(), updatedConfig)
|
cachedDomains := resolver.GetCachedDomains()
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
// Verify the method completes successfully without DNS timeouts
|
assert.Equal(t, 1, len(cachedDomains), "Should return exactly one domain for single added domain")
|
||||||
assert.GreaterOrEqual(t, len(removedDomains), 0, "Should not error on config update")
|
assert.Equal(t, testDomain.SafeString(), cachedDomains[0].SafeString(), "Cached domain should match original")
|
||||||
|
assert.False(t, strings.HasSuffix(cachedDomains[0].PunycodeString(), "."), "Domain should not have trailing dot")
|
||||||
// Verify no domains were actually added since IPs are rejected
|
|
||||||
domains := resolver.GetCachedDomains()
|
|
||||||
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolver_ContinueToNext(t *testing.T) {
|
func TestResolver_ManagementDomainProtection(t *testing.T) {
|
||||||
resolver := NewResolver()
|
resolver := NewResolver()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
// Create a mock response writer to capture the response
|
mgmtURL, _ := url.Parse("https://example.org")
|
||||||
mockWriter := &MockResponseWriter{}
|
err := resolver.PopulateFromConfig(ctx, mgmtURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Create a test DNS query
|
initialDomains := resolver.GetCachedDomains()
|
||||||
req := new(dns.Msg)
|
if len(initialDomains) == 0 {
|
||||||
req.SetQuestion("unknown.example.com.", dns.TypeA)
|
t.Skip("Management domain failed to resolve, skipping test")
|
||||||
|
}
|
||||||
|
assert.Equal(t, 1, len(initialDomains), "Should have management domain cached")
|
||||||
|
assert.Equal(t, "example.org", initialDomains[0].SafeString())
|
||||||
|
|
||||||
// Call continueToNext
|
serverDomains := dnsconfig.ServerDomains{
|
||||||
resolver.continueToNext(mockWriter, req)
|
Signal: "google.com",
|
||||||
|
Relay: []domain.Domain{"cloudflare.com"},
|
||||||
|
}
|
||||||
|
|
||||||
// Verify the response
|
_, err = resolver.UpdateFromServerDomains(ctx, serverDomains)
|
||||||
assert.NotNil(t, mockWriter.msg)
|
if err != nil {
|
||||||
assert.Equal(t, dns.RcodeNameError, mockWriter.msg.Rcode)
|
t.Logf("Server domains update failed: %v", err)
|
||||||
assert.True(t, mockWriter.msg.MsgHdr.Zero)
|
}
|
||||||
|
|
||||||
|
finalDomains := resolver.GetCachedDomains()
|
||||||
|
|
||||||
|
managementStillCached := false
|
||||||
|
for _, d := range finalDomains {
|
||||||
|
if d.SafeString() == "example.org" {
|
||||||
|
managementStillCached = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, managementStillCached, "Management domain should never be removed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockResponseWriter is a simple mock implementation of dns.ResponseWriter for testing
|
|
||||||
type MockResponseWriter struct {
|
|
||||||
msg *dns.Msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockResponseWriter) LocalAddr() net.Addr {
|
|
||||||
return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 53}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockResponseWriter) RemoteAddr() net.Addr {
|
|
||||||
return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockResponseWriter) WriteMsg(msg *dns.Msg) error {
|
|
||||||
m.msg = msg
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockResponseWriter) Write([]byte) (int, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockResponseWriter) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockResponseWriter) TsigStatus() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockResponseWriter) TsigTimersOnly(bool) {}
|
|
||||||
|
|
||||||
func (m *MockResponseWriter) Hijack() {}
|
|
||||||
|
@@ -74,7 +74,6 @@ type DefaultServer struct {
|
|||||||
handlerChain *HandlerChain
|
handlerChain *HandlerChain
|
||||||
extraDomains map[domain.Domain]int
|
extraDomains map[domain.Domain]int
|
||||||
|
|
||||||
// management cache resolver for critical infrastructure domains
|
|
||||||
mgmtCacheResolver *mgmt.Resolver
|
mgmtCacheResolver *mgmt.Resolver
|
||||||
|
|
||||||
// permanent related properties
|
// permanent related properties
|
||||||
@@ -106,18 +105,15 @@ type registeredHandlerMap map[types.HandlerID]handlerWrapper
|
|||||||
|
|
||||||
// DefaultServerConfig holds configuration parameters for NewDefaultServer
|
// DefaultServerConfig holds configuration parameters for NewDefaultServer
|
||||||
type DefaultServerConfig struct {
|
type DefaultServerConfig struct {
|
||||||
Ctx context.Context
|
|
||||||
WgInterface WGIface
|
WgInterface WGIface
|
||||||
CustomAddress string
|
CustomAddress string
|
||||||
StatusRecorder *peer.Status
|
StatusRecorder *peer.Status
|
||||||
StateManager *statemanager.Manager
|
StateManager *statemanager.Manager
|
||||||
DisableSys bool
|
DisableSys bool
|
||||||
MgmtURL *url.URL
|
|
||||||
ServerDomains dnsconfig.ServerDomains
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(config DefaultServerConfig) (*DefaultServer, error) {
|
func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*DefaultServer, error) {
|
||||||
var addrPort *netip.AddrPort
|
var addrPort *netip.AddrPort
|
||||||
if config.CustomAddress != "" {
|
if config.CustomAddress != "" {
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress)
|
parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress)
|
||||||
@@ -134,20 +130,7 @@ func NewDefaultServer(config DefaultServerConfig) (*DefaultServer, error) {
|
|||||||
dnsService = newServiceViaListener(config.WgInterface, addrPort)
|
dnsService = newServiceViaListener(config.WgInterface, addrPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := newDefaultServer(config.Ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
|
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
|
||||||
|
|
||||||
if config.MgmtURL != nil && server.mgmtCacheResolver != nil {
|
|
||||||
if err := server.mgmtCacheResolver.PopulateFromConfig(config.Ctx, config.MgmtURL); err != nil {
|
|
||||||
log.Warnf("Failed to populate management cache from management URL: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if server.mgmtCacheResolver != nil {
|
|
||||||
if err := server.UpdateServerConfig(config.ServerDomains); err != nil {
|
|
||||||
log.Warnf("Failed to populate management cache from ServerDomains: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return server, nil
|
return server, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -197,7 +180,6 @@ func newDefaultServer(
|
|||||||
handlerChain := NewHandlerChain()
|
handlerChain := NewHandlerChain()
|
||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
|
|
||||||
// Create management cache resolver
|
|
||||||
mgmtCacheResolver := mgmt.NewResolver()
|
mgmtCacheResolver := mgmt.NewResolver()
|
||||||
|
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
@@ -345,20 +327,8 @@ func (s *DefaultServer) Stop() {
|
|||||||
maps.Clear(s.extraDomains)
|
maps.Clear(s.extraDomains)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PopulateMgmtCacheFromConfig populates the management cache with domains from the client configuration
|
|
||||||
func (s *DefaultServer) PopulateMgmtCacheFromConfig(mgmtURL *url.URL) error {
|
|
||||||
if s.mgmtCacheResolver == nil {
|
|
||||||
return fmt.Errorf("management cache resolver not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("populating management cache from client configuration")
|
|
||||||
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||||
|
|
||||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||||
s.hostsDNSHolder.set(hostsDnsList)
|
s.hostsDNSHolder.set(hostsDnsList)
|
||||||
|
|
||||||
@@ -465,12 +435,12 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(removedDomains) > 0 {
|
if len(removedDomains) > 0 {
|
||||||
s.DeregisterHandler(removedDomains, PriorityMgmtCache)
|
s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache)
|
||||||
}
|
}
|
||||||
|
|
||||||
newDomains := s.mgmtCacheResolver.GetCachedDomains()
|
newDomains := s.mgmtCacheResolver.GetCachedDomains()
|
||||||
if len(newDomains) > 0 {
|
if len(newDomains) > 0 {
|
||||||
s.RegisterHandler(newDomains, s.mgmtCacheResolver, PriorityMgmtCache)
|
s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -935,3 +905,11 @@ func toZone(d domain.Domain) domain.Domain {
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PopulateManagementDomain populates the DNS cache with management domain
|
||||||
|
func (s *DefaultServer) PopulateManagementDomain(ctx context.Context, mgmtURL *url.URL) error {
|
||||||
|
if s.mgmtCacheResolver != nil && mgmtURL != nil {
|
||||||
|
return s.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@@ -23,7 +23,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
@@ -364,15 +363,12 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
t.Log(err)
|
t.Log(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
dnsServer, err := NewDefaultServer(DefaultServerConfig{
|
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||||
Ctx: context.Background(),
|
|
||||||
WgInterface: wgIface,
|
WgInterface: wgIface,
|
||||||
CustomAddress: "",
|
CustomAddress: "",
|
||||||
StatusRecorder: peer.NewRecorder("mgm"),
|
StatusRecorder: peer.NewRecorder("mgm"),
|
||||||
StateManager: nil,
|
StateManager: nil,
|
||||||
DisableSys: false,
|
DisableSys: false,
|
||||||
MgmtURL: nil,
|
|
||||||
ServerDomains: dnsconfig.ServerDomains{},
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -483,15 +479,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer, err := NewDefaultServer(DefaultServerConfig{
|
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||||
Ctx: context.Background(),
|
|
||||||
WgInterface: wgIface,
|
WgInterface: wgIface,
|
||||||
CustomAddress: "",
|
CustomAddress: "",
|
||||||
StatusRecorder: peer.NewRecorder("mgm"),
|
StatusRecorder: peer.NewRecorder("mgm"),
|
||||||
StateManager: nil,
|
StateManager: nil,
|
||||||
DisableSys: false,
|
DisableSys: false,
|
||||||
MgmtURL: nil,
|
|
||||||
ServerDomains: dnsconfig.ServerDomains{},
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create DNS server: %v", err)
|
t.Errorf("create DNS server: %v", err)
|
||||||
@@ -594,15 +587,12 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
dnsServer, err := NewDefaultServer(DefaultServerConfig{
|
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
|
||||||
Ctx: context.Background(),
|
|
||||||
WgInterface: &mocWGIface{},
|
WgInterface: &mocWGIface{},
|
||||||
CustomAddress: testCase.addrPort,
|
CustomAddress: testCase.addrPort,
|
||||||
StatusRecorder: peer.NewRecorder("mgm"),
|
StatusRecorder: peer.NewRecorder("mgm"),
|
||||||
StateManager: nil,
|
StateManager: nil,
|
||||||
DisableSys: false,
|
DisableSys: false,
|
||||||
MgmtURL: nil,
|
|
||||||
ServerDomains: dnsconfig.ServerDomains{},
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%v", err)
|
t.Fatalf("%v", err)
|
||||||
|
@@ -33,7 +33,7 @@ import (
|
|||||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/config"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
@@ -126,12 +126,6 @@ type EngineConfig struct {
|
|||||||
BlockInbound bool
|
BlockInbound bool
|
||||||
|
|
||||||
LazyConnectionEnabled bool
|
LazyConnectionEnabled bool
|
||||||
|
|
||||||
// ManagementURL is the URL of the management server for DNS cache
|
|
||||||
ManagementURL *url.URL
|
|
||||||
|
|
||||||
// NetbirdConfig contains signal, relay, and flow server configuration
|
|
||||||
NetbirdConfig *mgmProto.NetbirdConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||||
@@ -350,7 +344,7 @@ func (e *Engine) Stop() error {
|
|||||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||||
// Connections to remote peers are not established here.
|
// Connections to remote peers are not established here.
|
||||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||||
func (e *Engine) Start() error {
|
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
@@ -395,13 +389,18 @@ func (e *Engine) Start() error {
|
|||||||
return fmt.Errorf("read initial settings: %w", err)
|
return fmt.Errorf("read initial settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer, err := e.newDnsServer(dnsConfig, e.config.ManagementURL, e.config.NetbirdConfig)
|
dnsServer, err := e.newDnsServer(dnsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
return fmt.Errorf("create dns server: %w", err)
|
return fmt.Errorf("create dns server: %w", err)
|
||||||
}
|
}
|
||||||
e.dnsServer = dnsServer
|
e.dnsServer = dnsServer
|
||||||
|
|
||||||
|
// Populate DNS cache with NetbirdConfig and management URL for early resolution
|
||||||
|
if err := e.PopulateNetbirdConfig(netbirdConfig, mgmtURL); err != nil {
|
||||||
|
log.Warnf("failed to populate DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
||||||
Context: e.ctx,
|
Context: e.ctx,
|
||||||
PublicKey: e.config.WgPrivateKey.PublicKey().String(),
|
PublicKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||||
@@ -666,6 +665,32 @@ func (e *Engine) removePeer(peerKey string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PopulateNetbirdConfig populates the DNS cache with infrastructure domains from login response
|
||||||
|
func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
|
||||||
|
if e.dnsServer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate management URL if provided
|
||||||
|
if mgmtURL != nil {
|
||||||
|
if defaultServer, ok := e.dnsServer.(*dns.DefaultServer); ok {
|
||||||
|
if err := defaultServer.PopulateManagementDomain(e.ctx, mgmtURL); err != nil {
|
||||||
|
log.Warnf("failed to populate DNS cache with management URL: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate NetbirdConfig domains if provided
|
||||||
|
if netbirdConfig != nil {
|
||||||
|
serverDomains := dnsconfig.ExtractFromNetbirdConfig(netbirdConfig)
|
||||||
|
if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
|
||||||
|
return fmt.Errorf("update DNS server config from NetbirdConfig: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
@@ -697,11 +722,8 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.dnsServer != nil {
|
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
|
||||||
serverDomains := config.ExtractFromNetbirdConfig(wCfg)
|
log.Warnf("Failed to update DNS server config: %v", err)
|
||||||
if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
|
|
||||||
log.Warnf("Failed to update DNS server config: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo update signal
|
// todo update signal
|
||||||
@@ -1587,7 +1609,7 @@ func (e *Engine) wgInterfaceCreate() (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbirdConfig *mgmProto.NetbirdConfig) (dns.Server, error) {
|
func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||||
// due to tests where we are using a mocked version of the DNS server
|
// due to tests where we are using a mocked version of the DNS server
|
||||||
if e.dnsServer != nil {
|
if e.dnsServer != nil {
|
||||||
return e.dnsServer, nil
|
return e.dnsServer, nil
|
||||||
@@ -1612,18 +1634,13 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbird
|
|||||||
return dnsServer, nil
|
return dnsServer, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// Extract domains from NetBird configuration
|
|
||||||
serverDomains := config.ExtractFromNetbirdConfig(netbirdConfig)
|
dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{
|
||||||
|
|
||||||
dnsServer, err := dns.NewDefaultServer(dns.DefaultServerConfig{
|
|
||||||
Ctx: e.ctx,
|
|
||||||
WgInterface: e.wgInterface,
|
WgInterface: e.wgInterface,
|
||||||
CustomAddress: e.config.CustomDNSAddress,
|
CustomAddress: e.config.CustomDNSAddress,
|
||||||
StatusRecorder: e.statusRecorder,
|
StatusRecorder: e.statusRecorder,
|
||||||
StateManager: e.stateManager,
|
StateManager: e.stateManager,
|
||||||
DisableSys: e.config.DisableDNS,
|
DisableSys: e.config.DisableDNS,
|
||||||
MgmtURL: mgmtURL,
|
|
||||||
ServerDomains: serverDomains,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1643,11 +1660,6 @@ func (e *Engine) GetFirewallManager() firewallManager.Manager {
|
|||||||
return e.firewall
|
return e.firewall
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDNSServer returns the DNS server
|
|
||||||
func (e *Engine) GetDNSServer() dns.Server {
|
|
||||||
return e.dnsServer
|
|
||||||
}
|
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
iface, err := net.InterfaceByName(ifaceName)
|
iface, err := net.InterfaceByName(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@@ -261,7 +261,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
err = engine.Start()
|
err = engine.Start(nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -605,7 +605,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = engine.Start()
|
err = engine.Start(nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
@@ -1060,7 +1060,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
|
|||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
guid := fmt.Sprintf("{%s}", uuid.New().String())
|
||||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||||
err = engine.Start()
|
err = engine.Start(nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
t.Errorf("unable to start engine for peer %d with error %v", j, err)
|
||||||
wg.Done()
|
wg.Done()
|
||||||
|
@@ -39,7 +39,7 @@ func IsLoginRequired(ctx context.Context, config *Config) (bool, error) {
|
|||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
||||||
if isLoginNeeded(err) {
|
if isLoginNeeded(err) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
@@ -68,14 +68,18 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
||||||
if serverKey != nil && isRegistrationNeeded(err) {
|
if serverKey != nil && isRegistrationNeeded(err) {
|
||||||
log.Debugf("peer registration required")
|
log.Debugf("peer registration required")
|
||||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
|
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
||||||
@@ -100,11 +104,11 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm
|
|||||||
return mgmClient, err
|
return mgmClient, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *Config) (*wgtypes.Key, error) {
|
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
serverKey, err := mgmClient.GetServerPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
sysInfo := system.GetInfo(ctx)
|
||||||
@@ -120,8 +124,8 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
|
|||||||
config.BlockInbound,
|
config.BlockInbound,
|
||||||
config.LazyConnectionEnabled,
|
config.LazyConnectionEnabled,
|
||||||
)
|
)
|
||||||
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
return serverKey, err
|
return serverKey, loginResp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||||
|
Reference in New Issue
Block a user