Add dns interceptor based domain route functionality (#3032)

This commit is contained in:
Zoltan Papp 2024-12-13 14:17:10 +01:00 committed by GitHub
parent a145f0b811
commit c91d7808bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 1740 additions and 163 deletions

View File

@ -46,7 +46,7 @@ jobs:
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v3 uses: golangci/golangci-lint-action@v4
with: with:
version: latest version: latest
args: --timeout=12m args: --timeout=12m --out-format colored-line-number

View File

@ -0,0 +1,192 @@
package dns
import (
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
const (
PriorityDNSRoute = 100
PriorityMatchDomain = 50
PriorityDefault = 0
)
type HandlerEntry struct {
Handler dns.Handler
Priority int
Pattern string
OrigPattern string
IsWildcard bool
StopHandler handlerWithStop
}
// HandlerChain represents a prioritized chain of DNS handlers
type HandlerChain struct {
mu sync.RWMutex
handlers []HandlerEntry
}
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
type ResponseWriterChain struct {
dns.ResponseWriter
shouldContinue bool
}
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
// Check if this is a continue signal (NXDOMAIN with Zero bit set)
if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero {
w.shouldContinue = true
return nil
}
return w.ResponseWriter.WriteMsg(m)
}
func NewHandlerChain() *HandlerChain {
return &HandlerChain{
handlers: make([]HandlerEntry, 0),
}
}
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
c.mu.Lock()
defer c.mu.Unlock()
origPattern := pattern
isWildcard := strings.HasPrefix(pattern, "*.")
if isWildcard {
pattern = pattern[2:]
}
pattern = dns.Fqdn(pattern)
origPattern = dns.Fqdn(origPattern)
// First remove any existing handler with same original pattern and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority {
if c.handlers[i].StopHandler != nil {
c.handlers[i].StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break
}
}
log.Debugf("adding handler for pattern: %s (original: %s, wildcard: %v) with priority %d",
pattern, origPattern, isWildcard, priority)
entry := HandlerEntry{
Handler: handler,
Priority: priority,
Pattern: pattern,
OrigPattern: origPattern,
IsWildcard: isWildcard,
StopHandler: stopHandler,
}
// Insert handler in priority order
pos := 0
for i, h := range c.handlers {
if h.Priority < priority {
pos = i
break
}
pos = i + 1
}
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
}
// RemoveHandler removes a handler for the given pattern and priority
func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
c.mu.Lock()
defer c.mu.Unlock()
pattern = dns.Fqdn(pattern)
// Find and remove handlers matching both original pattern and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if entry.OrigPattern == pattern && entry.Priority == priority {
if entry.StopHandler != nil {
entry.StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return
}
}
}
// HasHandlers returns true if there are any handlers remaining for the given pattern
func (c *HandlerChain) HasHandlers(pattern string) bool {
c.mu.RLock()
defer c.mu.RUnlock()
pattern = dns.Fqdn(pattern)
for _, entry := range c.handlers {
if entry.Pattern == pattern {
return true
}
}
return false
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
qname := r.Question[0].Name
log.Debugf("handling DNS request for %s", qname)
c.mu.RLock()
defer c.mu.RUnlock()
log.Debugf("current handlers (%d):", len(c.handlers))
for _, h := range c.handlers {
log.Debugf(" - pattern: %s, original: %s, wildcard: %v, priority: %d",
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
}
// Try handlers in priority order
for _, entry := range c.handlers {
var matched bool
switch {
case entry.Pattern == ".":
matched = true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
}
if !matched {
log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false",
entry.OrigPattern, qname, entry.IsWildcard)
continue
}
log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v",
entry.OrigPattern, qname, entry.IsWildcard)
chainWriter := &ResponseWriterChain{ResponseWriter: w}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
log.Debugf("handler requested continue to next handler")
continue
}
return
}
// No handler matched or all handlers passed
log.Debugf("no handler found for %s", qname)
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}

View File

