From 5fee069379c75266dbbf4d5c9aee6f3d0d0fbab8 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 12 Dec 2024 18:19:06 +0100 Subject: [PATCH] Add handler chains (#3039) --------- Co-authored-by: Zoltan Papp --- client/internal/dns/handler_chain.go | 155 +++++++++ client/internal/dns/handler_chain_test.go | 319 ++++++++++++++++++ client/internal/dns/mock_server.go | 19 +- client/internal/dns/server.go | 80 +++-- client/internal/dns/service_memory.go | 1 - client/internal/dns/upstream.go | 5 + client/internal/dnsfwd/forwarder.go | 21 +- client/internal/engine_test.go | 5 +- .../routemanager/dnsinterceptor/handler.go | 60 ++-- client/internal/routemanager/dynamic/route.go | 6 +- dns/dns.go | 6 + go.mod | 1 + go.sum | 1 + 13 files changed, 588 insertions(+), 91 deletions(-) create mode 100644 client/internal/dns/handler_chain.go create mode 100644 client/internal/dns/handler_chain_test.go diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go new file mode 100644 index 000000000..6608e82f1 --- /dev/null +++ b/client/internal/dns/handler_chain.go @@ -0,0 +1,155 @@ +package dns + +import ( + "strings" + "sync" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" +) + +const ( + PriorityDNSRoute = 100 + PriorityMatchDomain = 50 + PriorityDefault = 0 +) + +type HandlerEntry struct { + Handler dns.Handler + Priority int + Pattern string + IsWildcard bool + StopHandler handlerWithStop +} + +type HandlerChain struct { + mu sync.RWMutex + handlers []HandlerEntry +} + +// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain +type ResponseWriterChain struct { + dns.ResponseWriter + shouldContinue bool +} + +func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error { + // Check if this is a continue signal (NXDOMAIN with Zero bit set) + if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero { + w.shouldContinue = true + return nil + } + return w.ResponseWriter.WriteMsg(m) +} + +func NewHandlerChain() *HandlerChain { + return &HandlerChain{ + handlers: make([]HandlerEntry, 0), + } +} + +func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) { + c.mu.Lock() + defer c.mu.Unlock() + + isWildcard := strings.HasPrefix(pattern, "*.") + if isWildcard { + pattern = pattern[2:] + } + pattern = dns.Fqdn(pattern) + + log.Debugf("adding handler for pattern: %s (wildcard: %v) with priority %d", pattern, isWildcard, priority) + + entry := HandlerEntry{ + Handler: handler, + Priority: priority, + Pattern: pattern, + IsWildcard: isWildcard, + StopHandler: stopHandler, + } + + // Insert handler in priority order + pos := 0 + for i, h := range c.handlers { + if h.Priority < priority { + pos = i + break + } + pos = i + 1 + } + + c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...) +} + +func (c *HandlerChain) RemoveHandler(pattern string) { + c.mu.Lock() + defer c.mu.Unlock() + + pattern = dns.Fqdn(pattern) + for i, entry := range c.handlers { + if entry.Pattern == pattern { + if entry.StopHandler != nil { + entry.StopHandler.stop() + } + c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) + return + } + } +} + +func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + return + } + + qname := r.Question[0].Name + log.Debugf("handling DNS request for %s", qname) + + c.mu.RLock() + defer c.mu.RUnlock() + + log.Debugf("current handlers (%d):", len(c.handlers)) + for _, h := range c.handlers { + log.Debugf(" - pattern: %s, wildcard: %v, priority: %d", h.Pattern, h.IsWildcard, h.Priority) + } + + // Try handlers in priority order + for _, entry := range c.handlers { + var matched bool + switch { + case entry.Pattern == ".": + matched = true + case entry.IsWildcard: + parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".") + matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern) + default: + matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern) + } + + if !matched { + log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false", + entry.Pattern, qname, entry.IsWildcard) + continue + } + + log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v", + entry.Pattern, qname, entry.IsWildcard) + chainWriter := &ResponseWriterChain{ResponseWriter: w} + entry.Handler.ServeDNS(chainWriter, r) + + // If handler wants to continue, try next handler + if chainWriter.shouldContinue { + log.Debugf("handler requested continue to next handler") + continue + } + return + } + + // No handler matched or all handlers passed + log.Debugf("no handler found for %s", qname) + resp := &dns.Msg{} + resp.SetRcode(r, dns.RcodeNameError) + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } +} diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go new file mode 100644 index 000000000..f97483019 --- /dev/null +++ b/client/internal/dns/handler_chain_test.go @@ -0,0 +1,319 @@ +package dns_test + +import ( + "net" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + nbdns "github.com/netbirdio/netbird/client/internal/dns" +) + +// MockHandler implements dns.Handler interface for testing +type MockHandler struct { + mock.Mock +} + +func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + m.Called(w, r) +} + +// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order +func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { + chain := nbdns.NewHandlerChain() + + // Create mock handlers for different priorities + defaultHandler := &MockHandler{} + matchDomainHandler := &MockHandler{} + dnsRouteHandler := &MockHandler{} + + // Setup handlers with different priorities + chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil) + chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil) + chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil) + + // Create test request + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + // Create test writer + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + // Setup expectations - only highest priority handler should be called + dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once() + matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe() + defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() + + // Execute + chain.ServeDNS(w, r) + + // Verify all expectations were met + dnsRouteHandler.AssertExpectations(t) + matchDomainHandler.AssertExpectations(t) + defaultHandler.AssertExpectations(t) +} + +// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios +func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { + tests := []struct { + name string + handlerDomain string + queryDomain string + isWildcard bool + shouldMatch bool + }{ + { + name: "exact match", + handlerDomain: "example.com.", + queryDomain: "example.com.", + isWildcard: false, + shouldMatch: true, + }, + { + name: "subdomain with non-wildcard", + handlerDomain: "example.com.", + queryDomain: "sub.example.com.", + isWildcard: false, + shouldMatch: true, + }, + { + name: "wildcard match", + handlerDomain: "*.example.com.", + queryDomain: "sub.example.com.", + isWildcard: true, + shouldMatch: true, + }, + { + name: "wildcard no match on apex", + handlerDomain: "*.example.com.", + queryDomain: "example.com.", + isWildcard: true, + shouldMatch: false, + }, + { + name: "root zone match", + handlerDomain: ".", + queryDomain: "anything.com.", + isWildcard: false, + shouldMatch: true, + }, + { + name: "no match different domain", + handlerDomain: "example.com.", + queryDomain: "example.org.", + isWildcard: false, + shouldMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + mockHandler := &MockHandler{} + + pattern := tt.handlerDomain + if tt.isWildcard { + pattern = "*." + tt.handlerDomain[2:] // Remove the first two chars if it's a wildcard + } + + chain.AddHandler(pattern, mockHandler, nbdns.PriorityDefault, nil) + + r := new(dns.Msg) + r.SetQuestion(tt.queryDomain, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + if tt.shouldMatch { + mockHandler.On("ServeDNS", mock.Anything, r).Once() + } + + chain.ServeDNS(w, r) + mockHandler.AssertExpectations(t) + }) + } +} + +// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns +func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { + tests := []struct { + name string + handlers []struct { + pattern string + priority int + } + queryDomain string + expectedCalls int + expectedHandler int // index of the handler that should be called + }{ + { + name: "wildcard and exact same priority - exact should win", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "example.com.", priority: nbdns.PriorityDefault}, + }, + queryDomain: "example.com.", + expectedCalls: 1, + expectedHandler: 1, // exact match handler should be called + }, + { + name: "higher priority wildcard over lower priority exact", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "example.com.", priority: nbdns.PriorityDefault}, + {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "test.example.com.", + expectedCalls: 1, + expectedHandler: 1, // higher priority wildcard handler should be called + }, + { + name: "multiple wildcards different priorities", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "test.example.com.", + expectedCalls: 1, + expectedHandler: 2, // highest priority handler should be called + }, + { + name: "subdomain with mix of patterns", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "sub.test.example.com.", + expectedCalls: 1, + expectedHandler: 2, // highest priority matching handler should be called + }, + { + name: "root zone with specific domain", + handlers: []struct { + pattern string + priority int + }{ + {pattern: ".", priority: nbdns.PriorityDefault}, + {pattern: "example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "example.com.", + expectedCalls: 1, + expectedHandler: 1, // higher priority specific domain should win over root + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + var handlers []*MockHandler + + // Setup handlers and expectations + for i := range tt.handlers { + handler := &MockHandler{} + handlers = append(handlers, handler) + + // Set expectation based on whether this handler should be called + if i == tt.expectedHandler { + handler.On("ServeDNS", mock.Anything, mock.Anything).Once() + } else { + handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe() + } + + chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil) + } + + // Create and execute request + r := new(dns.Msg) + r.SetQuestion(tt.queryDomain, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + chain.ServeDNS(w, r) + + // Verify expectations + for _, handler := range handlers { + handler.AssertExpectations(t) + } + }) + } +} + +// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality +func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { + chain := nbdns.NewHandlerChain() + + // Create handlers + handler1 := &MockHandler{} + handler2 := &MockHandler{} + handler3 := &MockHandler{} + + // Add handlers in priority order + chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil) + chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil) + chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil) + + // Create test request + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + // Setup mock responses to simulate chain continuation + handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // First handler signals continue + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true // Signal to continue + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // Second handler signals continue + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // Last handler responds normally + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeSuccess) + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + // Execute + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + chain.ServeDNS(w, r) + + // Verify all handlers were called in order + handler1.AssertExpectations(t) + handler2.AssertExpectations(t) + handler3.AssertExpectations(t) +} + +// mockResponseWriter implements dns.ResponseWriter for testing +type mockResponseWriter struct { + mock.Mock +} + +func (m *mockResponseWriter) LocalAddr() net.Addr { return nil } +func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil } +func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil } +func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } +func (m *mockResponseWriter) Close() error { return nil } +func (m *mockResponseWriter) TsigStatus() error { return nil } +func (m *mockResponseWriter) TsigTimersOnly(bool) {} +func (m *mockResponseWriter) Hijack() {} diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 1ec86a7ee..37a93fdfe 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -13,27 +13,20 @@ type MockServer struct { InitializeFunc func() error StopFunc func() UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error - RegisterHandlerFunc func([]string, dns.Handler) error - UnregisterHandlerFunc func([]string) error - DeregisterHandlerFunc func([]string) error + RegisterHandlerFunc func([]string, dns.Handler, int) + DeregisterHandlerFunc func([]string) } -func (m *MockServer) UnregisterHandler(domains []string) error { - panic("implement me") -} - -func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler) error { +func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { if m.RegisterHandlerFunc != nil { - return m.RegisterHandlerFunc(domains, handler) + m.RegisterHandlerFunc(domains, handler, priority) } - return nil } -func (m *MockServer) DeregisterHandler(domains []string) error { +func (m *MockServer) DeregisterHandler(domains []string) { if m.DeregisterHandlerFunc != nil { - return m.DeregisterHandlerFunc(domains) + m.DeregisterHandlerFunc(domains) } - return nil } // Initialize mock implementation of Initialize from Server interface diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index aa96f2306..6e45f0390 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -30,7 +30,8 @@ type IosDnsManager interface { // Server is a dns server interface type Server interface { - RegisterHandler(domains []string, handler dns.Handler) error + RegisterHandler(domains []string, handler dns.Handler, priority int) + DeregisterHandler(domains []string) Initialize() error Stop() DnsIP() string @@ -38,7 +39,6 @@ type Server interface { OnUpdatedHostDNSServer(strings []string) SearchDomains() []string ProbeAvailability() - DeregisterHandler(domains []string) error } type registeredHandlerMap map[string]handlerWithStop @@ -56,6 +56,7 @@ type DefaultServer struct { updateSerial uint64 previousConfigHash uint64 currentConfig HostDNSConfig + handlerChain *HandlerChain // permanent related properties permanent bool @@ -76,8 +77,9 @@ type handlerWithStop interface { } type muxUpdate struct { - domain string - handler handlerWithStop + domain string + handler handlerWithStop + priority int } // NewDefaultServer returns a new dns server @@ -137,10 +139,11 @@ func NewDefaultServerIos( func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer { ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - service: dnsService, - dnsMuxMap: make(registeredHandlerMap), + ctx: ctx, + ctxCancel: stop, + service: dnsService, + handlerChain: NewHandlerChain(), + dnsMuxMap: make(registeredHandlerMap), localResolver: &localResolver{ registeredMap: make(registrationMap), }, @@ -153,32 +156,38 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi return defaultServer } -func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler) error { +func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { s.mux.Lock() defer s.mux.Unlock() - log.Debugf("registering handler %s", handler) - for _, domain := range domains { - wosuff, _ := strings.CutPrefix(domain, "*.") - pattern := dns.Fqdn(wosuff) - s.service.RegisterMux(pattern, handler) - } - - return nil + s.registerHandler(domains, handler, priority) } -func (s *DefaultServer) DeregisterHandler(domains []string) error { +// registerhandler without lock +func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { + log.Debugf("registering handler %s with priority %d", handler, priority) + + for _, domain := range domains { + s.handlerChain.AddHandler(domain, handler, priority, nil) + + s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain) + } +} + +func (s *DefaultServer) DeregisterHandler(domains []string) { s.mux.Lock() defer s.mux.Unlock() + s.deregisterHandler(domains) +} + +func (s *DefaultServer) deregisterHandler(domains []string) { log.Debugf("unregistering handler for domains %s", domains) for _, domain := range domains { - wosuff, _ := strings.CutPrefix(domain, "*.") - pattern := dns.Fqdn(wosuff) - s.service.DeregisterMux(pattern) - } + s.handlerChain.RemoveHandler(domain) - return nil + s.service.DeregisterMux(nbdns.NormalizeZone(domain)) + } } // Initialize instantiate host manager and the dns service @@ -442,8 +451,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam if nsGroup.Primary { muxUpdates = append(muxUpdates, muxUpdate{ - domain: nbdns.RootZone, - handler: handler, + domain: nbdns.RootZone, + handler: handler, + priority: PriorityDefault, }) continue } @@ -459,8 +469,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam return nil, fmt.Errorf("received a nameserver group with an empty domain element") } muxUpdates = append(muxUpdates, muxUpdate{ - domain: domain, - handler: handler, + domain: domain, + handler: handler, + priority: PriorityMatchDomain, }) } } @@ -474,7 +485,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { var isContainRootUpdate bool for _, update := range muxUpdates { - s.service.RegisterMux(update.domain, update.handler) + s.registerHandler([]string{update.domain}, update.handler, update.priority) muxUpdateMap[update.domain] = update.handler if existingHandler, ok := s.dnsMuxMap[update.domain]; ok { existingHandler.stop() @@ -493,7 +504,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { existingHandler.stop() } else { existingHandler.stop() - s.service.DeregisterMux(key) + s.deregisterHandler([]string{key}) } } } @@ -547,13 +558,13 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { removeIndex[nbdns.RootZone] = -1 s.currentConfig.RouteAll = false - s.service.DeregisterMux(nbdns.RootZone) + s.deregisterHandler([]string{nbdns.RootZone}) } for i, item := range s.currentConfig.Domains { if _, found := removeIndex[item.Domain]; found { s.currentConfig.Domains[i].Disabled = true - s.service.DeregisterMux(item.Domain) + s.deregisterHandler([]string{item.Domain}) removeIndex[item.Domain] = i } } @@ -584,7 +595,7 @@ func (s *DefaultServer) upstreamCallbacks( continue } s.currentConfig.Domains[i].Disabled = false - s.service.RegisterMux(domain, handler) + s.registerHandler([]string{domain}, handler, PriorityMatchDomain) } l := log.WithField("nameservers", nsGroup.NameServers) @@ -592,7 +603,7 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { s.currentConfig.RouteAll = true - s.service.RegisterMux(nbdns.RootZone, handler) + s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) } if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") @@ -623,7 +634,8 @@ func (s *DefaultServer) addHostRootZone() { } handler.deactivate = func(error) {} handler.reactivate = func() {} - s.service.RegisterMux(nbdns.RootZone, handler) + + s.RegisterHandler([]string{nbdns.RootZone}, handler, PriorityDefault) } func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index e198249ff..729b90cc0 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -68,7 +68,6 @@ func (s *ServiceViaMemory) Stop() { } func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) { - log.Debugf("registering dns handler for pattern: %s", pattern) s.dnsMux.Handle(pattern, handler) } diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index b3baf2fa8..94497c61f 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -66,6 +66,11 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) * } } +// String returns a string representation of the upstream resolver +func (u *upstreamResolverBase) String() string { + return fmt.Sprintf("%v", u.upstreamServers) +} + func (u *upstreamResolverBase) stop() { log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) u.cancel() diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 1ffde7e49..dd9636158 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -6,6 +6,8 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" ) type DNSForwarder struct { @@ -18,6 +20,7 @@ type DNSForwarder struct { } func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder { + log.Debugf("creating DNS forwarder with listen address: %s, ttl: %d, domains: %v", listenAddress, ttl, domains) return &DNSForwarder{ listenAddress: listenAddress, ttl: ttl, @@ -29,7 +32,7 @@ func (f *DNSForwarder) Listen() error { mux := dns.NewServeMux() for _, d := range f.domains { - mux.HandleFunc(d, f.handleDNSQuery) + mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery) } dnsServer := &dns.Server{ @@ -47,8 +50,8 @@ func (f *DNSForwarder) UpdateDomains(domains []string) { f.mux.HandleRemove(d) } - for _, d := range domains { - f.mux.HandleFunc(d, f.handleDNSQuery) + for _, d := range f.domains { + f.mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery) } f.domains = domains } @@ -61,10 +64,10 @@ func (f *DNSForwarder) Close(ctx context.Context) error { } func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { - log.Tracef("received DNS query for DNS forwarder: %v", query) if len(query.Question) == 0 { return } + log.Tracef("received DNS request for DNS forwarder: %v", query.Question[0].Name) question := query.Question[0] domain := question.Name @@ -74,16 +77,17 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { ips, err := net.LookupIP(domain) if err != nil { log.Warnf("failed to resolve query for domain %s: %v", domain, err) - resp.Rcode = dns.RcodeRefused - _ = w.WriteMsg(resp) + resp.Rcode = dns.RcodeServerFailure + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write failure DNS response: %v", err) + } return } for _, ip := range ips { - log.Infof("resolved domain %s to IP %s", domain, ip) var respRecord dns.RR if ip.To4() == nil { - log.Infof("resolved domain %s to IPv6 %s", domain, ip) + log.Tracef("resolved domain %s to IPv6 %s", domain, ip) rr := dns.AAAA{ AAAA: ip, Hdr: dns.RR_Header{ @@ -95,6 +99,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { } respRecord = &rr } else { + log.Tracef("resolved domain %s to IPv4 %s", domain, ip) rr := dns.A{ A: ip, Hdr: dns.RR_Header{ diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index d717b7917..f1fec67e7 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -13,7 +13,6 @@ import ( "time" "github.com/google/uuid" - miekdns "github.com/miekg/dns" "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -103,9 +102,7 @@ func TestEngine_SSH(t *testing.T) { ) engine.dnsServer = &dns.MockServer{ - UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, - RegisterHandlerFunc: func(domains []string, handler miekdns.Handler) error { return nil }, - DeregisterHandlerFunc: func(domains []string) error { return nil }, + UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, } var sshKeysAdded []string diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index d384c2941..0a52485be 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -57,15 +57,12 @@ func New( } func (d *DnsInterceptor) String() string { - s, err := d.route.Domains.String() - if err != nil { - return d.route.Domains.PunycodeString() - } - return s + return d.route.Domains.SafeString() } func (d *DnsInterceptor) AddRoute(context.Context) error { - return d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d) + d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute) + return nil } func (d *DnsInterceptor) RemoveRoute() error { @@ -91,9 +88,7 @@ func (d *DnsInterceptor) RemoveRoute() error { clear(d.interceptedDomains) - if err := d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList()); err != nil { - merr = multierror.Append(merr, fmt.Errorf("unregister DNS handler: %v", err)) - } + d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList()) return nberrors.FormatErrorOrNil(merr) } @@ -143,23 +138,22 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { return } - log.Debugf("received DNS request: %v", r.Question[0].Name) + log.Tracef("received DNS request: %v", r.Question[0].Name) - if d.currentPeerKey == "" { - // TODO: call normal upstream instead of returning an error? - log.Debugf("no current peer key set, not resolving DNS request %s", r.Question[0].Name) - if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { - log.Errorf("failed writing DNS response: %v", err) - } + d.mu.RLock() + peerKey := d.currentPeerKey + d.mu.RUnlock() + + if peerKey == "" { + log.Debugf("no current peer key set, letting next handler try for %s", r.Question[0].Name) + d.continueToNextHandler(w, r, "no current peer key") return } - upstreamIP, err := d.getUpstreamIP() + upstreamIP, err := d.getUpstreamIP(peerKey) if err != nil { log.Errorf("failed to get upstream IP: %v", err) - if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { - log.Errorf("failed writing DNS response: %v", err) - } + d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err)) return } @@ -169,7 +163,12 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort) reply, _, err := client.ExchangeContext(context.Background(), r, upstream) - log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, d.currentPeerKey, r.Question[0].Name, reply.Answer) + + var answer []dns.RR + if reply != nil { + answer = reply.Answer + } + log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, peerKey, r.Question[0].Name, answer) if err != nil { log.Errorf("failed to exchange DNS request with %s: %v", upstream, err) @@ -185,13 +184,22 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } } -func (d *DnsInterceptor) getUpstreamIP() (net.IP, error) { - d.mu.RLock() - defer d.mu.RUnlock() +// continueToNextHandler signals the handler chain to try the next handler +func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) { + log.Debugf("continuing to next handler for %s: %s", r.Question[0].Name, reason) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + // Set Zero bit to signal handler chain to continue + resp.MsgHdr.Zero = true + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed writing DNS continue response: %v", err) + } +} - peerAllowedIP, exists := d.peerStore.AllowedIP(d.currentPeerKey) +func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) { + peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey) if !exists { - return nil, fmt.Errorf("peer connection not found for key: %s", d.currentPeerKey) + return nil, fmt.Errorf("peer connection not found for key: %s", peerKey) } return peerAllowedIP, nil } diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index ac94d4a5c..b71a91f74 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -74,11 +74,7 @@ func NewRoute( } func (r *Route) String() string { - s, err := r.route.Domains.String() - if err != nil { - return r.route.Domains.PunycodeString() - } - return s + return r.route.Domains.SafeString() } func (r *Route) AddRoute(ctx context.Context) error { diff --git a/dns/dns.go b/dns/dns.go index 18528c743..8dfdf8526 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -108,3 +108,9 @@ func GetParsedDomainLabel(name string) (string, error) { return validHost, nil } + +// NormalizeZone returns a normalized domain name without the wildcard prefix +func NormalizeZone(domain string) string { + d, _ := strings.CutPrefix(domain, "*.") + return d +} diff --git a/go.mod b/go.mod index 2b4111ce3..14f800036 100644 --- a/go.mod +++ b/go.mod @@ -207,6 +207,7 @@ require ( github.com/spf13/cast v1.5.0 // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect diff --git a/go.sum b/go.sum index 35abe82d2..3bdeb6827 100644 --- a/go.sum +++ b/go.sum @@ -662,6 +662,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=