This commit is contained in:
Viktor Liu
2025-07-11 11:06:19 +02:00
parent 629757c911
commit 90bf1baec2
10 changed files with 504 additions and 583 deletions

View File

@@ -272,11 +272,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
c.engineMutex.Unlock()
if err := c.engine.Start(); err != nil {
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
@@ -442,8 +443,6 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
BlockInbound: config.BlockInbound,
LazyConnectionEnabled: config.LazyConnectionEnabled,
ManagementURL: config.ManagementURL,
NetbirdConfig: netbirdConfig,
}
if config.PreSharedKey != "" {

View File

@@ -5,6 +5,7 @@ import (
"net"
"net/netip"
"net/url"
"strings"
log "github.com/sirupsen/logrus"
@@ -40,19 +41,34 @@ func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
// extractValidDomain extracts a valid domain from a URL, filtering out IP addresses
func extractValidDomain(rawURL string) (domain.Domain, error) {
parsedURL, err := url.Parse(rawURL)
if err != nil {
// If URL parsing fails, it might be a raw host:port, try parsing as such
if host, _, err := net.SplitHostPort(rawURL); err == nil {
return extractDomainFromHost(host)
}
// If not host:port, try as raw hostname
return extractDomainFromHost(rawURL)
if rawURL == "" {
return "", fmt.Errorf("empty URL")
}
host := parsedURL.Hostname()
if host == "" {
return "", fmt.Errorf("no hostname in URL")
// Try standard URL parsing first (handles https://, http://, rels://, etc.)
if parsedURL, err := url.Parse(rawURL); err == nil && parsedURL.Hostname() != "" {
return extractDomainFromHost(parsedURL.Hostname())
}
// Extract domain from various formats:
// - stun:domain:port -> domain
// - turns:domain:port?params -> domain
// - domain:port -> domain
host := rawURL
// Remove scheme prefix (stun:, turn:, turns:)
if colonIndex := strings.Index(host, ":"); colonIndex > 0 && colonIndex < 10 && !strings.Contains(host[:colonIndex], ".") {
host = host[colonIndex+1:]
}
// Remove port suffix
if hostOnly, _, err := net.SplitHostPort(host); err == nil {
host = hostOnly
}
// Remove query parameters
if queryIndex := strings.Index(host, "?"); queryIndex > 0 {
host = host[:queryIndex]
}
return extractDomainFromHost(host)

View File

@@ -0,0 +1,148 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestExtractValidDomain(t *testing.T) {
tests := []struct {
name string
url string
expected string
expectError bool
}{
{
name: "HTTPS URL with port",
url: "https://api.netbird.io:443",
expected: "api.netbird.io",
},
{
name: "HTTP URL without port",
url: "http://signal.example.com",
expected: "signal.example.com",
},
{
name: "Host with port (no scheme)",
url: "signal.netbird.io:443",
expected: "signal.netbird.io",
},
{
name: "STUN URL",
url: "stun:stun.netbird.io:443",
expected: "stun.netbird.io",
},
{
name: "STUN URL with different port",
url: "stun:stun.netbird.io:5555",
expected: "stun.netbird.io",
},
{
name: "TURNS URL with query params",
url: "turns:turn.netbird.io:443?transport=tcp",
expected: "turn.netbird.io",
},
{
name: "TURN URL",
url: "turn:turn.example.com:3478",
expected: "turn.example.com",
},
{
name: "REL URL",
url: "rel://relay.example.com:443",
expected: "relay.example.com",
},
{
name: "RELS URL",
url: "rels://relay.netbird.io:443",
expected: "relay.netbird.io",
},
{
name: "Raw hostname",
url: "example.org",
expected: "example.org",
},
{
name: "IP address should be rejected",
url: "192.168.1.1",
expectError: true,
},
{
name: "IP address with port should be rejected",
url: "192.168.1.1:443",
expectError: true,
},
{
name: "IPv6 address should be rejected",
url: "2001:db8::1",
expectError: true,
},
{
name: "Empty URL",
url: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := extractValidDomain(tt.url)
if tt.expectError {
assert.Error(t, err, "Expected error for URL: %s", tt.url)
} else {
assert.NoError(t, err, "Unexpected error for URL: %s", tt.url)
assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for URL: %s", tt.url)
}
})
}
}
func TestExtractDomainFromHost(t *testing.T) {
tests := []struct {
name string
host string
expected string
expectError bool
}{
{
name: "Valid domain",
host: "example.com",
expected: "example.com",
},
{
name: "Subdomain",
host: "api.example.com",
expected: "api.example.com",
},
{
name: "IPv4 address",
host: "192.168.1.1",
expectError: true,
},
{
name: "IPv6 address",
host: "2001:db8::1",
expectError: true,
},
{
name: "Empty host",
host: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := extractDomainFromHost(tt.host)
if tt.expectError {
assert.Error(t, err, "Expected error for host: %s", tt.host)
} else {
assert.NoError(t, err, "Unexpected error for host: %s", tt.host)
assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for host: %s", tt.host)
}
})
}
}

View File

@@ -16,25 +16,19 @@ import (
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
// CacheEntry holds DNS records for a cached domain
type CacheEntry struct {
ARecords []dns.RR
AAAARecords []dns.RR
}
// Resolver caches critical NetBird infrastructure domains
type Resolver struct {
cache map[domain.Domain]CacheEntry
records map[dns.Question][]dns.RR
managementDomain *domain.Domain
mutex sync.RWMutex
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
cache: make(map[domain.Domain]CacheEntry),
records: make(map[dns.Question][]dns.RR),
}
}
@@ -51,7 +45,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
question := r.Question[0]
qname := strings.ToLower(strings.TrimSuffix(question.Name, "."))
question.Name = strings.ToLower(dns.Fqdn(question.Name))
if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA {
m.continueToNext(w, r)
@@ -59,8 +53,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
m.mutex.RLock()
domainKey := domain.Domain(qname)
entry, found := m.cache[domainKey]
records, found := m.records[question]
m.mutex.RUnlock()
if !found {
@@ -73,34 +66,19 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp.Authoritative = false
resp.RecursionAvailable = true
var records []dns.RR
if question.Qtype == dns.TypeA {
records = entry.ARecords
} else if question.Qtype == dns.TypeAAAA {
records = entry.AAAARecords
}
resp.Answer = append(resp.Answer, records...)
if len(records) == 0 {
m.continueToNext(w, r)
return
}
for _, rr := range records {
rrCopy := dns.Copy(rr)
rrCopy.Header().Name = question.Name
resp.Answer = append(resp.Answer, rrCopy)
}
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), domainKey.SafeString())
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write response: %v", err)
}
}
// MatchSubdomains always returns true as required by the interface.
// MatchSubdomains returns false since this resolver only handles exact domain matches
// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains.
func (m *Resolver) MatchSubdomains() bool {
return true
return false
}
// continueToNext signals the handler chain to continue to the next handler.
@@ -115,8 +93,6 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
// AddDomain manually adds a domain to cache by resolving it.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
log.Debugf("adding domain=%s to cache", d.SafeString())
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
@@ -130,7 +106,7 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: d.PunycodeString() + ".",
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
@@ -141,7 +117,7 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: d.PunycodeString() + ".",
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
@@ -153,10 +129,25 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
}
m.mutex.Lock()
m.cache[d] = CacheEntry{
ARecords: aRecords,
AAAARecords: aaaaRecords,
if len(aRecords) > 0 {
aQuestion := dns.Question{
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
}
if len(aaaaRecords) > 0 {
aaaaQuestion := dns.Question{
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
@@ -167,12 +158,21 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
// PopulateFromConfig extracts and caches domains from the client configuration.
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
if mgmtURL != nil {
if d, err := extractDomainFromURL(mgmtURL); err == nil {
if mgmtURL == nil {
return nil
}
d, err := extractDomainFromURL(mgmtURL)
if err != nil {
return fmt.Errorf("extract domain from URL: %w", err)
}
m.mutex.Lock()
m.managementDomain = &d
m.mutex.Unlock()
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("failed to add management domain: %v", err)
}
}
return fmt.Errorf("add domain: %w", err)
}
return nil
@@ -183,191 +183,95 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock()
defer m.mutex.Unlock()
delete(m.cache, d)
aQuestion := dns.Question{
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
delete(m.records, aQuestion)
aaaaQuestion := dns.Question{
Name: strings.ToLower(dns.Fqdn(d.PunycodeString())),
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
}
// PopulateFromNetbirdConfig extracts and caches domains from the netbird config.
func (m *Resolver) PopulateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) error {
if config == nil {
return nil
}
m.addSignalDomain(ctx, config.Signal)
m.addRelayDomains(ctx, config.Relay)
m.addFlowDomain(ctx, config.Flow)
m.addStunDomains(ctx, config.Stuns)
m.addTurnDomains(ctx, config.Turns)
return nil
}
// addSignalDomain adds signal server domain to cache.
func (m *Resolver) addSignalDomain(ctx context.Context, signal *mgmProto.HostConfig) {
if signal == nil || signal.Uri == "" {
return
}
signalURL, err := url.Parse(signal.Uri)
if err != nil {
// If parsing fails, it might be a raw host:port, try adding a scheme
signalURL, err = url.Parse("https://" + signal.Uri)
if err != nil {
log.Warnf("failed to parse signal URL: %v", err)
return
}
}
d, err := extractDomainFromURL(signalURL)
if err != nil {
log.Warnf("failed to extract signal domain: %v", err)
return
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("failed to add signal domain: %v", err)
}
}
// addRelayDomains adds relay server domains to cache.
func (m *Resolver) addRelayDomains(ctx context.Context, relay *mgmProto.RelayConfig) {
if relay == nil {
return
}
for _, relayAddr := range relay.Urls {
relayURL, err := url.Parse(relayAddr)
if err != nil {
log.Warnf("failed to parse relay URL %s: %v", relayAddr, err)
continue
}
d, err := extractDomainFromURL(relayURL)
if err != nil {
log.Warnf("failed to extract relay domain from %s: %v", relayAddr, err)
continue
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("failed to add relay domain: %v", err)
}
}
}
// addFlowDomain adds traffic flow server domain to cache.
func (m *Resolver) addFlowDomain(ctx context.Context, flow *mgmProto.FlowConfig) {
if flow == nil || flow.Url == "" {
return
}
flowURL, err := url.Parse(flow.Url)
if err != nil {
log.Warnf("failed to parse flow URL: %v", err)
return
}
d, err := extractDomainFromURL(flowURL)
if err != nil {
log.Warnf("failed to extract flow domain: %v", err)
return
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("failed to add flow domain: %v", err)
}
}
// GetCachedDomains returns a list of all cached domains.
func (m *Resolver) GetCachedDomains() []domain.Domain {
func (m *Resolver) GetCachedDomains() domain.List {
m.mutex.RLock()
defer m.mutex.RUnlock()
domains := make([]domain.Domain, 0, len(m.cache))
for d := range m.cache {
domainSet := make(map[domain.Domain]struct{})
for question := range m.records {
domainName := strings.TrimSuffix(question.Name, ".")
domainSet[domain.Domain(domainName)] = struct{}{}
}
domains := make(domain.List, 0, len(domainSet))
for d := range domainSet {
domains = append(domains, d)
}
return domains
}
// ClearCache removes all cached domains and returns them for external deregistration.
func (m *Resolver) ClearCache() []domain.Domain {
m.mutex.Lock()
defer m.mutex.Unlock()
domains := make([]domain.Domain, 0, len(m.cache))
for d := range m.cache {
domains = append(domains, d)
}
m.cache = make(map[domain.Domain]CacheEntry)
log.Debugf("cleared %d cached domains", len(domains))
return domains
}
// UpdateFromNetbirdConfig updates the cache intelligently by comparing current and new configurations.
// Returns domains that were removed for external deregistration.
func (m *Resolver) UpdateFromNetbirdConfig(ctx context.Context, config *mgmProto.NetbirdConfig) ([]domain.Domain, error) {
log.Debugf("updating cache from NetbirdConfig")
// UpdateFromServerDomains updates the cache using the simplified ServerDomains struct
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) {
currentDomains := m.GetCachedDomains()
newDomains := m.extractDomainsFromConfig(config)
newDomains := m.extractDomainsFromServerDomains(serverDomains)
var removedDomains []domain.Domain
for _, currentDomain := range currentDomains {
found := false
for _, newDomain := range newDomains {
if currentDomain.SafeString() == newDomain.SafeString() {
found = true
break
}
}
if !found {
removedDomains = append(removedDomains, currentDomain)
}
}
m.mutex.Lock()
for _, domainToRemove := range removedDomains {
delete(m.cache, domainToRemove)
log.Debugf("removed domain=%s from cache", domainToRemove.SafeString())
}
m.mutex.Unlock()
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
}
}
removedDomains := m.removeStaleDomainsExceptManagement(currentDomains, newDomains)
m.addNewDomains(ctx, newDomains)
return removedDomains, nil
}
// UpdateFromServerDomains updates the cache using the simplified ServerDomains struct
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) ([]domain.Domain, error) {
log.Debugf("updating cache from ServerDomains")
// removeStaleDomainsExceptManagement removes domains not in newDomains, except management domain
func (m *Resolver) removeStaleDomainsExceptManagement(currentDomains, newDomains domain.List) domain.List {
var removedDomains domain.List
currentDomains := m.GetCachedDomains()
newDomains := m.extractDomainsFromServerDomains(serverDomains)
var removedDomains []domain.Domain
for _, currentDomain := range currentDomains {
found := false
for _, newDomain := range newDomains {
if currentDomain.SafeString() == newDomain.SafeString() {
found = true
break
if m.isDomainInList(currentDomain, newDomains) {
continue
}
if m.isManagementDomain(currentDomain) {
continue
}
if !found {
removedDomains = append(removedDomains, currentDomain)
if err := m.RemoveDomain(currentDomain); err != nil {
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
}
}
return removedDomains
}
// isDomainInList checks if domain exists in the list
func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool {
for _, d := range list {
if domain.SafeString() == d.SafeString() {
return true
}
}
return false
}
// isManagementDomain checks if domain is the protected management domain
func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.managementDomain != nil && domain.SafeString() == m.managementDomain.SafeString()
}
// addNewDomains adds all new domains to the cache
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
@@ -375,12 +279,10 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
}
}
return removedDomains, nil
}
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) []domain.Domain {
var domains []domain.Domain
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List {
var domains domain.List
if serverDomains.Signal != "" {
domains = append(domains, serverDomains.Signal)
@@ -411,164 +313,6 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains
}
// extractDomainsFromConfig extracts all domains from a NetbirdConfig.
func (m *Resolver) extractDomainsFromConfig(config *mgmProto.NetbirdConfig) []domain.Domain {
if config == nil {
return nil
}
var domains []domain.Domain
// Extract signal domain
domains = append(domains, m.extractSignalDomain(config)...)
// Extract relay domains
domains = append(domains, m.extractRelayDomains(config)...)
// Extract flow domain
domains = append(domains, m.extractFlowDomain(config)...)
// Extract STUN domains
domains = append(domains, m.extractSTUNDomains(config)...)
// Extract TURN domains
domains = append(domains, m.extractTURNDomains(config)...)
return domains
}
func (m *Resolver) extractSignalDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Signal == nil || config.Signal.Uri == "" {
return nil
}
if d, err := m.extractDomainFromSignalConfig(config.Signal); err == nil {
return []domain.Domain{d}
}
return nil
}
func (m *Resolver) extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Relay == nil {
return nil
}
var domains []domain.Domain
for _, relayURL := range config.Relay.Urls {
if d, err := m.extractDomainFromURL(relayURL); err == nil {
domains = append(domains, d)
}
}
return domains
}
func (m *Resolver) extractFlowDomain(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Flow == nil || config.Flow.Url == "" {
return nil
}
if d, err := m.extractDomainFromURL(config.Flow.Url); err == nil {
return []domain.Domain{d}
}
return nil
}
func (m *Resolver) extractSTUNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var domains []domain.Domain
for _, stun := range config.Stuns {
if stun != nil && stun.Uri != "" {
if d, err := m.extractDomainFromURL(stun.Uri); err == nil {
domains = append(domains, d)
}
}
}
return domains
}
func (m *Resolver) extractTURNDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var domains []domain.Domain
for _, turn := range config.Turns {
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
if d, err := m.extractDomainFromURL(turn.HostConfig.Uri); err == nil {
domains = append(domains, d)
}
}
}
return domains
}
// extractDomainFromSignalConfig extracts domain from signal configuration.
func (m *Resolver) extractDomainFromSignalConfig(signal *mgmProto.HostConfig) (domain.Domain, error) {
signalURL, err := url.Parse(signal.Uri)
if err != nil {
// If parsing fails, it might be a raw host:port, try adding a scheme
signalURL, err = url.Parse("https://" + signal.Uri)
if err != nil {
return "", err
}
}
return extractDomainFromURL(signalURL)
}
// extractDomainFromURL extracts domain from a URL string.
func (m *Resolver) extractDomainFromURL(urlStr string) (domain.Domain, error) {
parsedURL, err := url.Parse(urlStr)
if err != nil {
return "", err
}
return extractDomainFromURL(parsedURL)
}
// addStunDomains adds STUN server domains to cache.
func (m *Resolver) addStunDomains(ctx context.Context, stuns []*mgmProto.HostConfig) {
for _, stun := range stuns {
if stun == nil || stun.Uri == "" {
continue
}
stunURL, err := url.Parse(stun.Uri)
if err != nil {
log.Warnf("failed to parse STUN URL %s: %v", stun.Uri, err)
continue
}
d, err := extractDomainFromURL(stunURL)
if err != nil {
log.Warnf("failed to extract STUN domain from %s: %v", stun.Uri, err)
continue
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("failed to add STUN domain: %v", err)
}
}
}
// addTurnDomains adds TURN server domains to cache.
func (m *Resolver) addTurnDomains(ctx context.Context, turns []*mgmProto.ProtectedHostConfig) {
for _, turn := range turns {
if turn == nil || turn.HostConfig == nil || turn.HostConfig.Uri == "" {
continue
}
turnURL, err := url.Parse(turn.HostConfig.Uri)
if err != nil {
log.Warnf("failed to parse TURN URL %s: %v", turn.HostConfig.Uri, err)
continue
}
d, err := extractDomainFromURL(turnURL)
if err != nil {
log.Warnf("failed to extract TURN domain from %s: %v", turn.HostConfig.Uri, err)
continue
}
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("failed to add TURN domain: %v", err)
}
}
}
// extractDomainFromURL extracts the domain from a URL.
func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
if u == nil {

View File

@@ -2,22 +2,24 @@ package mgmt
import (
"context"
"net"
"net/url"
"strings"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
mgmProto "github.com/netbirdio/netbird/management/proto"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/management/domain"
)
func TestResolver_NewResolver(t *testing.T) {
resolver := NewResolver()
assert.NotNil(t, resolver)
assert.NotNil(t, resolver.cache)
assert.True(t, resolver.MatchSubdomains())
assert.NotNil(t, resolver.records)
assert.False(t, resolver.MatchSubdomains())
}
func TestResolver_ExtractDomainFromURL(t *testing.T) {
@@ -113,145 +115,173 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
resolver := NewResolver()
// Use IP address to avoid DNS resolution timeout
// Test with IP address - should return error since IP addresses are rejected
mgmtURL, _ := url.Parse("https://127.0.0.1")
err := resolver.PopulateFromConfig(ctx, mgmtURL)
assert.NoError(t, err)
assert.Error(t, err)
assert.Contains(t, err.Error(), "host is an IP address")
// IP addresses are rejected, so no domains should be cached
// No domains should be cached when using IP addresses
domains := resolver.GetCachedDomains()
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
}
func TestResolver_PopulateFromNetbirdConfig(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
func TestResolver_ServeDNS(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
// Use IP addresses to avoid DNS resolution timeouts
netbirdConfig := &mgmProto.NetbirdConfig{
Signal: &mgmProto.HostConfig{
Uri: "https://10.0.0.1",
},
Relay: &mgmProto.RelayConfig{
Urls: []string{
"https://10.0.0.2:443",
"https://10.0.0.3:443",
},
},
Flow: &mgmProto.FlowConfig{
Url: "https://10.0.0.4:80",
},
Stuns: []*mgmProto.HostConfig{
{Uri: "stun:10.0.0.5:3478"},
{Uri: "stun:10.0.0.6:3478"},
},
Turns: []*mgmProto.ProtectedHostConfig{
{
HostConfig: &mgmProto.HostConfig{
Uri: "turn:10.0.0.7:3478",
},
},
{
HostConfig: &mgmProto.HostConfig{
Uri: "turn:10.0.0.8:3478",
},
},
// Add a test domain to the cache - use example.org which is reserved for testing
testDomain, err := domain.FromString("example.org")
if err != nil {
t.Fatalf("Failed to create domain: %v", err)
}
err = resolver.AddDomain(ctx, testDomain)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
// Test A record query for cached domain
t.Run("Cached domain A record", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
err := resolver.PopulateFromNetbirdConfig(ctx, netbirdConfig)
assert.NoError(t, err)
req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA)
// IP addresses are rejected, so no domains should be cached
domains := resolver.GetCachedDomains()
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
}
resolver.ServeDNS(mockWriter, req)
func TestResolver_UpdateFromNetbirdConfig(t *testing.T) {
resolver := NewResolver()
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
assert.True(t, len(capturedMsg.Answer) > 0, "Should have at least one answer")
})
// Test with empty initial config and then add domains
initialConfig := &mgmProto.NetbirdConfig{}
// Start with empty config
removedDomains, err := resolver.UpdateFromNetbirdConfig(context.Background(), initialConfig)
assert.NoError(t, err)
assert.Equal(t, 0, len(removedDomains), "No domains should be removed from empty cache")
// Update to config with IP addresses instead of domains to avoid DNS resolution
// IP addresses will be rejected by extractDomainFromURL so no actual resolution happens
updatedConfig := &mgmProto.NetbirdConfig{
Signal: &mgmProto.HostConfig{
Uri: "https://127.0.0.1",
},
Flow: &mgmProto.FlowConfig{
Url: "https://192.168.1.1:80",
// Test uncached domain signals to continue to next handler
t.Run("Uncached domain signals continue to next handler", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
removedDomains, err = resolver.UpdateFromNetbirdConfig(context.Background(), updatedConfig)
assert.NoError(t, err)
// Verify the method completes successfully without DNS timeouts
assert.GreaterOrEqual(t, len(removedDomains), 0, "Should not error on config update")
// Verify no domains were actually added since IPs are rejected
domains := resolver.GetCachedDomains()
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
}
func TestResolver_ContinueToNext(t *testing.T) {
resolver := NewResolver()
// Create a mock response writer to capture the response
mockWriter := &MockResponseWriter{}
// Create a test DNS query
req := new(dns.Msg)
req.SetQuestion("unknown.example.com.", dns.TypeA)
// Call continueToNext
resolver.continueToNext(mockWriter, req)
resolver.ServeDNS(mockWriter, req)
// Verify the response
assert.NotNil(t, mockWriter.msg)
assert.Equal(t, dns.RcodeNameError, mockWriter.msg.Rcode)
assert.True(t, mockWriter.msg.MsgHdr.Zero)
}
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
// Zero flag set to true signals the handler chain to continue to next handler
assert.True(t, capturedMsg.MsgHdr.Zero, "Zero flag should be set to signal continuation to next handler")
assert.Empty(t, capturedMsg.Answer, "Should have no answers for uncached domain")
})
// MockResponseWriter is a simple mock implementation of dns.ResponseWriter for testing
type MockResponseWriter struct {
msg *dns.Msg
}
func (m *MockResponseWriter) LocalAddr() net.Addr {
return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 53}
}
func (m *MockResponseWriter) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
}
func (m *MockResponseWriter) WriteMsg(msg *dns.Msg) error {
m.msg = msg
// Test that subdomains of cached domains are NOT resolved
t.Run("Subdomains of cached domains are not resolved", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
func (m *MockResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
// Query for a subdomain of our cached domain
req := new(dns.Msg)
req.SetQuestion("sub.example.org.", dns.TypeA)
func (m *MockResponseWriter) Close() error {
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
assert.True(t, capturedMsg.MsgHdr.Zero, "Should signal continuation to next handler for subdomains")
assert.Empty(t, capturedMsg.Answer, "Should have no answers for subdomains")
})
// Test case-insensitive matching
t.Run("Case-insensitive domain matching", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
func (m *MockResponseWriter) TsigStatus() error {
return nil
// Query with different casing
req := new(dns.Msg)
req.SetQuestion("EXAMPLE.ORG.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
assert.True(t, len(capturedMsg.Answer) > 0, "Should resolve regardless of case")
})
}
func (m *MockResponseWriter) TsigTimersOnly(bool) {}
func TestResolver_GetCachedDomains(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
func (m *MockResponseWriter) Hijack() {}
testDomain, err := domain.FromString("example.org")
if err != nil {
t.Fatalf("Failed to create domain: %v", err)
}
err = resolver.AddDomain(ctx, testDomain)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
cachedDomains := resolver.GetCachedDomains()
assert.Equal(t, 1, len(cachedDomains), "Should return exactly one domain for single added domain")
assert.Equal(t, testDomain.SafeString(), cachedDomains[0].SafeString(), "Cached domain should match original")
assert.False(t, strings.HasSuffix(cachedDomains[0].PunycodeString(), "."), "Domain should not have trailing dot")
}
func TestResolver_ManagementDomainProtection(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
mgmtURL, _ := url.Parse("https://example.org")
err := resolver.PopulateFromConfig(ctx, mgmtURL)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
initialDomains := resolver.GetCachedDomains()
if len(initialDomains) == 0 {
t.Skip("Management domain failed to resolve, skipping test")
}
assert.Equal(t, 1, len(initialDomains), "Should have management domain cached")
assert.Equal(t, "example.org", initialDomains[0].SafeString())
serverDomains := dnsconfig.ServerDomains{
Signal: "google.com",
Relay: []domain.Domain{"cloudflare.com"},
}
_, err = resolver.UpdateFromServerDomains(ctx, serverDomains)
if err != nil {
t.Logf("Server domains update failed: %v", err)
}
finalDomains := resolver.GetCachedDomains()
managementStillCached := false
for _, d := range finalDomains {
if d.SafeString() == "example.org" {
managementStillCached = true
break
}
}
assert.True(t, managementStillCached, "Management domain should never be removed")
}

View File

@@ -74,7 +74,6 @@ type DefaultServer struct {
handlerChain *HandlerChain
extraDomains map[domain.Domain]int
// management cache resolver for critical infrastructure domains
mgmtCacheResolver *mgmt.Resolver
// permanent related properties
@@ -106,18 +105,15 @@ type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct {
Ctx context.Context
WgInterface WGIface
CustomAddress string
StatusRecorder *peer.Status
StateManager *statemanager.Manager
DisableSys bool
MgmtURL *url.URL
ServerDomains dnsconfig.ServerDomains
}
// NewDefaultServer returns a new dns server
func NewDefaultServer(config DefaultServerConfig) (*DefaultServer, error) {
func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*DefaultServer, error) {
var addrPort *netip.AddrPort
if config.CustomAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress)
@@ -134,20 +130,7 @@ func NewDefaultServer(config DefaultServerConfig) (*DefaultServer, error) {
dnsService = newServiceViaListener(config.WgInterface, addrPort)
}
server := newDefaultServer(config.Ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
if config.MgmtURL != nil && server.mgmtCacheResolver != nil {
if err := server.mgmtCacheResolver.PopulateFromConfig(config.Ctx, config.MgmtURL); err != nil {
log.Warnf("Failed to populate management cache from management URL: %v", err)
}
}
if server.mgmtCacheResolver != nil {
if err := server.UpdateServerConfig(config.ServerDomains); err != nil {
log.Warnf("Failed to populate management cache from ServerDomains: %v", err)
}
}
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
return server, nil
}
@@ -197,7 +180,6 @@ func newDefaultServer(
handlerChain := NewHandlerChain()
ctx, stop := context.WithCancel(ctx)
// Create management cache resolver
mgmtCacheResolver := mgmt.NewResolver()
defaultServer := &DefaultServer{
@@ -345,20 +327,8 @@ func (s *DefaultServer) Stop() {
maps.Clear(s.extraDomains)
}
// PopulateMgmtCacheFromConfig populates the management cache with domains from the client configuration
func (s *DefaultServer) PopulateMgmtCacheFromConfig(mgmtURL *url.URL) error {
if s.mgmtCacheResolver == nil {
return fmt.Errorf("management cache resolver not initialized")
}
log.Debug("populating management cache from client configuration")
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
}
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
s.hostsDNSHolder.set(hostsDnsList)
@@ -465,12 +435,12 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro
}
if len(removedDomains) > 0 {
s.DeregisterHandler(removedDomains, PriorityMgmtCache)
s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache)
}
newDomains := s.mgmtCacheResolver.GetCachedDomains()
if len(newDomains) > 0 {
s.RegisterHandler(newDomains, s.mgmtCacheResolver, PriorityMgmtCache)
s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache)
}
}
@@ -935,3 +905,11 @@ func toZone(d domain.Domain) domain.Domain {
),
)
}
// PopulateManagementDomain populates the DNS cache with management domain
func (s *DefaultServer) PopulateManagementDomain(ctx context.Context, mgmtURL *url.URL) error {
if s.mgmtCacheResolver != nil && mgmtURL != nil {
return s.mgmtCacheResolver.PopulateFromConfig(ctx, mgmtURL)
}
return nil
}

