diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index c3e613553..87de168a5 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -3,7 +3,8 @@ package dns import ( "fmt" - "github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor" + "github.com/miekg/dns" + nbdns "github.com/netbirdio/netbird/dns" ) @@ -14,7 +15,7 @@ type MockServer struct { UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error } -func (m *MockServer) RegisterHandler(*dnsinterceptor.RouteMatchHandler) error { +func (m *MockServer) RegisterHandler([]string, dns.Handler) error { return nil } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 3d424f9b1..198c0f11c 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -14,7 +14,6 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor" "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" ) @@ -31,7 +30,7 @@ type IosDnsManager interface { // Server is a dns server interface type Server interface { - RegisterHandler(handler *dnsinterceptor.RouteMatchHandler) error + RegisterHandler(domains []string, handler dns.Handler) error Initialize() error Stop() DnsIP() string @@ -153,7 +152,16 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi return defaultServer } -func (m *DefaultServer) RegisterHandler(*dnsinterceptor.RouteMatchHandler) error { +func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler) error { + s.mux.Lock() + defer s.mux.Unlock() + + log.Debugf("registering handler %s", handler) + for _, domain := range domains { + pattern := dns.Fqdn(domain) + s.service.RegisterMux(pattern, handler) + } + return nil } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 8cc5196de..5b47f5993 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net/netip" - "strings" "sync" "github.com/miekg/dns" @@ -12,12 +11,11 @@ import ( nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/route" ) -type RouteMatchHandler struct { +type DnsInterceptor struct { mu sync.RWMutex route *route.Route routeRefCounter *refcounter.RouteRefCounter @@ -25,7 +23,7 @@ type RouteMatchHandler struct { statusRecorder *peer.Status dnsServer nbdns.Server currentPeerKey string - domainRoutes map[string]*route.Route + interceptedIPs map[string]netip.Prefix } func New( @@ -34,144 +32,132 @@ func New( allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, statusRecorder *peer.Status, dnsServer nbdns.Server, -) routemanager.RouteHandler { - - return &RouteMatchHandler{ +) *DnsInterceptor { + return &DnsInterceptor{ route: rt, routeRefCounter: routeRefCounter, allowedIPsRefcounter: allowedIPsRefCounter, statusRecorder: statusRecorder, dnsServer: dnsServer, - domainRoutes: make(map[string]*route.Route), + interceptedIPs: make(map[string]netip.Prefix), } } -func (h *RouteMatchHandler) String() string { - return fmt.Sprintf("dns route for domains: %v", h.route.Domains) +func (h *DnsInterceptor) String() string { + s, err := h.route.Domains.String() + if err != nil { + return h.route.Domains.PunycodeString() + } + return s } -func (h *RouteMatchHandler) AddRoute(ctx context.Context) error { +func (h *DnsInterceptor) AddRoute(context.Context) error { + return h.dnsServer.RegisterHandler(h.route.Domains.ToPunycodeList(), h) +} + +func (h *DnsInterceptor) RemoveRoute() error { h.mu.Lock() defer h.mu.Unlock() - for _, domain := range h.route.Domains { - pattern := dns.Fqdn(string(domain)) - h.domainRoutes[pattern] = h.route + // Remove all intercepted IPs + for key, prefix := range h.interceptedIPs { + if _, err := h.routeRefCounter.Decrement(prefix); err != nil { + log.Errorf("Failed to remove route for IP %s: %v", prefix, err) + } + if h.currentPeerKey != "" { + if _, err := h.allowedIPsRefcounter.Decrement(prefix); err != nil { + log.Errorf("Failed to remove allowed IP %s: %v", prefix, err) + } + } + delete(h.interceptedIPs, key) } - return h.dnsServer.RegisterHandler(h) -} + // TODO: remove from mux -func (h *RouteMatchHandler) RemoveRoute() error { - h.mu.Lock() - defer h.mu.Unlock() - - h.domainRoutes = make(map[string]*route.Route) return nil } -func (h *RouteMatchHandler) AddAllowedIPs(peerKey string) error { +func (h *DnsInterceptor) AddAllowedIPs(peerKey string) error { h.mu.Lock() defer h.mu.Unlock() + h.currentPeerKey = peerKey + + // Re-add all intercepted IPs for the new peer + for _, prefix := range h.interceptedIPs { + if _, err := h.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil { + log.Errorf("Failed to add allowed IP %s: %v", prefix, err) + } + } + return nil } -func (h *RouteMatchHandler) RemoveAllowedIPs() error { +func (h *DnsInterceptor) RemoveAllowedIPs() error { h.mu.Lock() defer h.mu.Unlock() + + if h.currentPeerKey != "" { + for _, prefix := range h.interceptedIPs { + if _, err := h.allowedIPsRefcounter.Decrement(prefix); err != nil { + log.Errorf("Failed to remove allowed IP %s: %v", prefix, err) + } + } + } + h.currentPeerKey = "" return nil } -type responseInterceptor struct { - dns.ResponseWriter - handler *RouteMatchHandler - question dns.Question - answered bool -} - -func (i *responseInterceptor) WriteMsg(resp *dns.Msg) error { - if i.answered { - return nil - } - i.answered = true - - if resp == nil || len(resp.Answer) == 0 { - return i.ResponseWriter.WriteMsg(resp) +// ServeDNS implements the dns.Handler interface +func (h *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + log.Debugf("Received DNS request: %v", r) + if len(r.Question) == 0 { + return } - i.handler.mu.RLock() - defer i.handler.mu.RUnlock() - - questionName := i.question.Name - for _, ans := range resp.Answer { - var ip netip.Addr - switch rr := ans.(type) { - case *dns.A: - addr, ok := netip.AddrFromSlice(rr.A) - if !ok { - continue - } - ip = addr - case *dns.AAAA: - addr, ok := netip.AddrFromSlice(rr.AAAA) - if !ok { - continue - } - ip = addr - default: - continue - } - - if route := i.handler.findMatchingRoute(questionName); route != nil { - i.handler.processMatch(route, questionName, ip) - } - } - - return i.ResponseWriter.WriteMsg(resp) -} - -func (h *RouteMatchHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + // Create response interceptor to capture the response interceptor := &responseInterceptor{ ResponseWriter: w, handler: h, question: r.Question[0], + answered: false, } - h.dnsServer.ServeDNS(interceptor, r) + // Let the request pass through with our interceptor + err := interceptor.WriteMsg(r) + if err != nil { + log.Errorf("Failed writing DNS response: %v", err) + } } -func (h *RouteMatchHandler) findMatchingRoute(domain string) *route.Route { - domain = strings.ToLower(domain) +func (h *DnsInterceptor) processMatch(domain string, ip netip.Addr) { + h.mu.Lock() + defer h.mu.Unlock() - if route, ok := h.domainRoutes[domain]; ok { - return route - } - - labels := dns.SplitDomainName(domain) - if labels == nil { - return nil - } - - for i := 0; i < len(labels); i++ { - wildcard := "*." + strings.Join(labels[i:], ".") + "." - if route, ok := h.domainRoutes[wildcard]; ok { - return route - } - } - - return nil -} - -func (h *RouteMatchHandler) processMatch(route *route.Route, domain string, ip netip.Addr) { network := netip.PrefixFrom(ip, ip.BitLen()) + key := fmt.Sprintf("%s:%s", domain, network.String()) - if h.currentPeerKey == "" { + if _, exists := h.interceptedIPs[key]; exists { return } - if err := h.allowedIPsRefcounter.Increment(network, h.currentPeerKey); err != nil { - log.Errorf("Failed to add allowed IP %s: %v", network, err) + if _, err := h.routeRefCounter.Increment(network, struct{}{}); err != nil { + log.Errorf("Failed to add route for IP %s: %v", network, err) + return } + + if h.currentPeerKey != "" { + if _, err := h.allowedIPsRefcounter.Increment(network, h.currentPeerKey); err != nil { + log.Errorf("Failed to add allowed IP %s: %v", network, err) + // Rollback route addition + if _, err := h.routeRefCounter.Decrement(network); err != nil { + log.Errorf("Failed to rollback route addition for IP %s: %v", network, err) + } + return + } + } + + h.interceptedIPs[key] = network + log.Debugf("Added route for domain %s -> %s", domain, network) } diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index f44a164e2..0b469cd08 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -350,7 +350,8 @@ func validateDomains(domains []string) (domain.List, error) { return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) } - domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) + domainRegex := regexp.MustCompile(``) + //domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) var domainList domain.List