@ -0,0 +1,490 @@
package dns_test
import (
"net"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
)
// MockHandler implements dns.Handler interface for testing
type MockHandler struct {
mock.Mock
}
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
m.Called(w, r)
}
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
chain := nbdns.NewHandlerChain()
// Create mock handlers for different priorities
defaultHandler := &MockHandler{}
matchDomainHandler := &MockHandler{}
dnsRouteHandler := &MockHandler{}
// Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
// Create test request
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
// Create test writer
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Setup expectations - only highest priority handler should be called
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe()
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe()
// Execute
chain.ServeDNS(w, r)
// Verify all expectations were met
dnsRouteHandler.AssertExpectations(t)
matchDomainHandler.AssertExpectations(t)
defaultHandler.AssertExpectations(t)
}
// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
tests := []struct {
name string
handlerDomain string
queryDomain string
isWildcard bool
shouldMatch bool
}{
{
name: "exact match",
handlerDomain: "example.com.",
queryDomain: "example.com.",
isWildcard: false,
shouldMatch: true,
},
{
name: "subdomain with non-wildcard",
handlerDomain: "example.com.",
queryDomain: "sub.example.com.",
isWildcard: false,
shouldMatch: true,
},
{
name: "wildcard match",
handlerDomain: "*.example.com.",
queryDomain: "sub.example.com.",
isWildcard: true,
shouldMatch: true,
},
{
name: "wildcard no match on apex",
handlerDomain: "*.example.com.",
queryDomain: "example.com.",
isWildcard: true,
shouldMatch: false,
},
{
name: "root zone match",
handlerDomain: ".",
queryDomain: "anything.com.",
isWildcard: false,
shouldMatch: true,
},
{
name: "no match different domain",
handlerDomain: "example.com.",
queryDomain: "example.org.",
isWildcard: false,
shouldMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
mockHandler := &MockHandler{}
pattern := tt.handlerDomain
if tt.isWildcard {
pattern = "*." + tt.handlerDomain[2:] // Remove the first two chars if it's a wildcard
}
chain.AddHandler(pattern, mockHandler, nbdns.PriorityDefault, nil)
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
if tt.shouldMatch {
mockHandler.On("ServeDNS", mock.Anything, r).Once()
}
chain.ServeDNS(w, r)
mockHandler.AssertExpectations(t)
})
}
}
// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns
func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
tests := []struct {
name string
handlers []struct {
pattern string
priority int
}
queryDomain string
expectedCalls int
expectedHandler int // index of the handler that should be called
}{
{
name: "wildcard and exact same priority - exact should win",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "example.com.", priority: nbdns.PriorityDefault},
},
queryDomain: "example.com.",
expectedCalls: 1,
expectedHandler: 1, // exact match handler should be called
},
{
name: "higher priority wildcard over lower priority exact",
handlers: []struct {
pattern string
priority int
}{
{pattern: "example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "test.example.com.",
expectedCalls: 1,
expectedHandler: 1, // higher priority wildcard handler should be called
},
{
name: "multiple wildcards different priorities",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "test.example.com.",
expectedCalls: 1,
expectedHandler: 2, // highest priority handler should be called
},
{
name: "subdomain with mix of patterns",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "sub.test.example.com.",
expectedCalls: 1,
expectedHandler: 2, // highest priority matching handler should be called
},
{
name: "root zone with specific domain",
handlers: []struct {
pattern string
priority int
}{
{pattern: ".", priority: nbdns.PriorityDefault},
{pattern: "example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "example.com.",
expectedCalls: 1,
expectedHandler: 1, // higher priority specific domain should win over root
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
var handlers []*MockHandler
// Setup handlers and expectations
for i := range tt.handlers {
handler := &MockHandler{}
handlers = append(handlers, handler)
// Set expectation based on whether this handler should be called
if i == tt.expectedHandler {
handler.On("ServeDNS", mock.Anything, mock.Anything).Once()
} else {
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
}
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
}
// Create and execute request
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w, r)
// Verify expectations
for _, handler := range handlers {
handler.AssertExpectations(t)
}
})
}
}
// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality
func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
chain := nbdns.NewHandlerChain()
// Create handlers
handler1 := &MockHandler{}
handler2 := &MockHandler{}
handler3 := &MockHandler{}
// Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
// Create test request
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
// Setup mock responses to simulate chain continuation
handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
// First handler signals continue
w := args.Get(0).(*nbdns.ResponseWriterChain)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true // Signal to continue
assert.NoError(t, w.WriteMsg(resp))
}).Once()
handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
// Second handler signals continue
w := args.Get(0).(*nbdns.ResponseWriterChain)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
assert.NoError(t, w.WriteMsg(resp))
}).Once()
handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
// Last handler responds normally
w := args.Get(0).(*nbdns.ResponseWriterChain)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
// Execute
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w, r)
// Verify all handlers were called in order
handler1.AssertExpectations(t)
handler2.AssertExpectations(t)
handler3.AssertExpectations(t)
}
// mockResponseWriter implements dns.ResponseWriter for testing
type mockResponseWriter struct {
mock.Mock
}
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (m *mockResponseWriter) Close() error { return nil }
func (m *mockResponseWriter) TsigStatus() error { return nil }
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
func (m *mockResponseWriter) Hijack() {}
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
tests := []struct {
name string
ops []struct {
action string // "add" or "remove"
pattern string
priority int
}
query string
expectedCalls map[int]bool // map[priority]shouldBeCalled
}{
{
name: "remove high priority keeps lower priority handler",
ops: []struct {
action string
pattern string
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"remove", "example.com.", nbdns.PriorityDNSRoute},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: true,
},
},
{
name: "remove lower priority keeps high priority handler",
ops: []struct {
action string
pattern string
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"remove", "example.com.", nbdns.PriorityMatchDomain},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: true,
nbdns.PriorityMatchDomain: false,
},
},
{
name: "remove all handlers in order",
ops: []struct {
action string
pattern string
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"add", "example.com.", nbdns.PriorityDefault},
{"remove", "example.com.", nbdns.PriorityDNSRoute},
{"remove", "example.com.", nbdns.PriorityMatchDomain},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: false,
nbdns.PriorityDefault: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handlers := make(map[int]*MockHandler)
// Execute operations
for _, op := range tt.ops {
if op.action == "add" {
handler := &MockHandler{}
handlers[op.priority] = handler
chain.AddHandler(op.pattern, handler, op.priority, nil)
} else {
chain.RemoveHandler(op.pattern, op.priority)
}
}
// Create test request
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Setup expectations
for priority, handler := range handlers {
if shouldCall, exists := tt.expectedCalls[priority]; exists && shouldCall {
handler.On("ServeDNS", mock.Anything, r).Once()
} else {
handler.On("ServeDNS", mock.Anything, r).Maybe()
}
}
// Execute request
chain.ServeDNS(w, r)
// Verify expectations
for _, handler := range handlers {
handler.AssertExpectations(t)
}
// Verify handler exists check
for priority, shouldExist := range tt.expectedCalls {
if shouldExist {
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
"Handler chain should have handlers for pattern after removing priority %d", priority)
}
}
})
}
}
func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
chain := nbdns.NewHandlerChain()
testDomain := "example.com."
testQuery := "test.example.com."
// Create handlers for three priority levels
routeHandler := &MockHandler{}
matchHandler := &MockHandler{}
defaultHandler := &MockHandler{}
// Create test request that will be reused
r := new(dns.Msg)
r.SetQuestion(testQuery, dns.TypeA)
// Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
// Test 1: Initial state with all three handlers
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Highest priority handler (routeHandler) should be called
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
chain.ServeDNS(w, r)
routeHandler.AssertExpectations(t)
// Test 2: Remove highest priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
assert.True(t, chain.HasHandlers(testDomain))
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Now middle priority handler (matchHandler) should be called
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
chain.ServeDNS(w, r)
matchHandler.AssertExpectations(t)
// Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
assert.True(t, chain.HasHandlers(testDomain))
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
chain.ServeDNS(w, r)
defaultHandler.AssertExpectations(t)
// Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
assert.False(t, chain.HasHandlers(testDomain))
}

View File

@ -3,6 +3,8 @@ package dns
import ( import (
"fmt" "fmt"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
@ -11,6 +13,20 @@ type MockServer struct {
InitializeFunc func() error InitializeFunc func() error
StopFunc func() StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func([]string, dns.Handler, int)
DeregisterHandlerFunc func([]string, int)
}
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
if m.RegisterHandlerFunc != nil {
m.RegisterHandlerFunc(domains, handler, priority)
}
}
func (m *MockServer) DeregisterHandler(domains []string, priority int) {
if m.DeregisterHandlerFunc != nil {
m.DeregisterHandlerFunc(domains, priority)
}
} }
// Initialize mock implementation of Initialize from Server interface // Initialize mock implementation of Initialize from Server interface

View File