View File

@@ -23,7 +23,6 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/dns/types"
@@ -364,15 +363,12 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(DefaultServerConfig{
Ctx: context.Background(),
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
})
if err != nil {
t.Fatal(err)
@@ -483,15 +479,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return
}
dnsServer, err := NewDefaultServer(DefaultServerConfig{
Ctx: context.Background(),
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
})
if err != nil {
t.Errorf("create DNS server: %v", err)
@@ -594,15 +587,12 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(DefaultServerConfig{
Ctx: context.Background(),
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: &mocWGIface{},
CustomAddress: testCase.addrPort,
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
MgmtURL: nil,
ServerDomains: dnsconfig.ServerDomains{},
})
if err != nil {
t.Fatalf("%v", err)

View File

@@ -33,7 +33,7 @@ import (
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/config"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow"
@@ -126,12 +126,6 @@ type EngineConfig struct {
BlockInbound bool
LazyConnectionEnabled bool
// ManagementURL is the URL of the management server for DNS cache
ManagementURL *url.URL
// NetbirdConfig contains signal, relay, and flow server configuration
NetbirdConfig *mgmProto.NetbirdConfig
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -350,7 +344,7 @@ func (e *Engine) Stop() error {
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
func (e *Engine) Start() error {
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -395,13 +389,18 @@ func (e *Engine) Start() error {
return fmt.Errorf("read initial settings: %w", err)
}
dnsServer, err := e.newDnsServer(dnsConfig, e.config.ManagementURL, e.config.NetbirdConfig)
dnsServer, err := e.newDnsServer(dnsConfig)
if err != nil {
e.close()
return fmt.Errorf("create dns server: %w", err)
}
e.dnsServer = dnsServer
// Populate DNS cache with NetbirdConfig and management URL for early resolution
if err := e.PopulateNetbirdConfig(netbirdConfig, mgmtURL); err != nil {
log.Warnf("failed to populate DNS cache: %v", err)
}
e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
Context: e.ctx,
PublicKey: e.config.WgPrivateKey.PublicKey().String(),
@@ -666,6 +665,32 @@ func (e *Engine) removePeer(peerKey string) error {
return nil
}
// PopulateNetbirdConfig populates the DNS cache with infrastructure domains from login response
func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
if e.dnsServer == nil {
return nil
}
// Populate management URL if provided
if mgmtURL != nil {
if defaultServer, ok := e.dnsServer.(*dns.DefaultServer); ok {
if err := defaultServer.PopulateManagementDomain(e.ctx, mgmtURL); err != nil {
log.Warnf("failed to populate DNS cache with management URL: %v", err)
}
}
}
// Populate NetbirdConfig domains if provided
if netbirdConfig != nil {
serverDomains := dnsconfig.ExtractFromNetbirdConfig(netbirdConfig)
if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
return fmt.Errorf("update DNS server config from NetbirdConfig: %w", err)
}
}
return nil
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -697,12 +722,9 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return fmt.Errorf("handle the flow configuration: %w", err)
}
if e.dnsServer != nil {
serverDomains := config.ExtractFromNetbirdConfig(wCfg)
if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
}
// todo update signal
}
@@ -1587,7 +1609,7 @@ func (e *Engine) wgInterfaceCreate() (err error) {
return err
}
func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbirdConfig *mgmProto.NetbirdConfig) (dns.Server, error) {
func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
// due to tests where we are using a mocked version of the DNS server
if e.dnsServer != nil {
return e.dnsServer, nil
@@ -1612,18 +1634,13 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config, mgmtURL *url.URL, netbird
return dnsServer, nil
default:
// Extract domains from NetBird configuration
serverDomains := config.ExtractFromNetbirdConfig(netbirdConfig)
dnsServer, err := dns.NewDefaultServer(dns.DefaultServerConfig{
Ctx: e.ctx,
dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{
WgInterface: e.wgInterface,
CustomAddress: e.config.CustomDNSAddress,
StatusRecorder: e.statusRecorder,
StateManager: e.stateManager,
DisableSys: e.config.DisableDNS,
MgmtURL: mgmtURL,
ServerDomains: serverDomains,
})
if err != nil {
return nil, err
@@ -1643,11 +1660,6 @@ func (e *Engine) GetFirewallManager() firewallManager.Manager {
return e.firewall
}
// GetDNSServer returns the DNS server
func (e *Engine) GetDNSServer() dns.Server {
return e.dnsServer
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {

View File

@@ -261,7 +261,7 @@ func TestEngine_SSH(t *testing.T) {
},
}, nil
}
err = engine.Start()
err = engine.Start(nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -605,7 +605,7 @@ func TestEngine_Sync(t *testing.T) {
}
}()
err = engine.Start()
err = engine.Start(nil, nil)
if err != nil {
t.Fatal(err)
return
@@ -1060,7 +1060,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start()
err = engine.Start(nil, nil)
if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err)
wg.Done()

View File

@@ -39,7 +39,7 @@ func IsLoginRequired(ctx context.Context, config *Config) (bool, error) {
return false, err
}
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if isLoginNeeded(err) {
return true, nil
}
@@ -68,14 +68,18 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
return err
}
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
if err != nil {
return err
}
} else if err != nil {
return err
}
return err
return nil
}
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
@@ -100,11 +104,11 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm
return mgmClient, err
}
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *Config) (*wgtypes.Key, error) {
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
@@ -120,8 +124,8 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
config.BlockInbound,
config.LazyConnectionEnabled,
)
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, err
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.