From c91d7808bf4383e1e76196ac70f24f7cc8e9012b Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 13 Dec 2024 14:17:10 +0100 Subject: [PATCH] Add dns interceptor based domain route functionality (#3032) --- .github/workflows/golangci-lint.yml | 4 +- client/internal/dns/handler_chain.go | 192 +++++++ client/internal/dns/handler_chain_test.go | 490 ++++++++++++++++++ client/internal/dns/mock_server.go | 22 +- client/internal/dns/server.go | 94 +++- client/internal/dns/server_test.go | 6 +- client/internal/dns/service_listener.go | 1 + client/internal/dns/upstream.go | 5 + client/internal/dnsfwd/forwarder.go | 120 +++++ client/internal/dnsfwd/manager.go | 106 ++++ client/internal/engine.go | 152 +++--- client/internal/engine_test.go | 21 +- client/internal/peer/conn.go | 5 + client/internal/peerstore/store.go | 87 ++++ client/internal/routemanager/client.go | 58 ++- .../routemanager/dnsinterceptor/handler.go | 334 ++++++++++++ client/internal/routemanager/dynamic/route.go | 6 +- client/internal/routemanager/manager.go | 99 +++- client/internal/routemanager/manager_test.go | 6 +- client/internal/routemanager/mock.go | 31 +- client/ios/NetBirdSDK/client.go | 10 +- client/server/network.go | 12 +- dns/dns.go | 6 + go.mod | 1 + go.sum | 1 + .../http/handlers/routes/routes_handler.go | 2 +- .../handlers/routes/routes_handler_test.go | 32 ++ 27 files changed, 1740 insertions(+), 163 deletions(-) create mode 100644 client/internal/dns/handler_chain.go create mode 100644 client/internal/dns/handler_chain_test.go create mode 100644 client/internal/dnsfwd/forwarder.go create mode 100644 client/internal/dnsfwd/manager.go create mode 100644 client/internal/peerstore/store.go create mode 100644 client/internal/routemanager/dnsinterceptor/handler.go diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index dacb1922b..89defce32 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -46,7 +46,7 @@ jobs: if: matrix.os == 'ubuntu-latest' run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v4 with: version: latest - args: --timeout=12m + args: --timeout=12m --out-format colored-line-number diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go new file mode 100644 index 000000000..4a525844b --- /dev/null +++ b/client/internal/dns/handler_chain.go @@ -0,0 +1,192 @@ +package dns + +import ( + "strings" + "sync" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" +) + +const ( + PriorityDNSRoute = 100 + PriorityMatchDomain = 50 + PriorityDefault = 0 +) + +type HandlerEntry struct { + Handler dns.Handler + Priority int + Pattern string + OrigPattern string + IsWildcard bool + StopHandler handlerWithStop +} + +// HandlerChain represents a prioritized chain of DNS handlers +type HandlerChain struct { + mu sync.RWMutex + handlers []HandlerEntry +} + +// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain +type ResponseWriterChain struct { + dns.ResponseWriter + shouldContinue bool +} + +func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error { + // Check if this is a continue signal (NXDOMAIN with Zero bit set) + if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero { + w.shouldContinue = true + return nil + } + return w.ResponseWriter.WriteMsg(m) +} + +func NewHandlerChain() *HandlerChain { + return &HandlerChain{ + handlers: make([]HandlerEntry, 0), + } +} + +// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority +func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) { + c.mu.Lock() + defer c.mu.Unlock() + + origPattern := pattern + isWildcard := strings.HasPrefix(pattern, "*.") + if isWildcard { + pattern = pattern[2:] + } + pattern = dns.Fqdn(pattern) + origPattern = dns.Fqdn(origPattern) + + // First remove any existing handler with same original pattern and priority + for i := len(c.handlers) - 1; i >= 0; i-- { + if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority { + if c.handlers[i].StopHandler != nil { + c.handlers[i].StopHandler.stop() + } + c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) + break + } + } + + log.Debugf("adding handler for pattern: %s (original: %s, wildcard: %v) with priority %d", + pattern, origPattern, isWildcard, priority) + + entry := HandlerEntry{ + Handler: handler, + Priority: priority, + Pattern: pattern, + OrigPattern: origPattern, + IsWildcard: isWildcard, + StopHandler: stopHandler, + } + + // Insert handler in priority order + pos := 0 + for i, h := range c.handlers { + if h.Priority < priority { + pos = i + break + } + pos = i + 1 + } + + c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...) +} + +// RemoveHandler removes a handler for the given pattern and priority +func (c *HandlerChain) RemoveHandler(pattern string, priority int) { + c.mu.Lock() + defer c.mu.Unlock() + + pattern = dns.Fqdn(pattern) + + // Find and remove handlers matching both original pattern and priority + for i := len(c.handlers) - 1; i >= 0; i-- { + entry := c.handlers[i] + if entry.OrigPattern == pattern && entry.Priority == priority { + if entry.StopHandler != nil { + entry.StopHandler.stop() + } + c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) + return + } + } +} + +// HasHandlers returns true if there are any handlers remaining for the given pattern +func (c *HandlerChain) HasHandlers(pattern string) bool { + c.mu.RLock() + defer c.mu.RUnlock() + + pattern = dns.Fqdn(pattern) + for _, entry := range c.handlers { + if entry.Pattern == pattern { + return true + } + } + return false +} + +func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + return + } + + qname := r.Question[0].Name + log.Debugf("handling DNS request for %s", qname) + + c.mu.RLock() + defer c.mu.RUnlock() + + log.Debugf("current handlers (%d):", len(c.handlers)) + for _, h := range c.handlers { + log.Debugf(" - pattern: %s, original: %s, wildcard: %v, priority: %d", + h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority) + } + + // Try handlers in priority order + for _, entry := range c.handlers { + var matched bool + switch { + case entry.Pattern == ".": + matched = true + case entry.IsWildcard: + parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".") + matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern) + default: + matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern) + } + + if !matched { + log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false", + entry.OrigPattern, qname, entry.IsWildcard) + continue + } + + log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v", + entry.OrigPattern, qname, entry.IsWildcard) + chainWriter := &ResponseWriterChain{ResponseWriter: w} + entry.Handler.ServeDNS(chainWriter, r) + + // If handler wants to continue, try next handler + if chainWriter.shouldContinue { + log.Debugf("handler requested continue to next handler") + continue + } + return + } + + // No handler matched or all handlers passed + log.Debugf("no handler found for %s", qname) + resp := &dns.Msg{} + resp.SetRcode(r, dns.RcodeNameError) + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } +} diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go new file mode 100644 index 000000000..01ed5f4e7 --- /dev/null +++ b/client/internal/dns/handler_chain_test.go @@ -0,0 +1,490 @@ +package dns_test + +import ( + "net" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + nbdns "github.com/netbirdio/netbird/client/internal/dns" +) + +// MockHandler implements dns.Handler interface for testing +type MockHandler struct { + mock.Mock +} + +func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + m.Called(w, r) +} + +// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order +func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { + chain := nbdns.NewHandlerChain() + + // Create mock handlers for different priorities + defaultHandler := &MockHandler{} + matchDomainHandler := &MockHandler{} + dnsRouteHandler := &MockHandler{} + + // Setup handlers with different priorities + chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil) + chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil) + chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil) + + // Create test request + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + // Create test writer + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + // Setup expectations - only highest priority handler should be called + dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once() + matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe() + defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() + + // Execute + chain.ServeDNS(w, r) + + // Verify all expectations were met + dnsRouteHandler.AssertExpectations(t) + matchDomainHandler.AssertExpectations(t) + defaultHandler.AssertExpectations(t) +} + +// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios +func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { + tests := []struct { + name string + handlerDomain string + queryDomain string + isWildcard bool + shouldMatch bool + }{ + { + name: "exact match", + handlerDomain: "example.com.", + queryDomain: "example.com.", + isWildcard: false, + shouldMatch: true, + }, + { + name: "subdomain with non-wildcard", + handlerDomain: "example.com.", + queryDomain: "sub.example.com.", + isWildcard: false, + shouldMatch: true, + }, + { + name: "wildcard match", + handlerDomain: "*.example.com.", + queryDomain: "sub.example.com.", + isWildcard: true, + shouldMatch: true, + }, + { + name: "wildcard no match on apex", + handlerDomain: "*.example.com.", + queryDomain: "example.com.", + isWildcard: true, + shouldMatch: false, + }, + { + name: "root zone match", + handlerDomain: ".", + queryDomain: "anything.com.", + isWildcard: false, + shouldMatch: true, + }, + { + name: "no match different domain", + handlerDomain: "example.com.", + queryDomain: "example.org.", + isWildcard: false, + shouldMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + mockHandler := &MockHandler{} + + pattern := tt.handlerDomain + if tt.isWildcard { + pattern = "*." + tt.handlerDomain[2:] // Remove the first two chars if it's a wildcard + } + + chain.AddHandler(pattern, mockHandler, nbdns.PriorityDefault, nil) + + r := new(dns.Msg) + r.SetQuestion(tt.queryDomain, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + if tt.shouldMatch { + mockHandler.On("ServeDNS", mock.Anything, r).Once() + } + + chain.ServeDNS(w, r) + mockHandler.AssertExpectations(t) + }) + } +} + +// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns +func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { + tests := []struct { + name string + handlers []struct { + pattern string + priority int + } + queryDomain string + expectedCalls int + expectedHandler int // index of the handler that should be called + }{ + { + name: "wildcard and exact same priority - exact should win", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "example.com.", priority: nbdns.PriorityDefault}, + }, + queryDomain: "example.com.", + expectedCalls: 1, + expectedHandler: 1, // exact match handler should be called + }, + { + name: "higher priority wildcard over lower priority exact", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "example.com.", priority: nbdns.PriorityDefault}, + {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "test.example.com.", + expectedCalls: 1, + expectedHandler: 1, // higher priority wildcard handler should be called + }, + { + name: "multiple wildcards different priorities", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "test.example.com.", + expectedCalls: 1, + expectedHandler: 2, // highest priority handler should be called + }, + { + name: "subdomain with mix of patterns", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "sub.test.example.com.", + expectedCalls: 1, + expectedHandler: 2, // highest priority matching handler should be called + }, + { + name: "root zone with specific domain", + handlers: []struct { + pattern string + priority int + }{ + {pattern: ".", priority: nbdns.PriorityDefault}, + {pattern: "example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "example.com.", + expectedCalls: 1, + expectedHandler: 1, // higher priority specific domain should win over root + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + var handlers []*MockHandler + + // Setup handlers and expectations + for i := range tt.handlers { + handler := &MockHandler{} + handlers = append(handlers, handler) + + // Set expectation based on whether this handler should be called + if i == tt.expectedHandler { + handler.On("ServeDNS", mock.Anything, mock.Anything).Once() + } else { + handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe() + } + + chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil) + } + + // Create and execute request + r := new(dns.Msg) + r.SetQuestion(tt.queryDomain, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + chain.ServeDNS(w, r) + + // Verify expectations + for _, handler := range handlers { + handler.AssertExpectations(t) + } + }) + } +} + +// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality +func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { + chain := nbdns.NewHandlerChain() + + // Create handlers + handler1 := &MockHandler{} + handler2 := &MockHandler{} + handler3 := &MockHandler{} + + // Add handlers in priority order + chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil) + chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil) + chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil) + + // Create test request + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + // Setup mock responses to simulate chain continuation + handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // First handler signals continue + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true // Signal to continue + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // Second handler signals continue + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // Last handler responds normally + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeSuccess) + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + // Execute + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + chain.ServeDNS(w, r) + + // Verify all handlers were called in order + handler1.AssertExpectations(t) + handler2.AssertExpectations(t) + handler3.AssertExpectations(t) +} + +// mockResponseWriter implements dns.ResponseWriter for testing +type mockResponseWriter struct { + mock.Mock +} + +func (m *mockResponseWriter) LocalAddr() net.Addr { return nil } +func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil } +func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil } +func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } +func (m *mockResponseWriter) Close() error { return nil } +func (m *mockResponseWriter) TsigStatus() error { return nil } +func (m *mockResponseWriter) TsigTimersOnly(bool) {} +func (m *mockResponseWriter) Hijack() {} + +func TestHandlerChain_PriorityDeregistration(t *testing.T) { + tests := []struct { + name string + ops []struct { + action string // "add" or "remove" + pattern string + priority int + } + query string + expectedCalls map[int]bool // map[priority]shouldBeCalled + }{ + { + name: "remove high priority keeps lower priority handler", + ops: []struct { + action string + pattern string + priority int + }{ + {"add", "example.com.", nbdns.PriorityDNSRoute}, + {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"remove", "example.com.", nbdns.PriorityDNSRoute}, + }, + query: "example.com.", + expectedCalls: map[int]bool{ + nbdns.PriorityDNSRoute: false, + nbdns.PriorityMatchDomain: true, + }, + }, + { + name: "remove lower priority keeps high priority handler", + ops: []struct { + action string + pattern string + priority int + }{ + {"add", "example.com.", nbdns.PriorityDNSRoute}, + {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"remove", "example.com.", nbdns.PriorityMatchDomain}, + }, + query: "example.com.", + expectedCalls: map[int]bool{ + nbdns.PriorityDNSRoute: true, + nbdns.PriorityMatchDomain: false, + }, + }, + { + name: "remove all handlers in order", + ops: []struct { + action string + pattern string + priority int + }{ + {"add", "example.com.", nbdns.PriorityDNSRoute}, + {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"add", "example.com.", nbdns.PriorityDefault}, + {"remove", "example.com.", nbdns.PriorityDNSRoute}, + {"remove", "example.com.", nbdns.PriorityMatchDomain}, + }, + query: "example.com.", + expectedCalls: map[int]bool{ + nbdns.PriorityDNSRoute: false, + nbdns.PriorityMatchDomain: false, + nbdns.PriorityDefault: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + handlers := make(map[int]*MockHandler) + + // Execute operations + for _, op := range tt.ops { + if op.action == "add" { + handler := &MockHandler{} + handlers[op.priority] = handler + chain.AddHandler(op.pattern, handler, op.priority, nil) + } else { + chain.RemoveHandler(op.pattern, op.priority) + } + } + + // Create test request + r := new(dns.Msg) + r.SetQuestion(tt.query, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + // Setup expectations + for priority, handler := range handlers { + if shouldCall, exists := tt.expectedCalls[priority]; exists && shouldCall { + handler.On("ServeDNS", mock.Anything, r).Once() + } else { + handler.On("ServeDNS", mock.Anything, r).Maybe() + } + } + + // Execute request + chain.ServeDNS(w, r) + + // Verify expectations + for _, handler := range handlers { + handler.AssertExpectations(t) + } + + // Verify handler exists check + for priority, shouldExist := range tt.expectedCalls { + if shouldExist { + assert.True(t, chain.HasHandlers(tt.ops[0].pattern), + "Handler chain should have handlers for pattern after removing priority %d", priority) + } + } + }) + } +} + +func TestHandlerChain_MultiPriorityHandling(t *testing.T) { + chain := nbdns.NewHandlerChain() + + testDomain := "example.com." + testQuery := "test.example.com." + + // Create handlers for three priority levels + routeHandler := &MockHandler{} + matchHandler := &MockHandler{} + defaultHandler := &MockHandler{} + + // Create test request that will be reused + r := new(dns.Msg) + r.SetQuestion(testQuery, dns.TypeA) + + // Add handlers in mixed order + chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil) + chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil) + chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil) + + // Test 1: Initial state with all three handlers + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + // Highest priority handler (routeHandler) should be called + routeHandler.On("ServeDNS", mock.Anything, r).Return().Once() + + chain.ServeDNS(w, r) + routeHandler.AssertExpectations(t) + + // Test 2: Remove highest priority handler + chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute) + assert.True(t, chain.HasHandlers(testDomain)) + + w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + // Now middle priority handler (matchHandler) should be called + matchHandler.On("ServeDNS", mock.Anything, r).Return().Once() + + chain.ServeDNS(w, r) + matchHandler.AssertExpectations(t) + + // Test 3: Remove middle priority handler + chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) + assert.True(t, chain.HasHandlers(testDomain)) + + w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + // Now lowest priority handler (defaultHandler) should be called + defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once() + + chain.ServeDNS(w, r) + defaultHandler.AssertExpectations(t) + + // Test 4: Remove last handler + chain.RemoveHandler(testDomain, nbdns.PriorityDefault) + assert.False(t, chain.HasHandlers(testDomain)) +} diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 0739f0542..7e36ea5df 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -3,14 +3,30 @@ package dns import ( "fmt" + "github.com/miekg/dns" + nbdns "github.com/netbirdio/netbird/dns" ) // MockServer is the mock instance of a dns server type MockServer struct { - InitializeFunc func() error - StopFunc func() - UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + InitializeFunc func() error + StopFunc func() + UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + RegisterHandlerFunc func([]string, dns.Handler, int) + DeregisterHandlerFunc func([]string, int) +} + +func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { + if m.RegisterHandlerFunc != nil { + m.RegisterHandlerFunc(domains, handler, priority) + } +} + +func (m *MockServer) DeregisterHandler(domains []string, priority int) { + if m.DeregisterHandlerFunc != nil { + m.DeregisterHandlerFunc(domains, priority) + } } // Initialize mock implementation of Initialize from Server interface diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index f0277319c..55043c5b2 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -30,6 +30,8 @@ type IosDnsManager interface { // Server is a dns server interface type Server interface { + RegisterHandler(domains []string, handler dns.Handler, priority int) + DeregisterHandler(domains []string, priority int) Initialize() error Stop() DnsIP() string @@ -48,12 +50,14 @@ type DefaultServer struct { mux sync.Mutex service service dnsMuxMap registeredHandlerMap + handlerPriorities map[string]int localResolver *localResolver wgInterface WGIface hostManager hostManager updateSerial uint64 previousConfigHash uint64 currentConfig HostDNSConfig + handlerChain *HandlerChain // permanent related properties permanent bool @@ -74,8 +78,9 @@ type handlerWithStop interface { } type muxUpdate struct { - domain string - handler handlerWithStop + domain string + handler handlerWithStop + priority int } // NewDefaultServer returns a new dns server @@ -135,10 +140,12 @@ func NewDefaultServerIos( func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer { ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - service: dnsService, - dnsMuxMap: make(registeredHandlerMap), + ctx: ctx, + ctxCancel: stop, + service: dnsService, + handlerChain: NewHandlerChain(), + dnsMuxMap: make(registeredHandlerMap), + handlerPriorities: make(map[string]int), localResolver: &localResolver{ registeredMap: make(registrationMap), }, @@ -151,6 +158,41 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi return defaultServer } +func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { + s.mux.Lock() + defer s.mux.Unlock() + + s.registerHandler(domains, handler, priority) +} + +func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { + log.Debugf("registering handler %s with priority %d", handler, priority) + + for _, domain := range domains { + s.handlerChain.AddHandler(domain, handler, priority, nil) + s.handlerPriorities[domain] = priority + s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain) + } +} + +func (s *DefaultServer) DeregisterHandler(domains []string, priority int) { + s.mux.Lock() + defer s.mux.Unlock() + + s.deregisterHandler(domains, priority) +} + +func (s *DefaultServer) deregisterHandler(domains []string, priority int) { + for _, domain := range domains { + s.handlerChain.RemoveHandler(domain, priority) + + // Only deregister from service if no handlers remain + if !s.handlerChain.HasHandlers(domain) { + s.service.DeregisterMux(nbdns.NormalizeZone(domain)) + } + } +} + // Initialize instantiate host manager and the dns service func (s *DefaultServer) Initialize() (err error) { s.mux.Lock() @@ -343,14 +385,14 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) localRecords := make(map[string]nbdns.SimpleRecord, 0) for _, customZone := range customZones { - if len(customZone.Records) == 0 { return nil, nil, fmt.Errorf("received an empty list of records") } muxUpdates = append(muxUpdates, muxUpdate{ - domain: customZone.Domain, - handler: s.localResolver, + domain: customZone.Domain, + handler: s.localResolver, + priority: PriorityMatchDomain, }) for _, record := range customZone.Records { @@ -412,8 +454,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam if nsGroup.Primary { muxUpdates = append(muxUpdates, muxUpdate{ - domain: nbdns.RootZone, - handler: handler, + domain: nbdns.RootZone, + handler: handler, + priority: PriorityDefault, }) continue } @@ -429,8 +472,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam return nil, fmt.Errorf("received a nameserver group with an empty domain element") } muxUpdates = append(muxUpdates, muxUpdate{ - domain: domain, - handler: handler, + domain: domain, + handler: handler, + priority: PriorityMatchDomain, }) } } @@ -440,12 +484,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { muxUpdateMap := make(registeredHandlerMap) + handlersByPriority := make(map[string]int) var isContainRootUpdate bool + // First register new handlers for _, update := range muxUpdates { - s.service.RegisterMux(update.domain, update.handler) + s.registerHandler([]string{update.domain}, update.handler, update.priority) muxUpdateMap[update.domain] = update.handler + handlersByPriority[update.domain] = update.priority + if existingHandler, ok := s.dnsMuxMap[update.domain]; ok { existingHandler.stop() } @@ -455,6 +503,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { } } + // Then deregister old handlers not in the update for key, existingHandler := range s.dnsMuxMap { _, found := muxUpdateMap[key] if !found { @@ -463,12 +512,16 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { existingHandler.stop() } else { existingHandler.stop() - s.service.DeregisterMux(key) + // Deregister with the priority that was used to register + if oldPriority, ok := s.handlerPriorities[key]; ok { + s.deregisterHandler([]string{key}, oldPriority) + } } } } s.dnsMuxMap = muxUpdateMap + s.handlerPriorities = handlersByPriority } func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) { @@ -517,13 +570,13 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { removeIndex[nbdns.RootZone] = -1 s.currentConfig.RouteAll = false - s.service.DeregisterMux(nbdns.RootZone) + s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault) } for i, item := range s.currentConfig.Domains { if _, found := removeIndex[item.Domain]; found { s.currentConfig.Domains[i].Disabled = true - s.service.DeregisterMux(item.Domain) + s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain) removeIndex[item.Domain] = i } } @@ -554,7 +607,7 @@ func (s *DefaultServer) upstreamCallbacks( continue } s.currentConfig.Domains[i].Disabled = false - s.service.RegisterMux(domain, handler) + s.registerHandler([]string{domain}, handler, PriorityMatchDomain) } l := log.WithField("nameservers", nsGroup.NameServers) @@ -562,7 +615,7 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { s.currentConfig.RouteAll = true - s.service.RegisterMux(nbdns.RootZone, handler) + s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) } if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") @@ -593,7 +646,8 @@ func (s *DefaultServer) addHostRootZone() { } handler.deactivate = func(error) {} handler.reactivate = func() {} - s.service.RegisterMux(nbdns.RootZone, handler) + + s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) } func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index eab9f4ecb..aca7653a3 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -512,7 +512,7 @@ func TestDNSServerStartStop(t *testing.T) { t.Error(err) } - dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver) + dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1) resolver := &net.Resolver{ PreferGo: true, @@ -560,7 +560,9 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { localResolver: &localResolver{ registeredMap: make(registrationMap), }, - hostManager: hostManager, + handlerChain: NewHandlerChain(), + handlerPriorities: make(map[string]int), + hostManager: hostManager, currentConfig: HostDNSConfig{ Domains: []DomainConfig{ {false, "domain0", false}, diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index e0f9da26f..72dc4bc6e 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -105,6 +105,7 @@ func (s *serviceViaListener) Stop() { } func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) { + log.Debugf("registering dns handler for pattern: %s", pattern) s.dnsMux.Handle(pattern, handler) } diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index b3baf2fa8..94497c61f 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -66,6 +66,11 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) * } } +// String returns a string representation of the upstream resolver +func (u *upstreamResolverBase) String() string { + return fmt.Sprintf("%v", u.upstreamServers) +} + func (u *upstreamResolverBase) stop() { log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) u.cancel() diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go new file mode 100644 index 000000000..dd9636158 --- /dev/null +++ b/client/internal/dnsfwd/forwarder.go @@ -0,0 +1,120 @@ +package dnsfwd + +import ( + "context" + "net" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" +) + +type DNSForwarder struct { + listenAddress string + ttl uint32 + domains []string + + dnsServer *dns.Server + mux *dns.ServeMux +} + +func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder { + log.Debugf("creating DNS forwarder with listen address: %s, ttl: %d, domains: %v", listenAddress, ttl, domains) + return &DNSForwarder{ + listenAddress: listenAddress, + ttl: ttl, + domains: domains, + } +} +func (f *DNSForwarder) Listen() error { + log.Infof("listen DNS forwarder on: %s", f.listenAddress) + mux := dns.NewServeMux() + + for _, d := range f.domains { + mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery) + } + + dnsServer := &dns.Server{ + Addr: f.listenAddress, + Net: "udp", + Handler: mux, + } + f.dnsServer = dnsServer + f.mux = mux + return dnsServer.ListenAndServe() +} + +func (f *DNSForwarder) UpdateDomains(domains []string) { + for _, d := range f.domains { + f.mux.HandleRemove(d) + } + + for _, d := range f.domains { + f.mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery) + } + f.domains = domains +} + +func (f *DNSForwarder) Close(ctx context.Context) error { + if f.dnsServer == nil { + return nil + } + return f.dnsServer.ShutdownContext(ctx) +} + +func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { + if len(query.Question) == 0 { + return + } + log.Tracef("received DNS request for DNS forwarder: %v", query.Question[0].Name) + + question := query.Question[0] + domain := question.Name + + resp := query.SetReply(query) + + ips, err := net.LookupIP(domain) + if err != nil { + log.Warnf("failed to resolve query for domain %s: %v", domain, err) + resp.Rcode = dns.RcodeServerFailure + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write failure DNS response: %v", err) + } + return + } + + for _, ip := range ips { + var respRecord dns.RR + if ip.To4() == nil { + log.Tracef("resolved domain %s to IPv6 %s", domain, ip) + rr := dns.AAAA{ + AAAA: ip, + Hdr: dns.RR_Header{ + Name: domain, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: f.ttl, + }, + } + respRecord = &rr + } else { + log.Tracef("resolved domain %s to IPv4 %s", domain, ip) + rr := dns.A{ + A: ip, + Hdr: dns.RR_Header{ + Name: domain, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: f.ttl, + }, + } + respRecord = &rr + } + resp.Answer = append(resp.Answer, respRecord) + } + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } +} diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go new file mode 100644 index 000000000..bc05e0cec --- /dev/null +++ b/client/internal/dnsfwd/manager.go @@ -0,0 +1,106 @@ +package dnsfwd + +import ( + "context" + "fmt" + "net" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) + +const ( + // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also + ListenPort = 5353 + dnsTTL = 60 //seconds +) + +type Manager struct { + firewall firewall.Manager + + fwRules []firewall.Rule + dnsForwarder *DNSForwarder +} + +func NewManager(fw firewall.Manager) *Manager { + return &Manager{ + firewall: fw, + } +} + +func (m *Manager) Start(domains []string) error { + log.Infof("starting DNS forwarder") + if m.dnsForwarder != nil { + return nil + } + + if err := m.allowDNSFirewall(); err != nil { + return err + } + + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, domains) + go func() { + if err := m.dnsForwarder.Listen(); err != nil { + // todo handle close error if it is exists + log.Errorf("failed to start DNS forwarder, err: %v", err) + } + }() + + return nil +} + +func (m *Manager) UpdateDomains(domains []string) { + if m.dnsForwarder == nil { + return + } + + m.dnsForwarder.UpdateDomains(domains) +} + +func (m *Manager) Stop(ctx context.Context) error { + if m.dnsForwarder == nil { + return nil + } + + var mErr *multierror.Error + if err := m.dropDNSFirewall(); err != nil { + mErr = multierror.Append(mErr, err) + } + + if err := m.dnsForwarder.Close(ctx); err != nil { + mErr = multierror.Append(mErr, err) + } + + m.dnsForwarder = nil + return nberrors.FormatErrorOrNil(mErr) +} + +func (h *Manager) allowDNSFirewall() error { + dport := &firewall.Port{ + IsRange: false, + Values: []int{ListenPort}, + } + dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") + if err != nil { + log.Errorf("failed to add allow DNS router rules, err: %v", err) + return err + } + h.fwRules = dnsRules + + return nil +} + +func (h *Manager) dropDNSFirewall() error { + var mErr *multierror.Error + for _, rule := range h.fwRules { + if err := h.firewall.DeletePeerRule(rule); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) + } + } + + h.fwRules = nil + return nberrors.FormatErrorOrNil(mErr) +} diff --git a/client/internal/engine.go b/client/internal/engine.go index 34219def1..caf39a34f 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "maps" "math/rand" "net" "net/netip" @@ -30,10 +29,12 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" @@ -117,7 +118,7 @@ type Engine struct { // mgmClient is a Management Service client mgmClient mgm.Client // peerConns is a map that holds all the peers that are known to this peer - peerConns map[string]*peer.Conn + peerStore *peerstore.Store beforePeerHook nbnet.AddHookFunc afterPeerHook nbnet.RemoveHookFunc @@ -137,10 +138,6 @@ type Engine struct { TURNs []*stun.URI stunTurn atomic.Value - // clientRoutes is the most recent list of clientRoutes received from the Management Service - clientRoutes route.HAMap - clientRoutesMu sync.RWMutex - clientCtx context.Context clientCancel context.CancelFunc @@ -161,9 +158,10 @@ type Engine struct { statusRecorder *peer.Status - firewall manager.Manager - routeManager routemanager.Manager - acl acl.Manager + firewall manager.Manager + routeManager routemanager.Manager + acl acl.Manager + dnsForwardMgr *dnsfwd.Manager dnsServer dns.Server @@ -234,7 +232,7 @@ func NewEngineWithProbes( signaler: peer.NewSignaler(signalClient, config.WgPrivateKey), mgmClient: mgmClient, relayManager: relayManager, - peerConns: make(map[string]*peer.Conn), + peerStore: peerstore.NewConnStore(), syncMsgMux: &sync.Mutex{}, config: config, mobileDep: mobileDep, @@ -287,6 +285,13 @@ func (e *Engine) Stop() error { e.routeManager.Stop(e.stateManager) } + if e.dnsForwardMgr != nil { + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + e.dnsForwardMgr = nil + } + if e.srWatcher != nil { e.srWatcher.Close() } @@ -300,10 +305,6 @@ func (e *Engine) Stop() error { return fmt.Errorf("failed to remove all peers: %s", err) } - e.clientRoutesMu.Lock() - e.clientRoutes = nil - e.clientRoutesMu.Unlock() - if e.cancel != nil { e.cancel() } @@ -382,6 +383,8 @@ func (e *Engine) Start() error { e.relayManager, initialRoutes, e.stateManager, + dnsServer, + e.peerStore, ) beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err != nil { @@ -460,8 +463,8 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { var modified []*mgmProto.RemotePeerConfig for _, p := range peersUpdate { peerPubKey := p.GetWgPubKey() - if peerConn, ok := e.peerConns[peerPubKey]; ok { - if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") { + if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok { + if allowedIPs != strings.Join(p.AllowedIps, ",") { modified = append(modified, p) continue } @@ -492,17 +495,12 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { // removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service. // It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method. func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error { - currentPeers := make([]string, 0, len(e.peerConns)) - for p := range e.peerConns { - currentPeers = append(currentPeers, p) - } - newPeers := make([]string, 0, len(peersUpdate)) for _, p := range peersUpdate { newPeers = append(newPeers, p.GetWgPubKey()) } - toRemove := util.SliceDiff(currentPeers, newPeers) + toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers) for _, p := range toRemove { err := e.removePeer(p) @@ -516,7 +514,7 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) removeAllPeers() error { log.Debugf("removing all peer connections") - for p := range e.peerConns { + for _, p := range e.peerStore.PeersPubKey() { err := e.removePeer(p) if err != nil { return err @@ -540,9 +538,8 @@ func (e *Engine) removePeer(peerKey string) error { } }() - conn, exists := e.peerConns[peerKey] + conn, exists := e.peerStore.Remove(peerKey) if exists { - delete(e.peerConns, peerKey) conn.Close() } return nil @@ -786,7 +783,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error { } func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { - // intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't if networkMap.GetPeerConfig() != nil { err := e.updateConfig(networkMap.GetPeerConfig()) @@ -806,19 +802,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.acl.ApplyFiltering(networkMap) } - protoRoutes := networkMap.GetRoutes() - if protoRoutes == nil { - protoRoutes = []*mgmProto.Route{} - } + routedDomains, routes := toRoutes(networkMap.GetRoutes()) - _, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) - if err != nil { + if err := e.routeManager.UpdateRoutes(serial, routes); err != nil { log.Errorf("failed to update clientRoutes, err: %v", err) } - e.clientRoutesMu.Lock() - e.clientRoutes = clientRoutes - e.clientRoutesMu.Unlock() + // todo: useRoutingPeerDnsResolutionEnabled from network map proto + e.updateDNSForwarder(true, routedDomains) log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) @@ -867,8 +858,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)) - if err != nil { + if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil { log.Errorf("failed to update dns server, err: %v", err) } @@ -881,7 +871,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return nil } -func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { +func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) { + if protoRoutes == nil { + protoRoutes = []*mgmProto.Route{} + } + + var dnsRoutes []string routes := make([]*route.Route, 0) for _, protoRoute := range protoRoutes { var prefix netip.Prefix @@ -892,6 +887,8 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { continue } } + dnsRoutes = append(dnsRoutes, protoRoute.Domains...) + convertedRoute := &route.Route{ ID: route.ID(protoRoute.ID), Network: prefix, @@ -905,7 +902,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { } routes = append(routes, convertedRoute) } - return routes + return dnsRoutes, routes } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { @@ -982,12 +979,16 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { peerKey := peerConfig.GetWgPubKey() peerIPs := peerConfig.GetAllowedIps() - if _, ok := e.peerConns[peerKey]; !ok { + if _, ok := e.peerStore.PeerConn(peerKey); !ok { conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ",")) if err != nil { return fmt.Errorf("create peer connection: %w", err) } - e.peerConns[peerKey] = conn + + if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok { + conn.Close() + return fmt.Errorf("peer already exists: %s", peerKey) + } if e.beforePeerHook != nil && e.afterPeerHook != nil { conn.AddBeforeAddPeerHook(e.beforePeerHook) @@ -1076,8 +1077,8 @@ func (e *Engine) receiveSignalEvents() { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - conn := e.peerConns[msg.Key] - if conn == nil { + conn, ok := e.peerStore.PeerConn(msg.Key) + if !ok { return fmt.Errorf("wrongly addressed message %s", msg.Key) } @@ -1135,7 +1136,7 @@ func (e *Engine) receiveSignalEvents() { return err } - go conn.OnRemoteCandidate(candidate, e.GetClientRoutes()) + go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes()) case sProto.Body_MODE: } @@ -1239,7 +1240,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { if err != nil { return nil, nil, err } - routes := toRoutes(netMap.GetRoutes()) + _, routes := toRoutes(netMap.GetRoutes()) dnsCfg := toDNSConfig(netMap.GetDNSConfig()) return routes, &dnsCfg, nil } @@ -1322,26 +1323,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { } } -// GetClientRoutes returns the current routes from the route map -func (e *Engine) GetClientRoutes() route.HAMap { - e.clientRoutesMu.RLock() - defer e.clientRoutesMu.RUnlock() - - return maps.Clone(e.clientRoutes) -} - -// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only -func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { - e.clientRoutesMu.RLock() - defer e.clientRoutesMu.RUnlock() - - routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes)) - for id, v := range e.clientRoutes { - routes[id.NetID()] = v - } - return routes -} - // GetRouteManager returns the route manager func (e *Engine) GetRouteManager() routemanager.Manager { return e.routeManager @@ -1426,9 +1407,8 @@ func (e *Engine) receiveProbeEvents() { go e.probes.WgProbe.Receive(e.ctx, func() bool { log.Debug("received wg probe request") - for _, peer := range e.peerConns { - key := peer.GetKey() - wgStats, err := peer.WgConfig().WgInterface.GetStats(key) + for _, key := range e.peerStore.PeersPubKey() { + wgStats, err := e.wgInterface.GetStats(key) if err != nil { log.Debugf("failed to get wg stats for peer %s: %s", key, err) } @@ -1505,7 +1485,7 @@ func (e *Engine) startNetworkMonitor() { func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) { var vpnRoutes []netip.Prefix - for _, routes := range e.GetClientRoutes() { + for _, routes := range e.routeManager.GetClientRoutes() { if len(routes) > 0 && routes[0] != nil { vpnRoutes = append(vpnRoutes, routes[0].Network) } @@ -1573,6 +1553,40 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) { return nm, nil } +// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag +func (e *Engine) updateDNSForwarder(enabled bool, domains []string) { + if !enabled { + if e.dnsForwardMgr == nil { + return + } + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + return + } + + if len(domains) > 0 { + log.Infof("enable domain router service for domains: %v", domains) + if e.dnsForwardMgr == nil { + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall) + + if err := e.dnsForwardMgr.Start(domains); err != nil { + log.Errorf("failed to start DNS forward: %v", err) + e.dnsForwardMgr = nil + } + } else { + log.Infof("update domain router service for domains: %v", domains) + e.dnsForwardMgr.UpdateDomains(domains) + } + } else if e.dnsForwardMgr != nil { + log.Infof("disable domain router service") + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + e.dnsForwardMgr = nil + } +} + // isChecksEqual checks if two slices of checks are equal. func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { for _, check := range checks { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index f400efa55..9305c0b5a 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -252,7 +252,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { }, } engine.wgInterface = wgIface - engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil) + engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil, nil) _, _, err = engine.routeManager.Init() require.NoError(t, err) engine.dnsServer = &dns.MockServer{ @@ -392,8 +392,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { return } - if len(engine.peerConns) != c.expectedLen { - t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerConns)) + if len(engine.peerStore.PeersPubKey()) != c.expectedLen { + t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerStore.PeersPubKey())) } if engine.networkSerial != c.expectedSerial { @@ -401,7 +401,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { } for _, p := range c.expectedPeers { - conn, ok := engine.peerConns[p.GetWgPubKey()] + conn, ok := engine.peerStore.PeerConn(p.GetWgPubKey()) if !ok { t.Errorf("expecting Engine.peerConns to contain peer %s", p) } @@ -626,10 +626,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { }{} mockRouteManager := &routemanager.MockManager{ - UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { + UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { input.inputSerial = updateSerial input.inputRoutes = newRoutes - return nil, nil, testCase.inputErr + return testCase.inputErr }, } @@ -802,8 +802,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { assert.NoError(t, err, "shouldn't return error") mockRouteManager := &routemanager.MockManager{ - UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { - return nil, nil, nil + UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { + return nil }, } @@ -1238,7 +1238,8 @@ func getConnectedPeers(e *Engine) int { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() i := 0 - for _, conn := range e.peerConns { + for _, id := range e.peerStore.PeersPubKey() { + conn, _ := e.peerStore.PeerConn(id) if conn.Status() == peer.StatusConnected { i++ } @@ -1250,5 +1251,5 @@ func getPeers(e *Engine) int { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - return len(e.peerConns) + return len(e.peerStore.PeersPubKey()) } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 5c2e2cb60..b8cb2582f 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -747,6 +747,11 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { conn.wgProxyRelay = proxy } +// AllowedIP returns the allowed IP of the remote peer +func (conn *Conn) AllowedIP() net.IP { + return conn.allowedIP +} + func isController(config ConnConfig) bool { return config.LocalKey > config.Key } diff --git a/client/internal/peerstore/store.go b/client/internal/peerstore/store.go new file mode 100644 index 000000000..6b3385ff5 --- /dev/null +++ b/client/internal/peerstore/store.go @@ -0,0 +1,87 @@ +package peerstore + +import ( + "net" + "sync" + + "golang.org/x/exp/maps" + + "github.com/netbirdio/netbird/client/internal/peer" +) + +// Store is a thread-safe store for peer connections. +type Store struct { + peerConns map[string]*peer.Conn + peerConnsMu sync.RWMutex +} + +func NewConnStore() *Store { + return &Store{ + peerConns: make(map[string]*peer.Conn), + } +} + +func (s *Store) AddPeerConn(pubKey string, conn *peer.Conn) bool { + s.peerConnsMu.Lock() + defer s.peerConnsMu.Unlock() + + _, ok := s.peerConns[pubKey] + if ok { + return false + } + + s.peerConns[pubKey] = conn + return true +} + +func (s *Store) Remove(pubKey string) (*peer.Conn, bool) { + s.peerConnsMu.Lock() + defer s.peerConnsMu.Unlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + delete(s.peerConns, pubKey) + return p, true +} + +func (s *Store) AllowedIPs(pubKey string) (string, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return "", false + } + return p.WgConfig().AllowedIps, true +} + +func (s *Store) AllowedIP(pubKey string) (net.IP, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + return p.AllowedIP(), true +} + +func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + return p, true +} + +func (s *Store) PeersPubKey() []string { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + return maps.Keys(s.peerConns) +} diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 13e45b3a3..b7fc5b15d 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -13,12 +13,16 @@ import ( "github.com/netbirdio/netbird/client/iface" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/static" "github.com/netbirdio/netbird/route" ) +const useNewDNSRoute = true + type routerPeerStatus struct { connected bool relayed bool @@ -53,7 +57,17 @@ type clientNetwork struct { updateSerial uint64 } -func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork { +func newClientNetworkWatcher( + ctx context.Context, + dnsRouteInterval time.Duration, + wgInterface iface.IWGIface, + statusRecorder *peer.Status, + rt *route.Route, + routeRefCounter *refcounter.RouteRefCounter, + allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, + dnsServer nbdns.Server, + peerStore *peerstore.Store, +) *clientNetwork { ctx, cancel := context.WithCancel(ctx) client := &clientNetwork{ @@ -65,7 +79,16 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration routePeersNotifiers: make(map[string]chan struct{}), routeUpdate: make(chan routesUpdate), peerStateUpdate: make(chan struct{}), - handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface), + handler: handlerFromRoute( + rt, + routeRefCounter, + allowedIPsRefCounter, + dnsRouteInterval, + statusRecorder, + wgInterface, + dnsServer, + peerStore, + ), } return client } @@ -368,10 +391,37 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { } } -func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler { +func handlerFromRoute( + rt *route.Route, + routeRefCounter *refcounter.RouteRefCounter, + allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, + dnsRouterInteval time.Duration, + statusRecorder *peer.Status, + wgInterface iface.IWGIface, + dnsServer nbdns.Server, + peerStore *peerstore.Store, +) RouteHandler { if rt.IsDynamic() { + if useNewDNSRoute { + return dnsinterceptor.New( + rt, + routeRefCounter, + allowedIPsRefCounter, + statusRecorder, + dnsServer, + peerStore, + ) + } dns := nbdns.NewServiceViaMemory(wgInterface) - return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())) + return dynamic.NewRoute( + rt, + routeRefCounter, + allowedIPsRefCounter, + dnsRouterInteval, + statusRecorder, + wgInterface, + fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()), + ) } return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go new file mode 100644 index 000000000..702290015 --- /dev/null +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -0,0 +1,334 @@ +package dnsinterceptor + +import ( + "context" + "fmt" + "net" + "net/netip" + "strings" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + nbdns "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dnsfwd" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" +) + +type domainMap map[domain.Domain][]netip.Prefix + +type DnsInterceptor struct { + mu sync.RWMutex + route *route.Route + routeRefCounter *refcounter.RouteRefCounter + allowedIPsRefcounter *refcounter.AllowedIPsRefCounter + statusRecorder *peer.Status + dnsServer nbdns.Server + currentPeerKey string + interceptedDomains domainMap + peerStore *peerstore.Store +} + +func New( + rt *route.Route, + routeRefCounter *refcounter.RouteRefCounter, + allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, + statusRecorder *peer.Status, + dnsServer nbdns.Server, + peerStore *peerstore.Store, +) *DnsInterceptor { + return &DnsInterceptor{ + route: rt, + routeRefCounter: routeRefCounter, + allowedIPsRefcounter: allowedIPsRefCounter, + statusRecorder: statusRecorder, + dnsServer: dnsServer, + interceptedDomains: make(domainMap), + peerStore: peerStore, + } +} + +func (d *DnsInterceptor) String() string { + return d.route.Domains.SafeString() +} + +func (d *DnsInterceptor) AddRoute(context.Context) error { + d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute) + return nil +} + +func (d *DnsInterceptor) RemoveRoute() error { + d.mu.Lock() + defer d.mu.Unlock() + + var merr *multierror.Error + for domain, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { + if _, err := d.routeRefCounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err)) + } + if d.currentPeerKey != "" { + if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) + } + } + } + log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", ")) + + d.statusRecorder.DeleteResolvedDomainsStates(domain) + } + + clear(d.interceptedDomains) + + d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute) + + return nberrors.FormatErrorOrNil(merr) +} + +func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error { + d.mu.Lock() + defer d.mu.Unlock() + + var merr *multierror.Error + for domain, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { + if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) + } else if ref.Count > 1 && ref.Out != peerKey { + log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", + prefix.Addr(), + domain.SafeString(), + ref.Out, + ) + } + } + } + + d.currentPeerKey = peerKey + return nberrors.FormatErrorOrNil(merr) +} + +func (d *DnsInterceptor) RemoveAllowedIPs() error { + d.mu.Lock() + defer d.mu.Unlock() + + var merr *multierror.Error + for _, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { + if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) + } + } + } + + d.currentPeerKey = "" + return nberrors.FormatErrorOrNil(merr) +} + +// ServeDNS implements the dns.Handler interface +func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + return + } + log.Tracef("received DNS request: %v", r.Question[0].Name) + + d.mu.RLock() + peerKey := d.currentPeerKey + d.mu.RUnlock() + + if peerKey == "" { + log.Debugf("no current peer key set, letting next handler try for %s", r.Question[0].Name) + d.continueToNextHandler(w, r, "no current peer key") + return + } + + upstreamIP, err := d.getUpstreamIP(peerKey) + if err != nil { + log.Errorf("failed to get upstream IP: %v", err) + d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err)) + return + } + + client := &dns.Client{ + Timeout: 5 * time.Second, + Net: "udp", + } + upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort) + reply, _, err := client.ExchangeContext(context.Background(), r, upstream) + + var answer []dns.RR + if reply != nil { + answer = reply.Answer + } + log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, peerKey, r.Question[0].Name, answer) + + if err != nil { + log.Errorf("failed to exchange DNS request with %s: %v", upstream, err) + if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { + log.Errorf("failed writing DNS response: %v", err) + } + return + } + + reply.Id = r.Id + if err := d.writeMsg(w, reply); err != nil { + log.Errorf("failed writing DNS response: %v", err) + } +} + +// continueToNextHandler signals the handler chain to try the next handler +func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) { + log.Debugf("continuing to next handler for %s: %s", r.Question[0].Name, reason) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + // Set Zero bit to signal handler chain to continue + resp.MsgHdr.Zero = true + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed writing DNS continue response: %v", err) + } +} + +func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) { + peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey) + if !exists { + return nil, fmt.Errorf("peer connection not found for key: %s", peerKey) + } + return peerAllowedIP, nil +} + +func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { + if r == nil { + return fmt.Errorf("received nil DNS message") + } + + if len(r.Answer) > 0 && len(r.Question) > 0 { + // DNS names from miekg/dns are already in punycode format + dom := domain.Domain(r.Question[0].Name) + + var newPrefixes []netip.Prefix + for _, answer := range r.Answer { + var ip netip.Addr + switch rr := answer.(type) { + case *dns.A: + addr, ok := netip.AddrFromSlice(rr.A) + if !ok { + log.Debugf("failed to convert A record IP: %v", rr.A) + continue + } + ip = addr + case *dns.AAAA: + addr, ok := netip.AddrFromSlice(rr.AAAA) + if !ok { + log.Debugf("failed to convert AAAA record IP: %v", rr.AAAA) + continue + } + ip = addr + default: + continue + } + + prefix := netip.PrefixFrom(ip, ip.BitLen()) + newPrefixes = append(newPrefixes, prefix) + } + + if len(newPrefixes) > 0 { + if err := d.updateDomainPrefixes(dom, newPrefixes); err != nil { + log.Errorf("failed to update domain prefixes: %v", err) + } + } + } + + if err := w.WriteMsg(r); err != nil { + return fmt.Errorf("failed to write DNS response: %v", err) + } + + return nil +} + +func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes []netip.Prefix) error { + d.mu.Lock() + defer d.mu.Unlock() + + oldPrefixes := d.interceptedDomains[domain] + toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes) + + var merr *multierror.Error + + // Add new prefixes + for _, prefix := range toAdd { + if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err)) + continue + } + + if d.currentPeerKey == "" { + continue + } + if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) + } else if ref.Count > 1 && ref.Out != d.currentPeerKey { + log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", + prefix.Addr(), + domain.SafeString(), + ref.Out, + ) + } + } + + if !d.route.KeepRoute { + // Remove old prefixes + for _, prefix := range toRemove { + if _, err := d.routeRefCounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err)) + } + if d.currentPeerKey != "" { + if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) + } + } + } + } + + // Update domain prefixes + if len(toAdd) > 0 || len(toRemove) > 0 { + d.interceptedDomains[domain] = newPrefixes + d.statusRecorder.UpdateResolvedDomainsStates(domain, newPrefixes) + + if len(toAdd) > 0 { + log.Debugf("added dynamic route(s) for [%s]: %s", domain.SafeString(), toAdd) + } + if len(toRemove) > 0 { + log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), toRemove) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + +func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) { + prefixSet := make(map[netip.Prefix]bool) + for _, prefix := range oldPrefixes { + prefixSet[prefix] = false + } + for _, prefix := range newPrefixes { + if _, exists := prefixSet[prefix]; exists { + prefixSet[prefix] = true + } else { + toAdd = append(toAdd, prefix) + } + } + for prefix, inUse := range prefixSet { + if !inUse { + toRemove = append(toRemove, prefix) + } + } + return +} diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index ac94d4a5c..b71a91f74 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -74,11 +74,7 @@ func NewRoute( } func (r *Route) String() string { - s, err := r.route.Domains.String() - if err != nil { - return r.route.Domains.PunycodeString() - } - return s + return r.route.Domains.SafeString() } func (r *Route) AddRoute(ctx context.Context) error { diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 8bf3a91b0..30899bc1d 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -12,12 +12,15 @@ import ( "time" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" @@ -33,9 +36,11 @@ import ( // Manager is a route manager interface type Manager interface { Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) - UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) + UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector + GetClientRoutes() route.HAMap + GetClientRoutesWithNetID() map[route.NetID][]*route.Route SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string EnableServerRouter(firewall firewall.Manager) error @@ -60,6 +65,10 @@ type DefaultManager struct { allowedIPsRefCounter *refcounter.AllowedIPsRefCounter dnsRouteInterval time.Duration stateManager *statemanager.Manager + // clientRoutes is the most recent list of clientRoutes received from the Management Service + clientRoutes route.HAMap + dnsServer dns.Server + peerStore *peerstore.Store } func NewManager( @@ -71,6 +80,8 @@ func NewManager( relayMgr *relayClient.Manager, initialRoutes []*route.Route, stateManager *statemanager.Manager, + dnsServer dns.Server, + peerStore *peerstore.Store, ) *DefaultManager { mCTX, cancel := context.WithCancel(ctx) notifier := notifier.NewNotifier() @@ -88,6 +99,8 @@ func NewManager( pubKey: pubKey, notifier: notifier, stateManager: stateManager, + dnsServer: dnsServer, + peerStore: peerStore, } dm.routeRefCounter = refcounter.New( @@ -116,7 +129,7 @@ func NewManager( ) if runtime.GOOS == "android" { - cr := dm.clientRoutes(initialRoutes) + cr := dm.initialClientRoutes(initialRoutes) dm.notifier.SetInitialClientRoutes(cr) } return dm @@ -207,33 +220,40 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } m.ctx = nil + + m.mux.Lock() + defer m.mux.Unlock() + m.clientRoutes = nil } // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps -func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { +func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { select { case <-m.ctx.Done(): log.Infof("not updating routes as context is closed") - return nil, nil, m.ctx.Err() + return nil default: - m.mux.Lock() - defer m.mux.Unlock() - - newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) - - filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) - m.updateClientNetworks(updateSerial, filteredClientRoutes) - m.notifier.OnNewRoutes(filteredClientRoutes) - - if m.serverRouter != nil { - err := m.serverRouter.updateRoutes(newServerRoutesMap) - if err != nil { - return nil, nil, fmt.Errorf("update routes: %w", err) - } - } - - return newServerRoutesMap, newClientRoutesIDMap, nil } + + m.mux.Lock() + defer m.mux.Unlock() + + newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) + + filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) + m.updateClientNetworks(updateSerial, filteredClientRoutes) + m.notifier.OnNewRoutes(filteredClientRoutes) + + if m.serverRouter != nil { + err := m.serverRouter.updateRoutes(newServerRoutesMap) + if err != nil { + return err + } + } + + m.clientRoutes = newClientRoutesIDMap + + return nil } // SetRouteChangeListener set RouteListener for route change Notifier @@ -251,9 +271,24 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector { return m.routeSelector } -// GetClientRoutes returns the client routes -func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork { - return m.clientNetworks +// GetClientRoutes returns most recent list of clientRoutes received from the Management Service +func (m *DefaultManager) GetClientRoutes() route.HAMap { + m.mux.Lock() + defer m.mux.Unlock() + + return maps.Clone(m.clientRoutes) +} + +// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only +func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { + m.mux.Lock() + defer m.mux.Unlock() + + routes := make(map[route.NetID][]*route.Route, len(m.clientRoutes)) + for id, v := range m.clientRoutes { + routes[id.NetID()] = v + } + return routes } // TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones @@ -273,7 +308,17 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { continue } - clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter) + clientNetworkWatcher := newClientNetworkWatcher( + m.ctx, + m.dnsRouteInterval, + m.wgInterface, + m.statusRecorder, + routes[0], + m.routeRefCounter, + m.allowedIPsRefCounter, + m.dnsServer, + m.peerStore, + ) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) @@ -302,7 +347,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout for id, routes := range networks { clientNetworkWatcher, found := m.clientNetworks[id] if !found { - clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter) + clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerStore) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() } @@ -345,7 +390,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID] return newServerRoutesMap, newClientRoutesIDMap } -func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route { +func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route { _, crMap := m.classifyRoutes(initialRoutes) rs := make([]*route.Route, 0, len(crMap)) for _, routes := range crMap { diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 07dac21b8..71b951593 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -424,7 +424,7 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() - routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil) + routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil, nil, nil) _, _, err = routeManager.Init() @@ -436,11 +436,11 @@ func TestManagerUpdateRoutes(t *testing.T) { } if len(testCase.inputInitRoutes) > 0 { - _, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) + _ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) require.NoError(t, err, "should update routes with init routes") } - _, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) + _ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 556a62351..0219b17c8 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -2,7 +2,6 @@ package routemanager import ( "context" - "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" @@ -15,10 +14,12 @@ import ( // MockManager is the mock instance of a route manager type MockManager struct { - UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) - TriggerSelectionFunc func(haMap route.HAMap) - GetRouteSelectorFunc func() *routeselector.RouteSelector - StopFunc func(manager *statemanager.Manager) + UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error + TriggerSelectionFunc func(haMap route.HAMap) + GetRouteSelectorFunc func() *routeselector.RouteSelector + GetClientRoutesFunc func() route.HAMap + GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route + StopFunc func(manager *statemanager.Manager) } func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { @@ -31,11 +32,11 @@ func (m *MockManager) InitialRouteRange() []string { } // UpdateRoutes mock implementation of UpdateRoutes from Manager interface -func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { +func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { if m.UpdateRoutesFunc != nil { return m.UpdateRoutesFunc(updateSerial, newRoutes) } - return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented") + return nil } func (m *MockManager) TriggerSelection(networks route.HAMap) { @@ -52,6 +53,22 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector { return nil } +// GetClientRoutes mock implementation of GetClientRoutes from Manager interface +func (m *MockManager) GetClientRoutes() route.HAMap { + if m.GetClientRoutesFunc != nil { + return m.GetClientRoutesFunc() + } + return nil +} + +// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface +func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { + if m.GetClientRoutesWithNetIDFunc != nil { + return m.GetClientRoutesWithNetIDFunc() + } + return nil +} + // Start mock implementation of Start from Manager interface func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) { } diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 6f501e0c6..1b4413141 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -272,8 +272,8 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { return nil, fmt.Errorf("not connected") } - routesMap := engine.GetClientRoutesWithNetID() routeManager := engine.GetRouteManager() + routesMap := routeManager.GetClientRoutesWithNetID() if routeManager == nil { return nil, fmt.Errorf("could not get route manager") } @@ -365,12 +365,12 @@ func (c *Client) SelectRoute(id string) error { } else { log.Debugf("select route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.SelectRoutes(routes, true, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { log.Debugf("error when selecting routes: %s", err) return fmt.Errorf("select routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) return nil } @@ -392,12 +392,12 @@ func (c *Client) DeselectRoute(id string) error { } else { log.Debugf("deselect route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { log.Debugf("error when deselecting routes: %s", err) return fmt.Errorf("deselect routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) return nil } diff --git a/client/server/network.go b/client/server/network.go index ed204dd75..b4b4071b4 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -34,7 +34,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro return nil, fmt.Errorf("not connected") } - routesMap := engine.GetClientRoutesWithNetID() + routesMap := engine.GetRouteManager().GetClientRoutesWithNetID() routeSelector := engine.GetRouteManager().GetRouteSelector() var routes []*selectRoute @@ -116,11 +116,12 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ routeSelector.SelectAllRoutes() } else { routes := toNetIDs(req.GetNetworkIDs()) - if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil { return nil, fmt.Errorf("select routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) return &proto.SelectNetworksResponse{}, nil } @@ -145,11 +146,12 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe routeSelector.DeselectAllRoutes() } else { routes := toNetIDs(req.GetNetworkIDs()) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil { return nil, fmt.Errorf("deselect routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) return &proto.SelectNetworksResponse{}, nil } diff --git a/dns/dns.go b/dns/dns.go index 18528c743..8dfdf8526 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -108,3 +108,9 @@ func GetParsedDomainLabel(name string) (string, error) { return validHost, nil } + +// NormalizeZone returns a normalized domain name without the wildcard prefix +func NormalizeZone(domain string) string { + d, _ := strings.CutPrefix(domain, "*.") + return d +} diff --git a/go.mod b/go.mod index 81d829462..53124aa69 100644 --- a/go.mod +++ b/go.mod @@ -207,6 +207,7 @@ require ( github.com/spf13/cast v1.5.0 // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect diff --git a/go.sum b/go.sum index ac496ce0a..257644284 100644 --- a/go.sum +++ b/go.sum @@ -662,6 +662,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go index 9d420066c..a29ba4562 100644 --- a/management/server/http/handlers/routes/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -360,7 +360,7 @@ func validateDomains(domains []string) (domain.List, error) { return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) } - domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) + domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) var domainList domain.List diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index 879bc7fdb..4cee3ee30 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -330,6 +330,14 @@ func TestRoutesHandlers(t *testing.T) { expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, + { + name: "POST Wildcard Domain", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["*.example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID)), + expectedStatus: http.StatusOK, + expectedBody: false, + }, { name: "POST UnprocessableEntity when both network and domains are provided", requestType: http.MethodPost, @@ -609,6 +617,30 @@ func TestValidateDomains(t *testing.T) { expected: domain.List{"google.com"}, wantErr: true, }, + { + name: "Valid wildcard domain", + domains: []string{"*.example.com"}, + expected: domain.List{"*.example.com"}, + wantErr: false, + }, + { + name: "Wildcard with dot domain", + domains: []string{".*.example.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Wildcard with dot domain", + domains: []string{".*.example.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid wildcard domain", + domains: []string{"a.*.example.com"}, + expected: nil, + wantErr: true, + }, } for _, tt := range tests {