@ -30,6 +30,8 @@ type IosDnsManager interface {
// Server is a dns server interface // Server is a dns server interface
type Server interface { type Server interface {
RegisterHandler(domains []string, handler dns.Handler, priority int)
DeregisterHandler(domains []string, priority int)
Initialize() error Initialize() error
Stop() Stop()
DnsIP() string DnsIP() string
@ -48,12 +50,14 @@ type DefaultServer struct {
mux sync.Mutex mux sync.Mutex
service service service service
dnsMuxMap registeredHandlerMap dnsMuxMap registeredHandlerMap
handlerPriorities map[string]int
localResolver *localResolver localResolver *localResolver
wgInterface WGIface wgInterface WGIface
hostManager hostManager hostManager hostManager
updateSerial uint64 updateSerial uint64
previousConfigHash uint64 previousConfigHash uint64
currentConfig HostDNSConfig currentConfig HostDNSConfig
handlerChain *HandlerChain
// permanent related properties // permanent related properties
permanent bool permanent bool
@ -76,6 +80,7 @@ type handlerWithStop interface {
type muxUpdate struct { type muxUpdate struct {
domain string domain string
handler handlerWithStop handler handlerWithStop
priority int
} }
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
@ -138,7 +143,9 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
ctx: ctx, ctx: ctx,
ctxCancel: stop, ctxCancel: stop,
service: dnsService, service: dnsService,
handlerChain: NewHandlerChain(),
dnsMuxMap: make(registeredHandlerMap), dnsMuxMap: make(registeredHandlerMap),
handlerPriorities: make(map[string]int),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
@ -151,6 +158,41 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
return defaultServer return defaultServer
} }
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
s.mux.Lock()
defer s.mux.Unlock()
s.registerHandler(domains, handler, priority)
}
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
log.Debugf("registering handler %s with priority %d", handler, priority)
for _, domain := range domains {
s.handlerChain.AddHandler(domain, handler, priority, nil)
s.handlerPriorities[domain] = priority
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
}
}
func (s *DefaultServer) DeregisterHandler(domains []string, priority int) {
s.mux.Lock()
defer s.mux.Unlock()
s.deregisterHandler(domains, priority)
}
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
for _, domain := range domains {
s.handlerChain.RemoveHandler(domain, priority)
// Only deregister from service if no handlers remain
if !s.handlerChain.HasHandlers(domain) {
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
}
}
}
// Initialize instantiate host manager and the dns service // Initialize instantiate host manager and the dns service
func (s *DefaultServer) Initialize() (err error) { func (s *DefaultServer) Initialize() (err error) {
s.mux.Lock() s.mux.Lock()
@ -343,7 +385,6 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
localRecords := make(map[string]nbdns.SimpleRecord, 0) localRecords := make(map[string]nbdns.SimpleRecord, 0)
for _, customZone := range customZones { for _, customZone := range customZones {
if len(customZone.Records) == 0 { if len(customZone.Records) == 0 {
return nil, nil, fmt.Errorf("received an empty list of records") return nil, nil, fmt.Errorf("received an empty list of records")
} }
@ -351,6 +392,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
muxUpdates = append(muxUpdates, muxUpdate{ muxUpdates = append(muxUpdates, muxUpdate{
domain: customZone.Domain, domain: customZone.Domain,
handler: s.localResolver, handler: s.localResolver,
priority: PriorityMatchDomain,
}) })
for _, record := range customZone.Records { for _, record := range customZone.Records {
@ -414,6 +456,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
muxUpdates = append(muxUpdates, muxUpdate{ muxUpdates = append(muxUpdates, muxUpdate{
domain: nbdns.RootZone, domain: nbdns.RootZone,
handler: handler, handler: handler,
priority: PriorityDefault,
}) })
continue continue
} }
@ -431,6 +474,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
muxUpdates = append(muxUpdates, muxUpdate{ muxUpdates = append(muxUpdates, muxUpdate{
domain: domain, domain: domain,
handler: handler, handler: handler,
priority: PriorityMatchDomain,
}) })
} }
} }
@ -440,12 +484,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
muxUpdateMap := make(registeredHandlerMap) muxUpdateMap := make(registeredHandlerMap)
handlersByPriority := make(map[string]int)
var isContainRootUpdate bool var isContainRootUpdate bool
// First register new handlers
for _, update := range muxUpdates { for _, update := range muxUpdates {
s.service.RegisterMux(update.domain, update.handler) s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.domain] = update.handler muxUpdateMap[update.domain] = update.handler
handlersByPriority[update.domain] = update.priority
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok { if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
existingHandler.stop() existingHandler.stop()
} }
@ -455,6 +503,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
} }
} }
// Then deregister old handlers not in the update
for key, existingHandler := range s.dnsMuxMap { for key, existingHandler := range s.dnsMuxMap {
_, found := muxUpdateMap[key] _, found := muxUpdateMap[key]
if !found { if !found {
@ -463,12 +512,16 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
existingHandler.stop() existingHandler.stop()
} else { } else {
existingHandler.stop() existingHandler.stop()
s.service.DeregisterMux(key) // Deregister with the priority that was used to register
if oldPriority, ok := s.handlerPriorities[key]; ok {
s.deregisterHandler([]string{key}, oldPriority)
}
} }
} }
} }
s.dnsMuxMap = muxUpdateMap s.dnsMuxMap = muxUpdateMap
s.handlerPriorities = handlersByPriority
} }
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) { func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
@ -517,13 +570,13 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary { if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1 removeIndex[nbdns.RootZone] = -1
s.currentConfig.RouteAll = false s.currentConfig.RouteAll = false
s.service.DeregisterMux(nbdns.RootZone) s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault)
} }
for i, item := range s.currentConfig.Domains { for i, item := range s.currentConfig.Domains {
if _, found := removeIndex[item.Domain]; found { if _, found := removeIndex[item.Domain]; found {
s.currentConfig.Domains[i].Disabled = true s.currentConfig.Domains[i].Disabled = true
s.service.DeregisterMux(item.Domain) s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain)
removeIndex[item.Domain] = i removeIndex[item.Domain] = i
} }
} }
@ -554,7 +607,7 @@ func (s *DefaultServer) upstreamCallbacks(
continue continue
} }
s.currentConfig.Domains[i].Disabled = false s.currentConfig.Domains[i].Disabled = false
s.service.RegisterMux(domain, handler) s.registerHandler([]string{domain}, handler, PriorityMatchDomain)
} }
l := log.WithField("nameservers", nsGroup.NameServers) l := log.WithField("nameservers", nsGroup.NameServers)
@ -562,7 +615,7 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary { if nsGroup.Primary {
s.currentConfig.RouteAll = true s.currentConfig.RouteAll = true
s.service.RegisterMux(nbdns.RootZone, handler) s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
} }
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
@ -593,7 +646,8 @@ func (s *DefaultServer) addHostRootZone() {
} }
handler.deactivate = func(error) {} handler.deactivate = func(error) {}
handler.reactivate = func() {} handler.reactivate = func() {}
s.service.RegisterMux(nbdns.RootZone, handler)
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
} }
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {

View File

@ -512,7 +512,7 @@ func TestDNSServerStartStop(t *testing.T) {
t.Error(err) t.Error(err)
} }
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver) dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1)
resolver := &net.Resolver{ resolver := &net.Resolver{
PreferGo: true, PreferGo: true,
@ -560,6 +560,8 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
handlerChain: NewHandlerChain(),
handlerPriorities: make(map[string]int),
hostManager: hostManager, hostManager: hostManager,
currentConfig: HostDNSConfig{ currentConfig: HostDNSConfig{
Domains: []DomainConfig{ Domains: []DomainConfig{

View File

@ -105,6 +105,7 @@ func (s *serviceViaListener) Stop() {
} }
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) { func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
log.Debugf("registering dns handler for pattern: %s", pattern)
s.dnsMux.Handle(pattern, handler) s.dnsMux.Handle(pattern, handler)
} }

View File

@ -66,6 +66,11 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *
} }
} }
// String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("%v", u.upstreamServers)
}
func (u *upstreamResolverBase) stop() { func (u *upstreamResolverBase) stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
u.cancel() u.cancel()

View File

@ -0,0 +1,120 @@
package dnsfwd
import (
"context"
"net"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
)
type DNSForwarder struct {
listenAddress string
ttl uint32
domains []string
dnsServer *dns.Server
mux *dns.ServeMux
}
func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen address: %s, ttl: %d, domains: %v", listenAddress, ttl, domains)
return &DNSForwarder{
listenAddress: listenAddress,
ttl: ttl,
domains: domains,
}
}
func (f *DNSForwarder) Listen() error {
log.Infof("listen DNS forwarder on: %s", f.listenAddress)
mux := dns.NewServeMux()
for _, d := range f.domains {
mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery)
}
dnsServer := &dns.Server{
Addr: f.listenAddress,
Net: "udp",
Handler: mux,
}
f.dnsServer = dnsServer
f.mux = mux
return dnsServer.ListenAndServe()
}
func (f *DNSForwarder) UpdateDomains(domains []string) {
for _, d := range f.domains {
f.mux.HandleRemove(d)
}
for _, d := range f.domains {
f.mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery)
}
f.domains = domains
}
func (f *DNSForwarder) Close(ctx context.Context) error {
if f.dnsServer == nil {
return nil
}
return f.dnsServer.ShutdownContext(ctx)
}
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
if len(query.Question) == 0 {
return
}
log.Tracef("received DNS request for DNS forwarder: %v", query.Question[0].Name)
question := query.Question[0]
domain := question.Name
resp := query.SetReply(query)
ips, err := net.LookupIP(domain)
if err != nil {
log.Warnf("failed to resolve query for domain %s: %v", domain, err)
resp.Rcode = dns.RcodeServerFailure
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write failure DNS response: %v", err)
}
return
}
for _, ip := range ips {
var respRecord dns.RR
if ip.To4() == nil {
log.Tracef("resolved domain %s to IPv6 %s", domain, ip)
rr := dns.AAAA{
AAAA: ip,
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
} else {
log.Tracef("resolved domain %s to IPv4 %s", domain, ip)
rr := dns.A{
A: ip,
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
}
resp.Answer = append(resp.Answer, respRecord)
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}

View File

@ -0,0 +1,106 @@
package dnsfwd
import (
"context"
"fmt"
"net"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
ListenPort = 5353
dnsTTL = 60 //seconds
)
type Manager struct {
firewall firewall.Manager
fwRules []firewall.Rule
dnsForwarder *DNSForwarder
}
func NewManager(fw firewall.Manager) *Manager {
return &Manager{
firewall: fw,
}
}
func (m *Manager) Start(domains []string) error {
log.Infof("starting DNS forwarder")
if m.dnsForwarder != nil {
return nil
}
if err := m.allowDNSFirewall(); err != nil {
return err
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, domains)
go func() {
if err := m.dnsForwarder.Listen(); err != nil {
// todo handle close error if it is exists
log.Errorf("failed to start DNS forwarder, err: %v", err)
}
}()
return nil
}
func (m *Manager) UpdateDomains(domains []string) {
if m.dnsForwarder == nil {
return
}
m.dnsForwarder.UpdateDomains(domains)
}
func (m *Manager) Stop(ctx context.Context) error {
if m.dnsForwarder == nil {
return nil
}
var mErr *multierror.Error
if err := m.dropDNSFirewall(); err != nil {
mErr = multierror.Append(mErr, err)
}
if err := m.dnsForwarder.Close(ctx); err != nil {
mErr = multierror.Append(mErr, err)
}
m.dnsForwarder = nil
return nberrors.FormatErrorOrNil(mErr)
}
func (h *Manager) allowDNSFirewall() error {
dport := &firewall.Port{
IsRange: false,
Values: []int{ListenPort},
}
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
}
h.fwRules = dnsRules
return nil
}
func (h *Manager) dropDNSFirewall() error {
var mErr *multierror.Error
for _, rule := range h.fwRules {
if err := h.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}
h.fwRules = nil
return nberrors.FormatErrorOrNil(mErr)
}

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"maps"
"math/rand" "math/rand"
"net" "net"
"net/netip" "net/netip"
@ -30,10 +29,12 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
@ -117,7 +118,7 @@ type Engine struct {
// mgmClient is a Management Service client // mgmClient is a Management Service client
mgmClient mgm.Client mgmClient mgm.Client
// peerConns is a map that holds all the peers that are known to this peer // peerConns is a map that holds all the peers that are known to this peer
peerConns map[string]*peer.Conn peerStore *peerstore.Store
beforePeerHook nbnet.AddHookFunc beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc afterPeerHook nbnet.RemoveHookFunc
@ -137,10 +138,6 @@ type Engine struct {
TURNs []*stun.URI TURNs []*stun.URI
stunTurn atomic.Value stunTurn atomic.Value
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
clientRoutesMu sync.RWMutex
clientCtx context.Context clientCtx context.Context
clientCancel context.CancelFunc clientCancel context.CancelFunc
@ -164,6 +161,7 @@ type Engine struct {
firewall manager.Manager firewall manager.Manager
routeManager routemanager.Manager routeManager routemanager.Manager
acl acl.Manager acl acl.Manager
dnsForwardMgr *dnsfwd.Manager
dnsServer dns.Server dnsServer dns.Server
@ -234,7 +232,7 @@ func NewEngineWithProbes(
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey), signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
mgmClient: mgmClient, mgmClient: mgmClient,
relayManager: relayManager, relayManager: relayManager,
peerConns: make(map[string]*peer.Conn), peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{}, syncMsgMux: &sync.Mutex{},
config: config, config: config,
mobileDep: mobileDep, mobileDep: mobileDep,
@ -287,6 +285,13 @@ func (e *Engine) Stop() error {
e.routeManager.Stop(e.stateManager) e.routeManager.Stop(e.stateManager)
} }
if e.dnsForwardMgr != nil {
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
}
if e.srWatcher != nil { if e.srWatcher != nil {
e.srWatcher.Close() e.srWatcher.Close()
} }
@ -300,10 +305,6 @@ func (e *Engine) Stop() error {
return fmt.Errorf("failed to remove all peers: %s", err) return fmt.Errorf("failed to remove all peers: %s", err)
} }
e.clientRoutesMu.Lock()
e.clientRoutes = nil
e.clientRoutesMu.Unlock()
if e.cancel != nil { if e.cancel != nil {
e.cancel() e.cancel()
} }
@ -382,6 +383,8 @@ func (e *Engine) Start() error {
e.relayManager, e.relayManager,
initialRoutes, initialRoutes,
e.stateManager, e.stateManager,
dnsServer,
e.peerStore,
) )
beforePeerHook, afterPeerHook, err := e.routeManager.Init() beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil { if err != nil {
@ -460,8 +463,8 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
var modified []*mgmProto.RemotePeerConfig var modified []*mgmProto.RemotePeerConfig
for _, p := range peersUpdate { for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey() peerPubKey := p.GetWgPubKey()
if peerConn, ok := e.peerConns[peerPubKey]; ok { if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok {
if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") { if allowedIPs != strings.Join(p.AllowedIps, ",") {
modified = append(modified, p) modified = append(modified, p)
continue continue
} }
@ -492,17 +495,12 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service. // removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method. // It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
currentPeers := make([]string, 0, len(e.peerConns))
for p := range e.peerConns {
currentPeers = append(currentPeers, p)
}
newPeers := make([]string, 0, len(peersUpdate)) newPeers := make([]string, 0, len(peersUpdate))
for _, p := range peersUpdate { for _, p := range peersUpdate {
newPeers = append(newPeers, p.GetWgPubKey()) newPeers = append(newPeers, p.GetWgPubKey())
} }
toRemove := util.SliceDiff(currentPeers, newPeers) toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
for _, p := range toRemove { for _, p := range toRemove {
err := e.removePeer(p) err := e.removePeer(p)
@ -516,7 +514,7 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
func (e *Engine) removeAllPeers() error { func (e *Engine) removeAllPeers() error {
log.Debugf("removing all peer connections") log.Debugf("removing all peer connections")
for p := range e.peerConns { for _, p := range e.peerStore.PeersPubKey() {
err := e.removePeer(p) err := e.removePeer(p)
if err != nil { if err != nil {
return err return err
@ -540,9 +538,8 @@ func (e *Engine) removePeer(peerKey string) error {
} }
}() }()
conn, exists := e.peerConns[peerKey] conn, exists := e.peerStore.Remove(peerKey)
if exists { if exists {
delete(e.peerConns, peerKey)
conn.Close() conn.Close()
} }
return nil return nil
@ -786,7 +783,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
} }
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't // intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
if networkMap.GetPeerConfig() != nil { if networkMap.GetPeerConfig() != nil {
err := e.updateConfig(networkMap.GetPeerConfig()) err := e.updateConfig(networkMap.GetPeerConfig())
@ -806,19 +802,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.acl.ApplyFiltering(networkMap) e.acl.ApplyFiltering(networkMap)
} }
protoRoutes := networkMap.GetRoutes() routedDomains, routes := toRoutes(networkMap.GetRoutes())
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) if err := e.routeManager.UpdateRoutes(serial, routes); err != nil {
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err) log.Errorf("failed to update clientRoutes, err: %v", err)
} }
e.clientRoutesMu.Lock() // todo: useRoutingPeerDnsResolutionEnabled from network map proto
e.clientRoutes = clientRoutes e.updateDNSForwarder(true, routedDomains)
e.clientRoutesMu.Unlock()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
@ -867,8 +858,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
protoDNSConfig = &mgmProto.DNSConfig{} protoDNSConfig = &mgmProto.DNSConfig{}
} }
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)) if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil {
if err != nil {
log.Errorf("failed to update dns server, err: %v", err) log.Errorf("failed to update dns server, err: %v", err)
} }
@ -881,7 +871,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil return nil
} }
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) {
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
var dnsRoutes []string
routes := make([]*route.Route, 0) routes := make([]*route.Route, 0)
for _, protoRoute := range protoRoutes { for _, protoRoute := range protoRoutes {
var prefix netip.Prefix var prefix netip.Prefix
@ -892,6 +887,8 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
continue continue
} }
} }
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)
convertedRoute := &route.Route{ convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID), ID: route.ID(protoRoute.ID),
Network: prefix, Network: prefix,
@ -905,7 +902,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
} }
routes = append(routes, convertedRoute) routes = append(routes, convertedRoute)
} }
return routes return dnsRoutes, routes
} }
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
@ -982,12 +979,16 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
peerKey := peerConfig.GetWgPubKey() peerKey := peerConfig.GetWgPubKey()
peerIPs := peerConfig.GetAllowedIps() peerIPs := peerConfig.GetAllowedIps()
if _, ok := e.peerConns[peerKey]; !ok { if _, ok := e.peerStore.PeerConn(peerKey); !ok {
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ",")) conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
if err != nil { if err != nil {
return fmt.Errorf("create peer connection: %w", err) return fmt.Errorf("create peer connection: %w", err)
} }
e.peerConns[peerKey] = conn
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
conn.Close()
return fmt.Errorf("peer already exists: %s", peerKey)
}
if e.beforePeerHook != nil && e.afterPeerHook != nil { if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook) conn.AddBeforeAddPeerHook(e.beforePeerHook)
@ -1076,8 +1077,8 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
conn := e.peerConns[msg.Key] conn, ok := e.peerStore.PeerConn(msg.Key)
if conn == nil { if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key) return fmt.Errorf("wrongly addressed message %s", msg.Key)
} }
@ -1135,7 +1136,7 @@ func (e *Engine) receiveSignalEvents() {
return err return err
} }
go conn.OnRemoteCandidate(candidate, e.GetClientRoutes()) go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
case sProto.Body_MODE: case sProto.Body_MODE:
} }
@ -1239,7 +1240,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
routes := toRoutes(netMap.GetRoutes()) _, routes := toRoutes(netMap.GetRoutes())
dnsCfg := toDNSConfig(netMap.GetDNSConfig()) dnsCfg := toDNSConfig(netMap.GetDNSConfig())
return routes, &dnsCfg, nil return routes, &dnsCfg, nil
} }
@ -1322,26 +1323,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
} }
} }
// GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() route.HAMap {
e.clientRoutesMu.RLock()
defer e.clientRoutesMu.RUnlock()
return maps.Clone(e.clientRoutes)
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
e.clientRoutesMu.RLock()
defer e.clientRoutesMu.RUnlock()
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes {
routes[id.NetID()] = v
}
return routes
}
// GetRouteManager returns the route manager // GetRouteManager returns the route manager
func (e *Engine) GetRouteManager() routemanager.Manager { func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager return e.routeManager
@ -1426,9 +1407,8 @@ func (e *Engine) receiveProbeEvents() {
go e.probes.WgProbe.Receive(e.ctx, func() bool { go e.probes.WgProbe.Receive(e.ctx, func() bool {
log.Debug("received wg probe request") log.Debug("received wg probe request")
for _, peer := range e.peerConns { for _, key := range e.peerStore.PeersPubKey() {
key := peer.GetKey() wgStats, err := e.wgInterface.GetStats(key)
wgStats, err := peer.WgConfig().WgInterface.GetStats(key)
if err != nil { if err != nil {
log.Debugf("failed to get wg stats for peer %s: %s", key, err) log.Debugf("failed to get wg stats for peer %s: %s", key, err)
} }
@ -1505,7 +1485,7 @@ func (e *Engine) startNetworkMonitor() {
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) { func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
var vpnRoutes []netip.Prefix var vpnRoutes []netip.Prefix
for _, routes := range e.GetClientRoutes() { for _, routes := range e.routeManager.GetClientRoutes() {
if len(routes) > 0 && routes[0] != nil { if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network) vpnRoutes = append(vpnRoutes, routes[0].Network)
} }
@ -1573,6 +1553,40 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
return nm, nil return nm, nil
} }
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
if !enabled {
if e.dnsForwardMgr == nil {
return
}
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
return
}
if len(domains) > 0 {
log.Infof("enable domain router service for domains: %v", domains)
if e.dnsForwardMgr == nil {
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall)
if err := e.dnsForwardMgr.Start(domains); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
} else {
log.Infof("update domain router service for domains: %v", domains)
e.dnsForwardMgr.UpdateDomains(domains)
}
} else if e.dnsForwardMgr != nil {
log.Infof("disable domain router service")
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
}
}
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks { for _, check := range checks {

View File

@ -252,7 +252,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}, },
} }
engine.wgInterface = wgIface engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil, nil)
_, _, err = engine.routeManager.Init() _, _, err = engine.routeManager.Init()
require.NoError(t, err) require.NoError(t, err)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
@ -392,8 +392,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
return return
} }
if len(engine.peerConns) != c.expectedLen { if len(engine.peerStore.PeersPubKey()) != c.expectedLen {
t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerConns)) t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerStore.PeersPubKey()))
} }
if engine.networkSerial != c.expectedSerial { if engine.networkSerial != c.expectedSerial {
@ -401,7 +401,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
} }
for _, p := range c.expectedPeers { for _, p := range c.expectedPeers {
conn, ok := engine.peerConns[p.GetWgPubKey()] conn, ok := engine.peerStore.PeerConn(p.GetWgPubKey())
if !ok { if !ok {
t.Errorf("expecting Engine.peerConns to contain peer %s", p) t.Errorf("expecting Engine.peerConns to contain peer %s", p)
} }
@ -626,10 +626,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}{} }{}
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
input.inputSerial = updateSerial input.inputSerial = updateSerial
input.inputRoutes = newRoutes input.inputRoutes = newRoutes
return nil, nil, testCase.inputErr return testCase.inputErr
}, },
} }
@ -802,8 +802,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
return nil, nil, nil return nil
}, },
} }
@ -1238,7 +1238,8 @@ func getConnectedPeers(e *Engine) int {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
i := 0 i := 0
for _, conn := range e.peerConns { for _, id := range e.peerStore.PeersPubKey() {
conn, _ := e.peerStore.PeerConn(id)
if conn.Status() == peer.StatusConnected { if conn.Status() == peer.StatusConnected {
i++ i++
} }
@ -1250,5 +1251,5 @@ func getPeers(e *Engine) int {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
return len(e.peerConns) return len(e.peerStore.PeersPubKey())
} }

View File

@ -747,6 +747,11 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
conn.wgProxyRelay = proxy conn.wgProxyRelay = proxy
} }
// AllowedIP returns the allowed IP of the remote peer
func (conn *Conn) AllowedIP() net.IP {
return conn.allowedIP
}
func isController(config ConnConfig) bool { func isController(config ConnConfig) bool {
return config.LocalKey > config.Key return config.LocalKey > config.Key
} }

