Implement upstream DNS for intercepted domains (#3027)

This commit is contained in:
Viktor Liu 2024-12-11 17:57:30 +01:00 committed by GitHub
parent 619d899047
commit da0a54c6d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 250 additions and 78 deletions

View File

@ -13,9 +13,21 @@ type MockServer struct {
InitializeFunc func() error InitializeFunc func() error
StopFunc func() StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func([]string, dns.Handler) error
DeregisterHandlerFunc func([]string) error
} }
func (m *MockServer) RegisterHandler([]string, dns.Handler) error { func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler) error {
if m.RegisterHandlerFunc != nil {
return m.RegisterHandlerFunc(domains, handler)
}
return nil
}
func (m *MockServer) DeregisterHandler(domains []string) error {
if m.DeregisterHandlerFunc != nil {
return m.DeregisterHandlerFunc(domains)
}
return nil return nil
} }

View File

@ -38,6 +38,7 @@ type Server interface {
OnUpdatedHostDNSServer(strings []string) OnUpdatedHostDNSServer(strings []string)
SearchDomains() []string SearchDomains() []string
ProbeAvailability() ProbeAvailability()
UnregisterHandler(domains []string) error
} }
type registeredHandlerMap map[string]handlerWithStop type registeredHandlerMap map[string]handlerWithStop
@ -166,6 +167,20 @@ func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler) e
return nil return nil
} }
func (s *DefaultServer) UnregisterHandler(domains []string) error {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("unregistering handler for domains %s", domains)
for _, domain := range domains {
wosuff, _ := strings.CutPrefix(domain, "*.")
pattern := dns.Fqdn(wosuff)
s.service.DeregisterMux(pattern)
}
return nil
}
// Initialize instantiate host manager and the dns service // Initialize instantiate host manager and the dns service
func (s *DefaultServer) Initialize() (err error) { func (s *DefaultServer) Initialize() (err error) {
s.mux.Lock() s.mux.Lock()

View File

@ -747,6 +747,11 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
conn.wgProxyRelay = proxy conn.wgProxyRelay = proxy
} }
// AllowedIP returns the allowed IP of the remote peer
func (conn *Conn) AllowedIP() net.IP {
return conn.allowedIP
}
func isController(config ConnConfig) bool { func isController(config ConnConfig) bool {
return config.LocalKey > config.Key return config.LocalKey > config.Key
} }

View File

