mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-18 11:00:06 +02:00
Cache management domains
This commit is contained in:
@@ -259,7 +259,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
|
|
||||||
peerConfig := loginResp.GetPeerConfig()
|
peerConfig := loginResp.GetPeerConfig()
|
||||||
|
|
||||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
|
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, loginResp.GetNetbirdConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
@@ -413,7 +413,7 @@ func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, netbirdConfig *mgmProto.NetbirdConfig) (*EngineConfig, error) {
|
||||||
nm := false
|
nm := false
|
||||||
if config.NetworkMonitor != nil {
|
if config.NetworkMonitor != nil {
|
||||||
nm = *config.NetworkMonitor
|
nm = *config.NetworkMonitor
|
||||||
@@ -442,6 +442,8 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
|||||||
BlockInbound: config.BlockInbound,
|
BlockInbound: config.BlockInbound,
|
||||||
|
|
||||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||||
|
ManagementURL: config.ManagementURL,
|
||||||
|
NetbirdConfig: netbirdConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PreSharedKey != "" {
|
if config.PreSharedKey != "" {
|
||||||
|
@@ -11,6 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
PriorityMgmtCache = 150
|
||||||
PriorityLocal = 100
|
PriorityLocal = 100
|
||||||
PriorityDNSRoute = 75
|
PriorityDNSRoute = 75
|
||||||
PriorityUpstream = 50
|
PriorityUpstream = 50
|
||||||
|
@@ -34,7 +34,7 @@ func (d *Resolver) MatchSubdomains() bool {
|
|||||||
|
|
||||||
// String returns a string representation of the local resolver
|
// String returns a string representation of the local resolver
|
||||||
func (d *Resolver) String() string {
|
func (d *Resolver) String() string {
|
||||||
return fmt.Sprintf("local resolver [%d records]", len(d.records))
|
return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Resolver) Stop() {}
|
func (d *Resolver) Stop() {}
|
||||||
|
504
client/internal/dns/mgmt/mgmt.go
Normal file
504
client/internal/dns/mgmt/mgmt.go
Normal file
@@ -0,0 +1,504 @@
|
|||||||
|
package mgmt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CacheEntry holds DNS records for a cached domain
|
||||||
|
type CacheEntry struct {
|
||||||
|
ARecords []dns.RR
|
||||||
|
AAAARecords []dns.RR
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolver caches critical NetBird infrastructure domains
|
||||||
|
type Resolver struct {
|
||||||
|
cache map[domain.Domain]CacheEntry
|
||||||
|
mutex sync.RWMutex
|
||||||
|
systemResolver *net.Resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResolver creates a new management domains cache resolver.
|
||||||
|
func NewResolver() *Resolver {
|
||||||
|
return &Resolver{
|
||||||
|
cache: make(map[domain.Domain]CacheEntry),
|
||||||
|
systemResolver: net.DefaultResolver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the resolver.
|
||||||
|
func (m *Resolver) String() string {
|
||||||
|
return "MgmtCacheResolver"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeDNS implements dns.Handler interface.
|
||||||
|
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
if len(r.Question) == 0 {
|
||||||
|
m.continueToNext(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
question := r.Question[0]
|
||||||
|
qname := strings.ToLower(strings.TrimSuffix(question.Name, "."))
|
||||||
|
|
||||||
|
if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA {
|
||||||
|
m.continueToNext(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("MgmtCache: checking cache for domain=%s type=%s", qname, dns.TypeToString[question.Qtype])
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
parsedDomain, err := domain.FromString(qname)
|
||||||
|
if err != nil {
|
||||||
|
log.Tracef("MgmtCache: invalid domain format: %s", qname)
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
m.continueToNext(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, found := m.cache[parsedDomain]
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
log.Tracef("MgmtCache: no cache entry found for domain=%s", qname)
|
||||||
|
m.continueToNext(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetReply(r)
|
||||||
|
resp.Authoritative = false
|
||||||
|
resp.RecursionAvailable = true
|
||||||
|
|
||||||
|
var records []dns.RR
|
||||||
|
if question.Qtype == dns.TypeA {
|
||||||
|
records = entry.ARecords
|
||||||
|
} else if question.Qtype == dns.TypeAAAA {
|
||||||
|
records = entry.AAAARecords
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(records) == 0 {
|
||||||
|
log.Tracef("MgmtCache: no %s records for domain=%s", dns.TypeToString[question.Qtype], parsedDomain.SafeString())
|
||||||
|
m.continueToNext(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rr := range records {
|
||||||
|
rrCopy := dns.Copy(rr)
|
||||||
|
rrCopy.Header().Name = question.Name
|
||||||
|
resp.Answer = append(resp.Answer, rrCopy)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("MgmtCache: serving %d cached records for domain=%s", len(resp.Answer), parsedDomain.SafeString())
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("MgmtCache: failed to write response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatchSubdomains always returns true as required by the interface.
|
||||||
|
func (m *Resolver) MatchSubdomains() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// continueToNext signals the handler chain to continue to the next handler.
|
||||||
|
func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
resp.MsgHdr.Zero = true
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("MgmtCache: failed to write continue signal: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddDomain manually adds a domain to cache by resolving it.
|
||||||
|
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||||
|
log.Debugf("MgmtCache: adding domain=%s to cache", d.SafeString())
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var aRecords, aaaaRecords []dns.RR
|
||||||
|
|
||||||
|
if ips, err := m.systemResolver.LookupNetIP(ctx, "ip", d.PunycodeString()); err == nil {
|
||||||
|
for _, ip := range ips {
|
||||||
|
if ip.Is4() {
|
||||||
|
rr := &dns.A{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: d.PunycodeString() + ".",
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 300,
|
||||||
|
},
|
||||||
|
A: ip.AsSlice(),
|
||||||
|
}
|
||||||
|
aRecords = append(aRecords, rr)
|
||||||
|
} else if ip.Is6() {
|
||||||
|
rr := &dns.AAAA{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: d.PunycodeString() + ".",
|
||||||
|
Rrtype: dns.TypeAAAA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 300,
|
||||||
|
},
|
||||||
|
AAAA: ip.AsSlice(),
|
||||||
|
}
|
||||||
|
aaaaRecords = append(aaaaRecords, rr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.cache[d] = CacheEntry{
|
||||||
|
ARecords: aRecords,
|
||||||
|
AAAARecords: aaaaRecords,
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
log.Debugf("MgmtCache: added domain=%s with %d A records and %d AAAA records",
|
||||||
|
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||||
|
} else {
|
||||||
|
log.Warnf("MgmtCache: failed to resolve domain=%s: %v", d.SafeString(), err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||||
|
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
|
||||||
|
if mgmtURL != nil {
|
||||||
|
if d, err := extractDomainFromURL(mgmtURL); err == nil {
|
||||||
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to add management domain: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PopulateFromNetbirdConfig extracts and caches domains from the netbird config.
|
||||||
|
func (m *Resolver) PopulateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) error {
|
||||||
|
if config == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.addSignalDomain(ctx, config.Signal)
|
||||||
|
m.addRelayDomains(ctx, config.Relay)
|
||||||
|
m.addFlowDomain(ctx, config.Flow)
|
||||||
|
m.addStunDomains(ctx, config.Stuns)
|
||||||
|
m.addTurnDomains(ctx, config.Turns)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addSignalDomain adds signal server domain to cache.
|
||||||
|
func (m *Resolver) addSignalDomain(ctx context.Context, signal *mgmProto.HostConfig) {
|
||||||
|
if signal == nil || signal.Uri == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
signalURL, err := url.Parse(signal.Uri)
|
||||||
|
if err != nil {
|
||||||
|
// If parsing fails, it might be a raw host:port, try adding a scheme
|
||||||
|
signalURL, err = url.Parse("https://" + signal.Uri)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to parse signal URL: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := extractDomainFromURL(signalURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to extract signal domain: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to add signal domain: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRelayDomains adds relay server domains to cache.
|
||||||
|
func (m *Resolver) addRelayDomains(ctx context.Context, relay *mgmProto.RelayConfig) {
|
||||||
|
if relay == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, relayAddr := range relay.Urls {
|
||||||
|
relayURL, err := url.Parse(relayAddr)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to parse relay URL %s: %v", relayAddr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := extractDomainFromURL(relayURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to extract relay domain from %s: %v", relayAddr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to add relay domain: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addFlowDomain adds traffic flow server domain to cache.
|
||||||
|
func (m *Resolver) addFlowDomain(ctx context.Context, flow *mgmProto.FlowConfig) {
|
||||||
|
if flow == nil || flow.Url == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
flowURL, err := url.Parse(flow.Url)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to parse flow URL: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := extractDomainFromURL(flowURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to extract flow domain: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to add flow domain: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCachedDomains returns a list of all cached domains.
|
||||||
|
func (m *Resolver) GetCachedDomains() []domain.Domain {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
|
domains := make([]domain.Domain, 0, len(m.cache))
|
||||||
|
for d := range m.cache {
|
||||||
|
domains = append(domains, d)
|
||||||
|
}
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearCache removes all cached domains and returns them for external deregistration.
|
||||||
|
func (m *Resolver) ClearCache() []domain.Domain {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
domains := make([]domain.Domain, 0, len(m.cache))
|
||||||
|
for d := range m.cache {
|
||||||
|
domains = append(domains, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.cache = make(map[domain.Domain]CacheEntry)
|
||||||
|
log.Debugf("MgmtCache: cleared %d cached domains", len(domains))
|
||||||
|
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateFromNetbirdConfig updates the cache intelligently by comparing current and new configurations.
|
||||||
|
// Returns domains that were removed for external deregistration.
|
||||||
|
func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) ([]domain.Domain, error) {
|
||||||
|
log.Debugf("MgmtCache: updating cache from NetbirdConfig")
|
||||||
|
|
||||||
|
currentDomains := m.GetCachedDomains()
|
||||||
|
newDomains := m.extractDomainsFromConfig(config)
|
||||||
|
|
||||||
|
var removedDomains []domain.Domain
|
||||||
|
for _, currentDomain := range currentDomains {
|
||||||
|
found := false
|
||||||
|
for _, newDomain := range newDomains {
|
||||||
|
if currentDomain.SafeString() == newDomain.SafeString() {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
removedDomains = append(removedDomains, currentDomain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
for _, domainToRemove := range removedDomains {
|
||||||
|
delete(m.cache, domainToRemove)
|
||||||
|
log.Debugf("MgmtCache: removed domain=%s from cache", domainToRemove.SafeString())
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
for _, newDomain := range newDomains {
|
||||||
|
if err := m.AddDomain(ctx, newDomain); err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return removedDomains, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractDomainsFromConfig extracts all domains from a NetbirdConfig.
|
||||||
|
func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []domain.Domain {
|
||||||
|
if config == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var domains []domain.Domain
|
||||||
|
|
||||||
|
if config.Signal != nil && config.Signal.Uri != "" {
|
||||||
|
if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil {
|
||||||
|
domains = append(domains, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Relay != nil {
|
||||||
|
for _, relayURL := range config.Relay.Urls {
|
||||||
|
if d, err := m.extractDomainFromURL(relayURL); err == nil {
|
||||||
|
domains = append(domains, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Flow != nil && config.Flow.Url != "" {
|
||||||
|
if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil {
|
||||||
|
domains = append(domains, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, stun := range config.Stuns {
|
||||||
|
if stun != nil && stun.Uri != "" {
|
||||||
|
if d, err := m.extractDomainFromURL(stun.Uri); err == nil {
|
||||||
|
domains = append(domains, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, turn := range config.Turns {
|
||||||
|
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
|
||||||
|
if d, err := m.extractDomainFromURL(turn.HostConfig.Uri); err == nil {
|
||||||
|
domains = append(domains, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractDomainFromSignalConfig extracts domain from signal configuration.
|
||||||
|
func (m *Resolver) extractDomainFromSignalConfig(signal *mgmProto.HostConfig) (domain.Domain, error) {
|
||||||
|
signalURL, err := url.Parse(signal.Uri)
|
||||||
|
if err != nil {
|
||||||
|
// If parsing fails, it might be a raw host:port, try adding a scheme
|
||||||
|
signalURL, err = url.Parse("https://" + signal.Uri)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return extractDomainFromURL(signalURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractDomainFromURL extracts domain from a URL string.
|
||||||
|
func (m *Resolver) extractDomainFromURL(urlStr string) (domain.Domain, error) {
|
||||||
|
parsedURL, err := url.Parse(urlStr)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return extractDomainFromURL(parsedURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addStunDomains adds STUN server domains to cache.
|
||||||
|
func (m *Resolver) addStunDomains(ctx context.Context, stuns []*mgmProto.HostConfig) {
|
||||||
|
for _, stun := range stuns {
|
||||||
|
if stun == nil || stun.Uri == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
stunURL, err := url.Parse(stun.Uri)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to parse STUN URL %s: %v", stun.Uri, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := extractDomainFromURL(stunURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to extract STUN domain from %s: %v", stun.Uri, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to add STUN domain: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addTurnDomains adds TURN server domains to cache.
|
||||||
|
func (m *Resolver) addTurnDomains(ctx context.Context, turns []*mgmProto.ProtectedHostConfig) {
|
||||||
|
for _, turn := range turns {
|
||||||
|
if turn == nil || turn.HostConfig == nil || turn.HostConfig.Uri == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
turnURL, err := url.Parse(turn.HostConfig.Uri)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := extractDomainFromURL(turnURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.AddDomain(ctx, d); err != nil {
|
||||||
|
log.Warnf("MgmtCache: failed to add TURN domain: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractDomainFromURL extracts the domain from a URL.
|
||||||
|
func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
|
||||||
|
if u == nil {
|
||||||
|
return "", errors.New("invalid URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
host := u.Host
|
||||||
|
// If Host is empty, try to extract from Opaque (for schemes like stun:domain:port)
|
||||||
|
if host == "" && u.Opaque != "" {
|
||||||
|
host = u.Opaque
|
||||||
|
}
|
||||||
|
if host == "" && u.Path != "" {
|
||||||
|
host = strings.TrimPrefix(u.Path, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
if host == "" {
|
||||||
|
return "", errors.New("empty host")
|
||||||
|
}
|
||||||
|
|
||||||
|
host, _, err := net.SplitHostPort(host)
|
||||||
|
if err != nil {
|
||||||
|
switch {
|
||||||
|
case u.Host != "":
|
||||||
|
host = u.Host
|
||||||
|
case u.Opaque != "":
|
||||||
|
host = u.Opaque
|
||||||
|
default:
|
||||||
|
host = strings.TrimPrefix(u.Path, "/")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := netip.ParseAddr(host); err == nil {
|
||||||
|
return "", errors.New("host is an IP address, skipping")
|
||||||
|
}
|
||||||
|
|
||||||
|
return domain.FromString(host)
|
||||||
|
}
|
227
client/internal/dns/mgmt/mgmt_test.go
Normal file
227
client/internal/dns/mgmt/mgmt_test.go
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
package mgmt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResolver_NewResolver(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
assert.NotNil(t, resolver)
|
||||||
|
assert.NotNil(t, resolver.cache)
|
||||||
|
assert.True(t, resolver.MatchSubdomains())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
urlStr string
|
||||||
|
expectedDom string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "HTTPS URL with port",
|
||||||
|
urlStr: "https://api.netbird.io:443",
|
||||||
|
expectedDom: "api.netbird.io",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTP URL without port",
|
||||||
|
urlStr: "http://signal.example.com",
|
||||||
|
expectedDom: "signal.example.com",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL with path",
|
||||||
|
urlStr: "https://relay.netbird.io/status",
|
||||||
|
expectedDom: "relay.netbird.io",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid URL",
|
||||||
|
urlStr: "not-a-valid-url",
|
||||||
|
expectedDom: "not-a-valid-url",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty URL",
|
||||||
|
urlStr: "",
|
||||||
|
expectedDom: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "STUN URL",
|
||||||
|
urlStr: "stun:stun.example.com:3478",
|
||||||
|
expectedDom: "stun.example.com",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TURN URL",
|
||||||
|
urlStr: "turn:turn.example.com:3478",
|
||||||
|
expectedDom: "turn.example.com",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "REL URL",
|
||||||
|
urlStr: "rel://relay.example.com:443",
|
||||||
|
expectedDom: "relay.example.com",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RELS URL",
|
||||||
|
urlStr: "rels://relay.example.com:443",
|
||||||
|
expectedDom: "relay.example.com",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var parsedURL *url.URL
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if tt.urlStr != "" {
|
||||||
|
parsedURL, err = url.Parse(tt.urlStr)
|
||||||
|
if err != nil && !tt.expectError {
|
||||||
|
t.Fatalf("Failed to parse URL: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
domain, err := extractDomainFromURL(parsedURL)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expectedDom, domain.SafeString())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_PopulateFromConfig(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
mgmtURL, _ := url.Parse("https://api.netbird.io")
|
||||||
|
|
||||||
|
err := resolver.PopulateFromConfig(ctx, mgmtURL)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Give some time for async population
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
domains := resolver.GetCachedDomains()
|
||||||
|
assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
netbirdConfig := &mgmProto.NetbirdConfig{
|
||||||
|
Signal: &mgmProto.HostConfig{
|
||||||
|
Uri: "https://signal.netbird.io",
|
||||||
|
},
|
||||||
|
Relay: &mgmProto.RelayConfig{
|
||||||
|
Urls: []string{
|
||||||
|
"https://relay1.netbird.io:443",
|
||||||
|
"https://relay2.netbird.io:443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Flow: &mgmProto.FlowConfig{
|
||||||
|
Url: "https://flow.netbird.io:80",
|
||||||
|
},
|
||||||
|
Stuns: []*mgmProto.HostConfig{
|
||||||
|
{Uri: "stun:stun1.netbird.io:3478"},
|
||||||
|
{Uri: "stun:stun2.netbird.io:3478"},
|
||||||
|
},
|
||||||
|
Turns: []*mgmProto.ProtectedHostConfig{
|
||||||
|
{
|
||||||
|
HostConfig: &mgmProto.HostConfig{
|
||||||
|
Uri: "turn:turn1.netbird.io:3478",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
HostConfig: &mgmProto.HostConfig{
|
||||||
|
Uri: "turn:turn2.netbird.io:3478",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := resolver.PopulateFromNetbirdConfig(ctx, netbirdConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Give some time for async population
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
domains := resolver.GetCachedDomains()
|
||||||
|
assert.GreaterOrEqual(t, len(domains), 0) // Domains might not be cached yet due to async nature
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_ContinueToNext(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
// Create a mock response writer to capture the response
|
||||||
|
mockWriter := &MockResponseWriter{}
|
||||||
|
|
||||||
|
// Create a test DNS query
|
||||||
|
req := new(dns.Msg)
|
||||||
|
req.SetQuestion("unknown.example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
// Call continueToNext
|
||||||
|
resolver.continueToNext(mockWriter, req)
|
||||||
|
|
||||||
|
// Verify the response
|
||||||
|
assert.NotNil(t, mockWriter.msg)
|
||||||
|
assert.Equal(t, dns.RcodeNameError, mockWriter.msg.Rcode)
|
||||||
|
assert.True(t, mockWriter.msg.MsgHdr.Zero)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockResponseWriter is a simple mock implementation of dns.ResponseWriter for testing
|
||||||
|
type MockResponseWriter struct {
|
||||||
|
msg *dns.Msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResponseWriter) LocalAddr() net.Addr {
|
||||||
|
return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 53}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResponseWriter) RemoteAddr() net.Addr {
|
||||||
|
return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResponseWriter) WriteMsg(msg *dns.Msg) error {
|
||||||
|
m.msg = msg
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResponseWriter) Write([]byte) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResponseWriter) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResponseWriter) TsigStatus() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResponseWriter) TsigTimersOnly(bool) {}
|
||||||
|
|
||||||
|
func (m *MockResponseWriter) Hijack() {}
|
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -16,6 +17,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/local"
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/mgmt"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -23,6 +25,7 @@ import (
|
|||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
@@ -70,6 +73,9 @@ type DefaultServer struct {
|
|||||||
handlerChain *HandlerChain
|
handlerChain *HandlerChain
|
||||||
extraDomains map[domain.Domain]int
|
extraDomains map[domain.Domain]int
|
||||||
|
|
||||||
|
// management cache resolver for critical infrastructure domains
|
||||||
|
mgmtCacheResolver *mgmt.Resolver
|
||||||
|
|
||||||
// permanent related properties
|
// permanent related properties
|
||||||
permanent bool
|
permanent bool
|
||||||
hostsDNSHolder *hostsDNSHolder
|
hostsDNSHolder *hostsDNSHolder
|
||||||
@@ -105,6 +111,8 @@ func NewDefaultServer(
|
|||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
stateManager *statemanager.Manager,
|
stateManager *statemanager.Manager,
|
||||||
disableSys bool,
|
disableSys bool,
|
||||||
|
mgmtURL *url.URL,
|
||||||
|
netbirdConfig *mgmProto.NetbirdConfig,
|
||||||
) (*DefaultServer, error) {
|
) (*DefaultServer, error) {
|
||||||
var addrPort *netip.AddrPort
|
var addrPort *netip.AddrPort
|
||||||
if customAddress != "" {
|
if customAddress != "" {
|
||||||
@@ -122,7 +130,29 @@ func NewDefaultServer(
|
|||||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil
|
server := newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys)
|
||||||
|
|
||||||
|
// Pre-populate management cache with management URL
|
||||||
|
if mgmtURL != nil && server.mgmtCacheResolver != nil {
|
||||||
|
if err := server.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL); err != nil {
|
||||||
|
log.Warnf("Failed to populate management cache from management URL: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-populate management cache with NetbirdConfig domains
|
||||||
|
if netbirdConfig != nil && server.mgmtCacheResolver != nil {
|
||||||
|
if err := server.mgmtCacheResolver.PopulateFromNetbirdConfig(ctx, netbirdConfig); err != nil {
|
||||||
|
log.Warnf("Failed to populate management cache from NetbirdConfig: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register newly populated domains
|
||||||
|
domains := server.mgmtCacheResolver.GetCachedDomains()
|
||||||
|
if len(domains) > 0 {
|
||||||
|
server.RegisterHandler(domains, server.mgmtCacheResolver, PriorityMgmtCache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return server, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||||
@@ -170,6 +200,10 @@ func newDefaultServer(
|
|||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
handlerChain := NewHandlerChain()
|
handlerChain := NewHandlerChain()
|
||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
// Create management cache resolver
|
||||||
|
mgmtCacheResolver := mgmt.NewResolver()
|
||||||
|
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: stop,
|
ctxCancel: stop,
|
||||||
@@ -183,8 +217,22 @@ func newDefaultServer(
|
|||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
stateManager: stateManager,
|
stateManager: stateManager,
|
||||||
hostsDNSHolder: newHostsDNSHolder(),
|
hostsDNSHolder: newHostsDNSHolder(),
|
||||||
|
mgmtCacheResolver: mgmtCacheResolver,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Register cached domains with the handler chain
|
||||||
|
registerMgmtCacheDomains := func() {
|
||||||
|
domains := mgmtCacheResolver.GetCachedDomains()
|
||||||
|
if len(domains) > 0 {
|
||||||
|
defaultServer.RegisterHandler(domains, mgmtCacheResolver, PriorityMgmtCache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register any pre-populated domains from management cache
|
||||||
|
registerMgmtCacheDomains()
|
||||||
|
|
||||||
|
// Management cache resolver will be registered for specific domains when they are added
|
||||||
|
|
||||||
// register with root zone, handler chain takes care of the routing
|
// register with root zone, handler chain takes care of the routing
|
||||||
dnsService.RegisterMux(".", handlerChain)
|
dnsService.RegisterMux(".", handlerChain)
|
||||||
|
|
||||||
@@ -208,7 +256,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
log.Debugf("registering handler %s with priority %d", handler, priority)
|
log.Debugf("registering handler %s with priority %d for %v", handler, priority, domains)
|
||||||
|
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
if domain == "" {
|
if domain == "" {
|
||||||
@@ -236,7 +284,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||||
log.Debugf("deregistering handler %v with priority %d", domains, priority)
|
log.Debugf("deregistering handler with priority %d for %v", priority, domains)
|
||||||
|
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
if domain == "" {
|
if domain == "" {
|
||||||
@@ -304,11 +352,32 @@ func (s *DefaultServer) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
s.service.Stop()
|
s.service.Stop()
|
||||||
|
|
||||||
maps.Clear(s.extraDomains)
|
maps.Clear(s.extraDomains)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PopulateMgmtCacheFromConfig populates the management cache with domains from the client configuration
|
||||||
|
func (s *DefaultServer) PopulateMgmtCacheFromConfig(mgmtURL *url.URL) error {
|
||||||
|
if s.mgmtCacheResolver == nil {
|
||||||
|
return fmt.Errorf("management cache resolver not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("populating management cache from client configuration")
|
||||||
|
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PopulateMgmtCacheFromNetbirdConfig populates the management cache with domains from the netbird configuration
|
||||||
|
func (s *DefaultServer) PopulateMgmtCacheFromNetbirdConfig(config *mgmProto.NetbirdConfig) error {
|
||||||
|
if s.mgmtCacheResolver == nil {
|
||||||
|
return fmt.Errorf("management cache resolver not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("populating management cache from netbird configuration")
|
||||||
|
return s.mgmtCacheResolver.PopulateFromNetbirdConfig(s.ctx, config)
|
||||||
|
}
|
||||||
|
|
||||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||||
|
|
||||||
|
@@ -363,7 +363,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
t.Log(err)
|
t.Log(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false)
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -473,7 +473,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false)
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create DNS server: %v", err)
|
t.Errorf("create DNS server: %v", err)
|
||||||
return
|
return
|
||||||
@@ -575,7 +575,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false)
|
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%v", err)
|
t.Fatalf("%v", err)
|
||||||
}
|
}
|
||||||
|
@@ -75,7 +75,7 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
|
|||||||
|
|
||||||
// String returns a string representation of the upstream resolver
|
// String returns a string representation of the upstream resolver
|
||||||
func (u *upstreamResolverBase) String() string {
|
func (u *upstreamResolverBase) String() string {
|
||||||
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
return fmt.Sprintf("Upstream %v", u.upstreamServers)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the unique handler ID
|
// ID returns the unique handler ID
|
||||||
|
@@ -7,6 +7,7 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -124,6 +125,12 @@ type EngineConfig struct {
|
|||||||
BlockInbound bool
|
BlockInbound bool
|
||||||
|
|
||||||
LazyConnectionEnabled bool
|
LazyConnectionEnabled bool
|
||||||
|
|
||||||
|
// ManagementURL is the URL of the management server for DNS cache
|
||||||
|
ManagementURL *url.URL
|
||||||
|
|
||||||
|
// NetbirdConfig contains signal, relay, and flow server configuration
|
||||||
|
NetbirdConfig *mgmProto.NetbirdConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||||
@@ -387,7 +394,7 @@ func (e *Engine) Start() error {
|
|||||||
return fmt.Errorf("read initial settings: %w", err)
|
return fmt.Errorf("read initial settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer, err := e.newDnsServer(dnsConfig)
|
dnsServer, err := e.newDnsServer(dnsConfig, e.config.ManagementURL, e.config.NetbirdConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
return fmt.Errorf("create dns server: %w", err)
|
return fmt.Errorf("create dns server: %w", err)
|
||||||
@@ -1572,7 +1579,7 @@ func (e *Engine) wgInterfaceCreate() (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbirdConfig *mgmProto.NetbirdConfig) (dns.Server, error) {
|
||||||
// due to tests where we are using a mocked version of the DNS server
|
// due to tests where we are using a mocked version of the DNS server
|
||||||
if e.dnsServer != nil {
|
if e.dnsServer != nil {
|
||||||
return e.dnsServer, nil
|
return e.dnsServer, nil
|
||||||
@@ -1597,7 +1604,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
|||||||
return dnsServer, nil
|
return dnsServer, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
|
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS, mgmtURL, netbirdConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1616,6 +1623,11 @@ func (e *Engine) GetFirewallManager() firewallManager.Manager {
|
|||||||
return e.firewall
|
return e.firewall
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDNSServer returns the DNS server
|
||||||
|
func (e *Engine) GetDNSServer() dns.Server {
|
||||||
|
return e.dnsServer
|
||||||
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
iface, err := net.InterfaceByName(ifaceName)
|
iface, err := net.InterfaceByName(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Reference in New Issue
Block a user