View File

@ -0,0 +1,87 @@
package peerstore
import (
"net"
"sync"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/peer"
)
// Store is a thread-safe store for peer connections.
type Store struct {
peerConns map[string]*peer.Conn
peerConnsMu sync.RWMutex
}
func NewConnStore() *Store {
return &Store{
peerConns: make(map[string]*peer.Conn),
}
}
func (s *Store) AddPeerConn(pubKey string, conn *peer.Conn) bool {
s.peerConnsMu.Lock()
defer s.peerConnsMu.Unlock()
_, ok := s.peerConns[pubKey]
if ok {
return false
}
s.peerConns[pubKey] = conn
return true
}
func (s *Store) Remove(pubKey string) (*peer.Conn, bool) {
s.peerConnsMu.Lock()
defer s.peerConnsMu.Unlock()
p, ok := s.peerConns[pubKey]
if !ok {
return nil, false
}
delete(s.peerConns, pubKey)
return p, true
}
func (s *Store) AllowedIPs(pubKey string) (string, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return "", false
}
return p.WgConfig().AllowedIps, true
}
func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return nil, false
}
return p.AllowedIP(), true
}
func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return nil, false
}
return p, true
}
func (s *Store) PeersPubKey() []string {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
return maps.Keys(s.peerConns)
}

