This commit is contained in:
Viktor Liu
2024-12-10 15:57:06 +01:00
parent d77ac20760
commit 16a2867d69
4 changed files with 98 additions and 102 deletions

View File

@ -3,7 +3,8 @@ package dns
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor" "github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
@ -14,7 +15,7 @@ type MockServer struct {
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
} }
func (m *MockServer) RegisterHandler(*dnsinterceptor.RouteMatchHandler) error { func (m *MockServer) RegisterHandler([]string, dns.Handler) error {
return nil return nil
} }

View File

@ -14,7 +14,6 @@ import (
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
@ -31,7 +30,7 @@ type IosDnsManager interface {
// Server is a dns server interface // Server is a dns server interface
type Server interface { type Server interface {
RegisterHandler(handler *dnsinterceptor.RouteMatchHandler) error RegisterHandler(domains []string, handler dns.Handler) error
Initialize() error Initialize() error
Stop() Stop()
DnsIP() string DnsIP() string
@ -153,7 +152,16 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
return defaultServer return defaultServer
} }
func (m *DefaultServer) RegisterHandler(*dnsinterceptor.RouteMatchHandler) error { func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler) error {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("registering handler %s", handler)
for _, domain := range domains {
pattern := dns.Fqdn(domain)
s.service.RegisterMux(pattern, handler)
}
return nil return nil
} }

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"net/netip" "net/netip"
"strings"
"sync" "sync"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -12,12 +11,11 @@ import (
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type RouteMatchHandler struct { type DnsInterceptor struct {
mu sync.RWMutex mu sync.RWMutex
route *route.Route route *route.Route
routeRefCounter *refcounter.RouteRefCounter routeRefCounter *refcounter.RouteRefCounter
@ -25,7 +23,7 @@ type RouteMatchHandler struct {
statusRecorder *peer.Status statusRecorder *peer.Status
dnsServer nbdns.Server dnsServer nbdns.Server
currentPeerKey string currentPeerKey string
domainRoutes map[string]*route.Route interceptedIPs map[string]netip.Prefix
} }
func New( func New(
@ -34,144 +32,132 @@ func New(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
statusRecorder *peer.Status, statusRecorder *peer.Status,
dnsServer nbdns.Server, dnsServer nbdns.Server,
) routemanager.RouteHandler { ) *DnsInterceptor {
return &DnsInterceptor{
return &RouteMatchHandler{
route: rt, route: rt,
routeRefCounter: routeRefCounter, routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter, allowedIPsRefcounter: allowedIPsRefCounter,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
dnsServer: dnsServer, dnsServer: dnsServer,
domainRoutes: make(map[string]*route.Route), interceptedIPs: make(map[string]netip.Prefix),
} }
} }
func (h *RouteMatchHandler) String() string { func (h *DnsInterceptor) String() string {
return fmt.Sprintf("dns route for domains: %v", h.route.Domains) s, err := h.route.Domains.String()
if err != nil {
return h.route.Domains.PunycodeString()
}
return s
} }
func (h *RouteMatchHandler) AddRoute(ctx context.Context) error { func (h *DnsInterceptor) AddRoute(context.Context) error {
return h.dnsServer.RegisterHandler(h.route.Domains.ToPunycodeList(), h)
}
func (h *DnsInterceptor) RemoveRoute() error {
h.mu.Lock() h.mu.Lock()
defer h.mu.Unlock() defer h.mu.Unlock()
for _, domain := range h.route.Domains { // Remove all intercepted IPs
pattern := dns.Fqdn(string(domain)) for key, prefix := range h.interceptedIPs {
h.domainRoutes[pattern] = h.route if _, err := h.routeRefCounter.Decrement(prefix); err != nil {
log.Errorf("Failed to remove route for IP %s: %v", prefix, err)
}
if h.currentPeerKey != "" {
if _, err := h.allowedIPsRefcounter.Decrement(prefix); err != nil {
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
}
}
delete(h.interceptedIPs, key)
} }
return h.dnsServer.RegisterHandler(h) // TODO: remove from mux
}
func (h *RouteMatchHandler) RemoveRoute() error {
h.mu.Lock()
defer h.mu.Unlock()
h.domainRoutes = make(map[string]*route.Route)
return nil return nil
} }
func (h *RouteMatchHandler) AddAllowedIPs(peerKey string) error { func (h *DnsInterceptor) AddAllowedIPs(peerKey string) error {
h.mu.Lock() h.mu.Lock()
defer h.mu.Unlock() defer h.mu.Unlock()
h.currentPeerKey = peerKey h.currentPeerKey = peerKey
// Re-add all intercepted IPs for the new peer
for _, prefix := range h.interceptedIPs {
if _, err := h.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
log.Errorf("Failed to add allowed IP %s: %v", prefix, err)
}
}
return nil return nil
} }
func (h *RouteMatchHandler) RemoveAllowedIPs() error { func (h *DnsInterceptor) RemoveAllowedIPs() error {
h.mu.Lock() h.mu.Lock()
defer h.mu.Unlock() defer h.mu.Unlock()
if h.currentPeerKey != "" {
for _, prefix := range h.interceptedIPs {
if _, err := h.allowedIPsRefcounter.Decrement(prefix); err != nil {
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
}
}
}
h.currentPeerKey = "" h.currentPeerKey = ""
return nil return nil
} }
type responseInterceptor struct { // ServeDNS implements the dns.Handler interface
dns.ResponseWriter func (h *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
handler *RouteMatchHandler log.Debugf("Received DNS request: %v", r)
question dns.Question if len(r.Question) == 0 {
answered bool return
}
func (i *responseInterceptor) WriteMsg(resp *dns.Msg) error {
if i.answered {
return nil
}
i.answered = true
if resp == nil || len(resp.Answer) == 0 {
return i.ResponseWriter.WriteMsg(resp)
} }
i.handler.mu.RLock() // Create response interceptor to capture the response
defer i.handler.mu.RUnlock()
questionName := i.question.Name
for _, ans := range resp.Answer {
var ip netip.Addr
switch rr := ans.(type) {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
continue
}
ip = addr
default:
continue
}
if route := i.handler.findMatchingRoute(questionName); route != nil {
i.handler.processMatch(route, questionName, ip)
}
}
return i.ResponseWriter.WriteMsg(resp)
}
func (h *RouteMatchHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
interceptor := &responseInterceptor{ interceptor := &responseInterceptor{
ResponseWriter: w, ResponseWriter: w,
handler: h, handler: h,
question: r.Question[0], question: r.Question[0],
answered: false,
} }
h.dnsServer.ServeDNS(interceptor, r) // Let the request pass through with our interceptor
err := interceptor.WriteMsg(r)
if err != nil {
log.Errorf("Failed writing DNS response: %v", err)
}
} }
func (h *RouteMatchHandler) findMatchingRoute(domain string) *route.Route { func (h *DnsInterceptor) processMatch(domain string, ip netip.Addr) {
domain = strings.ToLower(domain) h.mu.Lock()
defer h.mu.Unlock()
if route, ok := h.domainRoutes[domain]; ok {
return route
}
labels := dns.SplitDomainName(domain)
if labels == nil {
return nil
}
for i := 0; i < len(labels); i++ {
wildcard := "*." + strings.Join(labels[i:], ".") + "."
if route, ok := h.domainRoutes[wildcard]; ok {
return route
}
}
return nil
}
func (h *RouteMatchHandler) processMatch(route *route.Route, domain string, ip netip.Addr) {
network := netip.PrefixFrom(ip, ip.BitLen()) network := netip.PrefixFrom(ip, ip.BitLen())
key := fmt.Sprintf("%s:%s", domain, network.String())
if h.currentPeerKey == "" { if _, exists := h.interceptedIPs[key]; exists {
return return
} }
if err := h.allowedIPsRefcounter.Increment(network, h.currentPeerKey); err != nil { if _, err := h.routeRefCounter.Increment(network, struct{}{}); err != nil {
log.Errorf("Failed to add allowed IP %s: %v", network, err) log.Errorf("Failed to add route for IP %s: %v", network, err)
return
} }
if h.currentPeerKey != "" {
if _, err := h.allowedIPsRefcounter.Increment(network, h.currentPeerKey); err != nil {
log.Errorf("Failed to add allowed IP %s: %v", network, err)
// Rollback route addition
if _, err := h.routeRefCounter.Decrement(network); err != nil {
log.Errorf("Failed to rollback route addition for IP %s: %v", network, err)
}
return
}
}
h.interceptedIPs[key] = network
log.Debugf("Added route for domain %s -> %s", domain, network)
} }

View File

@ -350,7 +350,8 @@ func validateDomains(domains []string) (domain.List, error) {
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
} }
domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) domainRegex := regexp.MustCompile(``)
//domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
var domainList domain.List var domainList domain.List