@ -3,18 +3,27 @@ package dnsinterceptor
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"strings"
"sync" "sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type domainMap map[domain.Domain][]netip.Prefix
type DnsInterceptor struct { type DnsInterceptor struct {
mu sync.RWMutex mu sync.RWMutex
route *route.Route route *route.Route
@ -23,8 +32,9 @@ type DnsInterceptor struct {
statusRecorder *peer.Status statusRecorder *peer.Status
dnsServer nbdns.Server dnsServer nbdns.Server
currentPeerKey string currentPeerKey string
interceptedIPs map[string]netip.Prefix interceptedDomains domainMap
peerConns map[string]*peer.Conn peerConns map[string]*peer.Conn
// TODO: peerConns add lock to sync with engine
} }
func New( func New(
@ -41,7 +51,7 @@ func New(
allowedIPsRefcounter: allowedIPsRefCounter, allowedIPsRefcounter: allowedIPsRefCounter,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
dnsServer: dnsServer, dnsServer: dnsServer,
interceptedIPs: make(map[string]netip.Prefix), interceptedDomains: make(domainMap),
peerConns: peerConns, peerConns: peerConns,
} }
} }
@ -62,85 +72,154 @@ func (d *DnsInterceptor) RemoveRoute() error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
// Remove all intercepted IPs var merr *multierror.Error
for key, prefix := range d.interceptedIPs { for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil { if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
log.Errorf("Failed to remove route for IP %s: %v", prefix, err) merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
} }
if d.currentPeerKey != "" { if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err) merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
} }
} }
delete(d.interceptedIPs, key) }
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
d.statusRecorder.DeleteResolvedDomainsStates(domain)
} }
// TODO: remove from mux clear(d.interceptedDomains)
return nil if err := d.dnsServer.UnregisterHandler(d.route.Domains.ToPunycodeList()); err != nil {
merr = multierror.Append(merr, fmt.Errorf("unregister DNS handler: %v", err))
}
return nberrors.FormatErrorOrNil(merr)
} }
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error { func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
}
}
d.currentPeerKey = peerKey d.currentPeerKey = peerKey
return nberrors.FormatErrorOrNil(merr)
// Re-add all intercepted IPs for the new peer
for _, prefix := range d.interceptedIPs {
if _, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
log.Errorf("Failed to add allowed IP %s: %v", prefix, err)
}
}
return nil
} }
func (d *DnsInterceptor) RemoveAllowedIPs() error { func (d *DnsInterceptor) RemoveAllowedIPs() error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
if d.currentPeerKey != "" { var merr *multierror.Error
for _, prefix := range d.interceptedIPs { for _, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err) merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
} }
} }
} }
d.currentPeerKey = "" d.currentPeerKey = ""
return nil return nberrors.FormatErrorOrNil(merr)
} }
// ServeDNS implements the dns.Handler interface // ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Debugf("received DNS request: %v", r)
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return
} }
log.Debugf("received DNS request: %v", r.Question[0].Name)
if err := d.writeMsg(w, r); err != nil { if d.currentPeerKey == "" {
// TODO: call normal upstream instead of returning an error?
log.Debugf("no current peer key set, not resolving DNS request %s", r.Question[0].Name)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
return
}
upstreamIP, err := d.getUpstreamIP()
if err != nil {
log.Errorf("failed to get upstream IP: %v", err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
return
}
client := &dns.Client{
Timeout: 5 * time.Second,
Net: "udp",
}
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, d.currentPeerKey, r.Question[0].Name, reply.Answer)
if err != nil {
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
return
}
reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil {
log.Errorf("failed writing DNS response: %v", err) log.Errorf("failed writing DNS response: %v", err)
} }
} }
func (d *DnsInterceptor) getUpstreamIP() (net.IP, error) {
d.mu.RLock()
defer d.mu.RUnlock()
peerConn, exists := d.peerConns[d.currentPeerKey]
if !exists {
return nil, fmt.Errorf("peer connection not found for key: %s", d.currentPeerKey)
}
return peerConn.AllowedIP(), nil
}
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
if r == nil || len(r.Answer) == 0 { if r == nil {
return w.WriteMsg(r) return fmt.Errorf("received nil DNS message")
} }
if len(r.Answer) > 0 && len(r.Question) > 0 {
// DNS names from miekg/dns are already in punycode format
dom := domain.Domain(r.Question[0].Name)
var newPrefixes []netip.Prefix
for _, ans := range r.Answer { for _, ans := range r.Answer {
var ip netip.Addr var ip netip.Addr
switch rr := ans.(type) { switch rr := ans.(type) {
case *dns.A: case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A) addr, ok := netip.AddrFromSlice(rr.A)
if !ok { if !ok {
log.Debugf("failed to convert A record IP: %v", rr.A)
continue continue
} }
ip = addr ip = addr
case *dns.AAAA: case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA) addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok { if !ok {
log.Debugf("failed to convert AAAA record IP: %v", rr.AAAA)
continue continue
} }
ip = addr ip = addr
@ -148,39 +227,100 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
continue continue
} }
d.processMatch(r.Question[0].Name, ip) prefix := netip.PrefixFrom(ip, ip.BitLen())
newPrefixes = append(newPrefixes, prefix)
} }
return w.WriteMsg(r) if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(dom, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
}
}
}
if err := w.WriteMsg(r); err != nil {
return fmt.Errorf("failed to write DNS response: %v", err)
}
return nil
} }
func (d *DnsInterceptor) processMatch(domain string, ip netip.Addr) { func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes []netip.Prefix) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
network := netip.PrefixFrom(ip, ip.BitLen()) oldPrefixes := d.interceptedDomains[domain]
key := fmt.Sprintf("%s:%s", domain, network.String()) toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
if _, exists := d.interceptedIPs[key]; exists { var merr *multierror.Error
return
// Add new prefixes
for _, prefix := range toAdd {
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
continue
} }
if _, err := d.routeRefCounter.Increment(network, struct{}{}); err != nil { if d.currentPeerKey == "" {
log.Errorf("Failed to add route for IP %s: %v", network, err) continue
return }
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
} }
if !d.route.KeepRoute {
// Remove old prefixes
for _, prefix := range toRemove {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
}
if d.currentPeerKey != "" { if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Increment(network, d.currentPeerKey); err != nil { if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
log.Errorf("Failed to add allowed IP %s: %v", network, err) merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
// Rollback route addition }
if _, err := d.routeRefCounter.Decrement(network); err != nil {
log.Errorf("Failed to rollback route addition for IP %s: %v", network, err)
} }
return
} }
} }
d.interceptedIPs[key] = network // Update domain prefixes
log.Debugf("Added route for domain %s -> %s", domain, network) if len(toAdd) > 0 || len(toRemove) > 0 {
d.interceptedDomains[domain] = newPrefixes
d.statusRecorder.UpdateResolvedDomainsStates(domain, newPrefixes)
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for [%s]: %s", domain.SafeString(), toAdd)
}
if len(toRemove) > 0 {
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), toRemove)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
prefixSet := make(map[netip.Prefix]bool)
for _, prefix := range oldPrefixes {
prefixSet[prefix] = false
}
for _, prefix := range newPrefixes {
if _, exists := prefixSet[prefix]; exists {
prefixSet[prefix] = true
} else {
toAdd = append(toAdd, prefix)
}
}
for prefix, inUse := range prefixSet {
if !inUse {
toRemove = append(toRemove, prefix)
}
}
return
} }

View File

@ -346,7 +346,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks { for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id] clientNetworkWatcher, found := m.clientNetworks[id]
if !found { if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, nil) clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerConns)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
} }