View File

@ -13,12 +13,16 @@ import (
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static" "github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const useNewDNSRoute = true
type routerPeerStatus struct { type routerPeerStatus struct {
connected bool connected bool
relayed bool relayed bool
@ -53,7 +57,17 @@ type clientNetwork struct {
updateSerial uint64 updateSerial uint64
} }
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork { func newClientNetworkWatcher(
ctx context.Context,
dnsRouteInterval time.Duration,
wgInterface iface.IWGIface,
statusRecorder *peer.Status,
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
) *clientNetwork {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{ client := &clientNetwork{
@ -65,7 +79,16 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
routePeersNotifiers: make(map[string]chan struct{}), routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate), routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}), peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface), handler: handlerFromRoute(
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouteInterval,
statusRecorder,
wgInterface,
dnsServer,
peerStore,
),
} }
return client return client
} }
@ -368,10 +391,37 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
} }
} }
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler { func handlerFromRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsRouterInteval time.Duration,
statusRecorder *peer.Status,
wgInterface iface.IWGIface,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
) RouteHandler {
if rt.IsDynamic() { if rt.IsDynamic() {
if useNewDNSRoute {
return dnsinterceptor.New(
rt,
routeRefCounter,
allowedIPsRefCounter,
statusRecorder,
dnsServer,
peerStore,
)
}
dns := nbdns.NewServiceViaMemory(wgInterface) dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())) return dynamic.NewRoute(
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouterInteval,
statusRecorder,
wgInterface,
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
)
} }
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
} }

