mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-14 17:28:56 +02:00
f
This commit is contained in:
@ -3,7 +3,8 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -14,7 +15,7 @@ type MockServer struct {
|
|||||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) RegisterHandler(*dnsinterceptor.RouteMatchHandler) error {
|
func (m *MockServer) RegisterHandler([]string, dns.Handler) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,7 +14,6 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
@ -31,7 +30,7 @@ type IosDnsManager interface {
|
|||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
RegisterHandler(handler *dnsinterceptor.RouteMatchHandler) error
|
RegisterHandler(domains []string, handler dns.Handler) error
|
||||||
Initialize() error
|
Initialize() error
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
DnsIP() string
|
||||||
@ -153,7 +152,16 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
|||||||
return defaultServer
|
return defaultServer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultServer) RegisterHandler(*dnsinterceptor.RouteMatchHandler) error {
|
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler) error {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
log.Debugf("registering handler %s", handler)
|
||||||
|
for _, domain := range domains {
|
||||||
|
pattern := dns.Fqdn(domain)
|
||||||
|
s.service.RegisterMux(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@ -12,12 +11,11 @@ import (
|
|||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RouteMatchHandler struct {
|
type DnsInterceptor struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
route *route.Route
|
route *route.Route
|
||||||
routeRefCounter *refcounter.RouteRefCounter
|
routeRefCounter *refcounter.RouteRefCounter
|
||||||
@ -25,7 +23,7 @@ type RouteMatchHandler struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
dnsServer nbdns.Server
|
dnsServer nbdns.Server
|
||||||
currentPeerKey string
|
currentPeerKey string
|
||||||
domainRoutes map[string]*route.Route
|
interceptedIPs map[string]netip.Prefix
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(
|
func New(
|
||||||
@ -34,144 +32,132 @@ func New(
|
|||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
dnsServer nbdns.Server,
|
dnsServer nbdns.Server,
|
||||||
) routemanager.RouteHandler {
|
) *DnsInterceptor {
|
||||||
|
return &DnsInterceptor{
|
||||||
return &RouteMatchHandler{
|
|
||||||
route: rt,
|
route: rt,
|
||||||
routeRefCounter: routeRefCounter,
|
routeRefCounter: routeRefCounter,
|
||||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
dnsServer: dnsServer,
|
dnsServer: dnsServer,
|
||||||
domainRoutes: make(map[string]*route.Route),
|
interceptedIPs: make(map[string]netip.Prefix),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RouteMatchHandler) String() string {
|
func (h *DnsInterceptor) String() string {
|
||||||
return fmt.Sprintf("dns route for domains: %v", h.route.Domains)
|
s, err := h.route.Domains.String()
|
||||||
|
if err != nil {
|
||||||
|
return h.route.Domains.PunycodeString()
|
||||||
|
}
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RouteMatchHandler) AddRoute(ctx context.Context) error {
|
func (h *DnsInterceptor) AddRoute(context.Context) error {
|
||||||
|
return h.dnsServer.RegisterHandler(h.route.Domains.ToPunycodeList(), h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DnsInterceptor) RemoveRoute() error {
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
defer h.mu.Unlock()
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
for _, domain := range h.route.Domains {
|
// Remove all intercepted IPs
|
||||||
pattern := dns.Fqdn(string(domain))
|
for key, prefix := range h.interceptedIPs {
|
||||||
h.domainRoutes[pattern] = h.route
|
if _, err := h.routeRefCounter.Decrement(prefix); err != nil {
|
||||||
|
log.Errorf("Failed to remove route for IP %s: %v", prefix, err)
|
||||||
|
}
|
||||||
|
if h.currentPeerKey != "" {
|
||||||
|
if _, err := h.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
|
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(h.interceptedIPs, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.dnsServer.RegisterHandler(h)
|
// TODO: remove from mux
|
||||||
}
|
|
||||||
|
|
||||||
func (h *RouteMatchHandler) RemoveRoute() error {
|
|
||||||
h.mu.Lock()
|
|
||||||
defer h.mu.Unlock()
|
|
||||||
|
|
||||||
h.domainRoutes = make(map[string]*route.Route)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RouteMatchHandler) AddAllowedIPs(peerKey string) error {
|
func (h *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
defer h.mu.Unlock()
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
h.currentPeerKey = peerKey
|
h.currentPeerKey = peerKey
|
||||||
|
|
||||||
|
// Re-add all intercepted IPs for the new peer
|
||||||
|
for _, prefix := range h.interceptedIPs {
|
||||||
|
if _, err := h.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||||
|
log.Errorf("Failed to add allowed IP %s: %v", prefix, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RouteMatchHandler) RemoveAllowedIPs() error {
|
func (h *DnsInterceptor) RemoveAllowedIPs() error {
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
defer h.mu.Unlock()
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
if h.currentPeerKey != "" {
|
||||||
|
for _, prefix := range h.interceptedIPs {
|
||||||
|
if _, err := h.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
|
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
h.currentPeerKey = ""
|
h.currentPeerKey = ""
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type responseInterceptor struct {
|
// ServeDNS implements the dns.Handler interface
|
||||||
dns.ResponseWriter
|
func (h *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
handler *RouteMatchHandler
|
log.Debugf("Received DNS request: %v", r)
|
||||||
question dns.Question
|
if len(r.Question) == 0 {
|
||||||
answered bool
|
return
|
||||||
}
|
|
||||||
|
|
||||||
func (i *responseInterceptor) WriteMsg(resp *dns.Msg) error {
|
|
||||||
if i.answered {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
i.answered = true
|
|
||||||
|
|
||||||
if resp == nil || len(resp.Answer) == 0 {
|
|
||||||
return i.ResponseWriter.WriteMsg(resp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
i.handler.mu.RLock()
|
// Create response interceptor to capture the response
|
||||||
defer i.handler.mu.RUnlock()
|
|
||||||
|
|
||||||
questionName := i.question.Name
|
|
||||||
for _, ans := range resp.Answer {
|
|
||||||
var ip netip.Addr
|
|
||||||
switch rr := ans.(type) {
|
|
||||||
case *dns.A:
|
|
||||||
addr, ok := netip.AddrFromSlice(rr.A)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ip = addr
|
|
||||||
case *dns.AAAA:
|
|
||||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ip = addr
|
|
||||||
default:
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if route := i.handler.findMatchingRoute(questionName); route != nil {
|
|
||||||
i.handler.processMatch(route, questionName, ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return i.ResponseWriter.WriteMsg(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *RouteMatchHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
interceptor := &responseInterceptor{
|
interceptor := &responseInterceptor{
|
||||||
ResponseWriter: w,
|
ResponseWriter: w,
|
||||||
handler: h,
|
handler: h,
|
||||||
question: r.Question[0],
|
question: r.Question[0],
|
||||||
|
answered: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
h.dnsServer.ServeDNS(interceptor, r)
|
// Let the request pass through with our interceptor
|
||||||
|
err := interceptor.WriteMsg(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed writing DNS response: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RouteMatchHandler) findMatchingRoute(domain string) *route.Route {
|
func (h *DnsInterceptor) processMatch(domain string, ip netip.Addr) {
|
||||||
domain = strings.ToLower(domain)
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
if route, ok := h.domainRoutes[domain]; ok {
|
|
||||||
return route
|
|
||||||
}
|
|
||||||
|
|
||||||
labels := dns.SplitDomainName(domain)
|
|
||||||
if labels == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < len(labels); i++ {
|
|
||||||
wildcard := "*." + strings.Join(labels[i:], ".") + "."
|
|
||||||
if route, ok := h.domainRoutes[wildcard]; ok {
|
|
||||||
return route
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *RouteMatchHandler) processMatch(route *route.Route, domain string, ip netip.Addr) {
|
|
||||||
network := netip.PrefixFrom(ip, ip.BitLen())
|
network := netip.PrefixFrom(ip, ip.BitLen())
|
||||||
|
key := fmt.Sprintf("%s:%s", domain, network.String())
|
||||||
|
|
||||||
if h.currentPeerKey == "" {
|
if _, exists := h.interceptedIPs[key]; exists {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.allowedIPsRefcounter.Increment(network, h.currentPeerKey); err != nil {
|
if _, err := h.routeRefCounter.Increment(network, struct{}{}); err != nil {
|
||||||
log.Errorf("Failed to add allowed IP %s: %v", network, err)
|
log.Errorf("Failed to add route for IP %s: %v", network, err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.currentPeerKey != "" {
|
||||||
|
if _, err := h.allowedIPsRefcounter.Increment(network, h.currentPeerKey); err != nil {
|
||||||
|
log.Errorf("Failed to add allowed IP %s: %v", network, err)
|
||||||
|
// Rollback route addition
|
||||||
|
if _, err := h.routeRefCounter.Decrement(network); err != nil {
|
||||||
|
log.Errorf("Failed to rollback route addition for IP %s: %v", network, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.interceptedIPs[key] = network
|
||||||
|
log.Debugf("Added route for domain %s -> %s", domain, network)
|
||||||
}
|
}
|
||||||
|
@ -350,7 +350,8 @@ func validateDomains(domains []string) (domain.List, error) {
|
|||||||
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
|
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
|
||||||
}
|
}
|
||||||
|
|
||||||
domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
|
domainRegex := regexp.MustCompile(``)
|
||||||
|
//domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
|
||||||
|
|
||||||
var domainList domain.List
|
var domainList domain.List
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user