View File

@ -0,0 +1,334 @@
package dnsinterceptor
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
type domainMap map[domain.Domain][]netip.Prefix
type DnsInterceptor struct {
mu sync.RWMutex
route *route.Route
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
statusRecorder *peer.Status
dnsServer nbdns.Server
currentPeerKey string
interceptedDomains domainMap
peerStore *peerstore.Store
}
func New(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
statusRecorder *peer.Status,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
) *DnsInterceptor {
return &DnsInterceptor{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
statusRecorder: statusRecorder,
dnsServer: dnsServer,
interceptedDomains: make(domainMap),
peerStore: peerStore,
}
}
func (d *DnsInterceptor) String() string {
return d.route.Domains.SafeString()
}
func (d *DnsInterceptor) AddRoute(context.Context) error {
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute)
return nil
}
func (d *DnsInterceptor) RemoveRoute() error {
d.mu.Lock()
defer d.mu.Unlock()
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
}
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
}
}
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
d.statusRecorder.DeleteResolvedDomainsStates(domain)
}
clear(d.interceptedDomains)
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute)
return nberrors.FormatErrorOrNil(merr)
}
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
d.mu.Lock()
defer d.mu.Unlock()
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
}
}
d.currentPeerKey = peerKey
return nberrors.FormatErrorOrNil(merr)
}
func (d *DnsInterceptor) RemoveAllowedIPs() error {
d.mu.Lock()
defer d.mu.Unlock()
var merr *multierror.Error
for _, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
}
}
d.currentPeerKey = ""
return nberrors.FormatErrorOrNil(merr)
}
// ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
log.Tracef("received DNS request: %v", r.Question[0].Name)
d.mu.RLock()
peerKey := d.currentPeerKey
d.mu.RUnlock()
if peerKey == "" {
log.Debugf("no current peer key set, letting next handler try for %s", r.Question[0].Name)
d.continueToNextHandler(w, r, "no current peer key")
return
}
upstreamIP, err := d.getUpstreamIP(peerKey)
if err != nil {
log.Errorf("failed to get upstream IP: %v", err)
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
return
}
client := &dns.Client{
Timeout: 5 * time.Second,
Net: "udp",
}
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
var answer []dns.RR
if reply != nil {
answer = reply.Answer
}
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, peerKey, r.Question[0].Name, answer)
if err != nil {
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
return
}
reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
}
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Debugf("continuing to next handler for %s: %s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed writing DNS continue response: %v", err)
}
}
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) {
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
if !exists {
return nil, fmt.Errorf("peer connection not found for key: %s", peerKey)
}
return peerAllowedIP, nil
}
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
if r == nil {
return fmt.Errorf("received nil DNS message")
}
if len(r.Answer) > 0 && len(r.Question) > 0 {
// DNS names from miekg/dns are already in punycode format
dom := domain.Domain(r.Question[0].Name)
var newPrefixes []netip.Prefix
for _, answer := range r.Answer {
var ip netip.Addr
switch rr := answer.(type) {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
log.Debugf("failed to convert A record IP: %v", rr.A)
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
log.Debugf("failed to convert AAAA record IP: %v", rr.AAAA)
continue
}
ip = addr
default:
continue
}
prefix := netip.PrefixFrom(ip, ip.BitLen())
newPrefixes = append(newPrefixes, prefix)
}
if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(dom, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
}
}
}
if err := w.WriteMsg(r); err != nil {
return fmt.Errorf("failed to write DNS response: %v", err)
}
return nil
}
func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes []netip.Prefix) error {
d.mu.Lock()
defer d.mu.Unlock()
oldPrefixes := d.interceptedDomains[domain]
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
var merr *multierror.Error
// Add new prefixes
for _, prefix := range toAdd {
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
continue
}
if d.currentPeerKey == "" {
continue
}
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
}
if !d.route.KeepRoute {
// Remove old prefixes
for _, prefix := range toRemove {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
}
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
}
}
}
// Update domain prefixes
if len(toAdd) > 0 || len(toRemove) > 0 {
d.interceptedDomains[domain] = newPrefixes
d.statusRecorder.UpdateResolvedDomainsStates(domain, newPrefixes)
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for [%s]: %s", domain.SafeString(), toAdd)
}
if len(toRemove) > 0 {
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), toRemove)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
prefixSet := make(map[netip.Prefix]bool)
for _, prefix := range oldPrefixes {
prefixSet[prefix] = false
}
for _, prefix := range newPrefixes {
if _, exists := prefixSet[prefix]; exists {
prefixSet[prefix] = true
} else {
toAdd = append(toAdd, prefix)
}
}
for prefix, inUse := range prefixSet {
if !inUse {
toRemove = append(toRemove, prefix)
}
}
return
}

View File

@ -74,11 +74,7 @@ func NewRoute(
} }
func (r *Route) String() string { func (r *Route) String() string {
s, err := r.route.Domains.String() return r.route.Domains.SafeString()
if err != nil {
return r.route.Domains.PunycodeString()
}
return s
} }
func (r *Route) AddRoute(ctx context.Context) error { func (r *Route) AddRoute(ctx context.Context) error {

View File

@ -12,12 +12,15 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
@ -33,9 +36,11 @@ import (
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
TriggerSelection(route.HAMap) TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector GetRouteSelector() *routeselector.RouteSelector
GetClientRoutes() route.HAMap
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener) SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error EnableServerRouter(firewall firewall.Manager) error
@ -60,6 +65,10 @@ type DefaultManager struct {
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
stateManager *statemanager.Manager stateManager *statemanager.Manager
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
dnsServer dns.Server
peerStore *peerstore.Store
} }
func NewManager( func NewManager(
@ -71,6 +80,8 @@ func NewManager(
relayMgr *relayClient.Manager, relayMgr *relayClient.Manager,
initialRoutes []*route.Route, initialRoutes []*route.Route,
stateManager *statemanager.Manager, stateManager *statemanager.Manager,
dnsServer dns.Server,
peerStore *peerstore.Store,
) *DefaultManager { ) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx) mCTX, cancel := context.WithCancel(ctx)
notifier := notifier.NewNotifier() notifier := notifier.NewNotifier()
@ -88,6 +99,8 @@ func NewManager(
pubKey: pubKey, pubKey: pubKey,
notifier: notifier, notifier: notifier,
stateManager: stateManager, stateManager: stateManager,
dnsServer: dnsServer,
peerStore: peerStore,
} }
dm.routeRefCounter = refcounter.New( dm.routeRefCounter = refcounter.New(
@ -116,7 +129,7 @@ func NewManager(
) )
if runtime.GOOS == "android" { if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes) cr := dm.initialClientRoutes(initialRoutes)
dm.notifier.SetInitialClientRoutes(cr) dm.notifier.SetInitialClientRoutes(cr)
} }
return dm return dm
@ -207,15 +220,21 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
} }
m.ctx = nil m.ctx = nil
m.mux.Lock()
defer m.mux.Unlock()
m.clientRoutes = nil
} }
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
log.Infof("not updating routes as context is closed") log.Infof("not updating routes as context is closed")
return nil, nil, m.ctx.Err() return nil
default: default:
}
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
@ -228,12 +247,13 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
if m.serverRouter != nil { if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap) err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("update routes: %w", err) return err
} }
} }
return newServerRoutesMap, newClientRoutesIDMap, nil m.clientRoutes = newClientRoutesIDMap
}
return nil
} }
// SetRouteChangeListener set RouteListener for route change Notifier // SetRouteChangeListener set RouteListener for route change Notifier
@ -251,9 +271,24 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
return m.routeSelector return m.routeSelector
} }
// GetClientRoutes returns the client routes // GetClientRoutes returns most recent list of clientRoutes received from the Management Service
func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork { func (m *DefaultManager) GetClientRoutes() route.HAMap {
return m.clientNetworks m.mux.Lock()
defer m.mux.Unlock()
return maps.Clone(m.clientRoutes)
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
m.mux.Lock()
defer m.mux.Unlock()
routes := make(map[route.NetID][]*route.Route, len(m.clientRoutes))
for id, v := range m.clientRoutes {
routes[id.NetID()] = v
}
return routes
} }
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones // TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
@ -273,7 +308,17 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
continue continue
} }
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter) clientNetworkWatcher := newClientNetworkWatcher(
m.ctx,
m.dnsRouteInterval,
m.wgInterface,
m.statusRecorder,
routes[0],
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsServer,
m.peerStore,
)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
@ -302,7 +347,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks { for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id] clientNetworkWatcher, found := m.clientNetworks[id]
if !found { if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter) clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerStore)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
} }
@ -345,7 +390,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
return newServerRoutesMap, newClientRoutesIDMap return newServerRoutesMap, newClientRoutesIDMap
} }
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route { func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.classifyRoutes(initialRoutes) _, crMap := m.classifyRoutes(initialRoutes)
rs := make([]*route.Route, 0, len(crMap)) rs := make([]*route.Route, 0, len(crMap))
for _, routes := range crMap { for _, routes := range crMap {

View File

@ -424,7 +424,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm") statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO() ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil) routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil, nil, nil)
_, _, err = routeManager.Init() _, _, err = routeManager.Init()
@ -436,11 +436,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
} }
if len(testCase.inputInitRoutes) > 0 { if len(testCase.inputInitRoutes) > 0 {
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) _ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
require.NoError(t, err, "should update routes with init routes") require.NoError(t, err, "should update routes with init routes")
} }
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) _ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
require.NoError(t, err, "should update routes") require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected expectedWatchers := testCase.clientNetworkWatchersExpected

View File

@ -2,7 +2,6 @@ package routemanager
import ( import (
"context" "context"
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
@ -15,9 +14,11 @@ import (
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
type MockManager struct { type MockManager struct {
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
TriggerSelectionFunc func(haMap route.HAMap) TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
StopFunc func(manager *statemanager.Manager) StopFunc func(manager *statemanager.Manager)
} }
@ -31,11 +32,11 @@ func (m *MockManager) InitialRouteRange() []string {
} }
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface // UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
if m.UpdateRoutesFunc != nil { if m.UpdateRoutesFunc != nil {
return m.UpdateRoutesFunc(updateSerial, newRoutes) return m.UpdateRoutesFunc(updateSerial, newRoutes)
} }
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented") return nil
} }
func (m *MockManager) TriggerSelection(networks route.HAMap) { func (m *MockManager) TriggerSelection(networks route.HAMap) {
@ -52,6 +53,22 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
return nil return nil
} }
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
func (m *MockManager) GetClientRoutes() route.HAMap {
if m.GetClientRoutesFunc != nil {
return m.GetClientRoutesFunc()
}
return nil
}
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
if m.GetClientRoutesWithNetIDFunc != nil {
return m.GetClientRoutesWithNetIDFunc()
}
return nil
}
// Start mock implementation of Start from Manager interface // Start mock implementation of Start from Manager interface
func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) { func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) {
} }

View File

@ -272,8 +272,8 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
return nil, fmt.Errorf("not connected") return nil, fmt.Errorf("not connected")
} }
routesMap := engine.GetClientRoutesWithNetID()
routeManager := engine.GetRouteManager() routeManager := engine.GetRouteManager()
routesMap := routeManager.GetClientRoutesWithNetID()
if routeManager == nil { if routeManager == nil {
return nil, fmt.Errorf("could not get route manager") return nil, fmt.Errorf("could not get route manager")
} }
@ -365,12 +365,12 @@ func (c *Client) SelectRoute(id string) error {
} else { } else {
log.Debugf("select route with id: %s", id) log.Debugf("select route with id: %s", id)
routes := toNetIDs([]string{id}) routes := toNetIDs([]string{id})
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
log.Debugf("error when selecting routes: %s", err) log.Debugf("error when selecting routes: %s", err)
return fmt.Errorf("select routes: %w", err) return fmt.Errorf("select routes: %w", err)
} }
} }
routeManager.TriggerSelection(engine.GetClientRoutes()) routeManager.TriggerSelection(routeManager.GetClientRoutes())
return nil return nil
} }
@ -392,12 +392,12 @@ func (c *Client) DeselectRoute(id string) error {
} else { } else {
log.Debugf("deselect route with id: %s", id) log.Debugf("deselect route with id: %s", id)
routes := toNetIDs([]string{id}) routes := toNetIDs([]string{id})
if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
log.Debugf("error when deselecting routes: %s", err) log.Debugf("error when deselecting routes: %s", err)
return fmt.Errorf("deselect routes: %w", err) return fmt.Errorf("deselect routes: %w", err)
} }
} }
routeManager.TriggerSelection(engine.GetClientRoutes()) routeManager.TriggerSelection(routeManager.GetClientRoutes())
return nil return nil
} }

View File

@ -34,7 +34,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
return nil, fmt.Errorf("not connected") return nil, fmt.Errorf("not connected")
} }
routesMap := engine.GetClientRoutesWithNetID() routesMap := engine.GetRouteManager().GetClientRoutesWithNetID()
routeSelector := engine.GetRouteManager().GetRouteSelector() routeSelector := engine.GetRouteManager().GetRouteSelector()
var routes []*selectRoute var routes []*selectRoute
@ -116,11 +116,12 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
routeSelector.SelectAllRoutes() routeSelector.SelectAllRoutes()
} else { } else {
routes := toNetIDs(req.GetNetworkIDs()) routes := toNetIDs(req.GetNetworkIDs())
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID())
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil {
return nil, fmt.Errorf("select routes: %w", err) return nil, fmt.Errorf("select routes: %w", err)
} }
} }
routeManager.TriggerSelection(engine.GetClientRoutes()) routeManager.TriggerSelection(routeManager.GetClientRoutes())
return &proto.SelectNetworksResponse{}, nil return &proto.SelectNetworksResponse{}, nil
} }
@ -145,11 +146,12 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
routeSelector.DeselectAllRoutes() routeSelector.DeselectAllRoutes()
} else { } else {
routes := toNetIDs(req.GetNetworkIDs()) routes := toNetIDs(req.GetNetworkIDs())
if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID())
if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil {
return nil, fmt.Errorf("deselect routes: %w", err) return nil, fmt.Errorf("deselect routes: %w", err)
} }
} }
routeManager.TriggerSelection(engine.GetClientRoutes()) routeManager.TriggerSelection(routeManager.GetClientRoutes())
return &proto.SelectNetworksResponse{}, nil return &proto.SelectNetworksResponse{}, nil
} }

View File

@ -108,3 +108,9 @@ func GetParsedDomainLabel(name string) (string, error) {
return validHost, nil return validHost, nil
} }
// NormalizeZone returns a normalized domain name without the wildcard prefix
func NormalizeZone(domain string) string {
d, _ := strings.CutPrefix(domain, "*.")
return d
}

1
go.mod
View File

@ -207,6 +207,7 @@ require (
github.com/spf13/cast v1.5.0 // indirect github.com/spf13/cast v1.5.0 // indirect
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.8.0 // indirect github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect github.com/vishvananda/netns v0.0.4 // indirect

1
go.sum
View File

@ -662,6 +662,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=

View File

@ -360,7 +360,7 @@ func validateDomains(domains []string) (domain.List, error) {
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
} }
domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
var domainList domain.List var domainList domain.List

View File

@ -330,6 +330,14 @@ func TestRoutesHandlers(t *testing.T) {
expectedStatus: http.StatusUnprocessableEntity, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{
name: "POST Wildcard Domain",
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["*.example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID)),
expectedStatus: http.StatusOK,
expectedBody: false,
},
{ {
name: "POST UnprocessableEntity when both network and domains are provided", name: "POST UnprocessableEntity when both network and domains are provided",
requestType: http.MethodPost, requestType: http.MethodPost,
@ -609,6 +617,30 @@ func TestValidateDomains(t *testing.T) {
expected: domain.List{"google.com"}, expected: domain.List{"google.com"},
wantErr: true, wantErr: true,
}, },
{
name: "Valid wildcard domain",
domains: []string{"*.example.com"},
expected: domain.List{"*.example.com"},
wantErr: false,
},
{
name: "Wildcard with dot domain",
domains: []string{".*.example.com"},
expected: nil,
wantErr: true,
},
{
name: "Wildcard with dot domain",
domains: []string{".*.example.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid wildcard domain",
domains: []string{"a.*.example.com"},
expected: nil,
wantErr: true,
},
} }
for _, tt := range tests { for _, tt := range tests {