Implement upstream DNS for intercepted domains (#3027)

This commit is contained in:
Viktor Liu 2024-12-11 17:57:30 +01:00 committed by GitHub
parent 619d899047
commit da0a54c6d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 250 additions and 78 deletions

View File

@ -10,12 +10,24 @@ import (
// MockServer is the mock instance of a dns server
type MockServer struct {
InitializeFunc func() error
StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
InitializeFunc func() error
StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func([]string, dns.Handler) error
DeregisterHandlerFunc func([]string) error
}
func (m *MockServer) RegisterHandler([]string, dns.Handler) error {
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler) error {
if m.RegisterHandlerFunc != nil {
return m.RegisterHandlerFunc(domains, handler)
}
return nil
}
func (m *MockServer) DeregisterHandler(domains []string) error {
if m.DeregisterHandlerFunc != nil {
return m.DeregisterHandlerFunc(domains)
}
return nil
}

View File

@ -38,6 +38,7 @@ type Server interface {
OnUpdatedHostDNSServer(strings []string)
SearchDomains() []string
ProbeAvailability()
UnregisterHandler(domains []string) error
}
type registeredHandlerMap map[string]handlerWithStop
@ -166,6 +167,20 @@ func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler) e
return nil
}
func (s *DefaultServer) UnregisterHandler(domains []string) error {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("unregistering handler for domains %s", domains)
for _, domain := range domains {
wosuff, _ := strings.CutPrefix(domain, "*.")
pattern := dns.Fqdn(wosuff)
s.service.DeregisterMux(pattern)
}
return nil
}
// Initialize instantiate host manager and the dns service
func (s *DefaultServer) Initialize() (err error) {
s.mux.Lock()

View File

@ -747,6 +747,11 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
conn.wgProxyRelay = proxy
}
// AllowedIP returns the allowed IP of the remote peer
func (conn *Conn) AllowedIP() net.IP {
return conn.allowedIP
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

View File

@ -3,18 +3,27 @@ package dnsinterceptor
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
type domainMap map[domain.Domain][]netip.Prefix
type DnsInterceptor struct {
mu sync.RWMutex
route *route.Route
@ -23,8 +32,9 @@ type DnsInterceptor struct {
statusRecorder *peer.Status
dnsServer nbdns.Server
currentPeerKey string
interceptedIPs map[string]netip.Prefix
interceptedDomains domainMap
peerConns map[string]*peer.Conn
// TODO: peerConns add lock to sync with engine
}
func New(
@ -41,7 +51,7 @@ func New(
allowedIPsRefcounter: allowedIPsRefCounter,
statusRecorder: statusRecorder,
dnsServer: dnsServer,
interceptedIPs: make(map[string]netip.Prefix),
interceptedDomains: make(domainMap),
peerConns: peerConns,
}
}
@ -62,125 +72,255 @@ func (d *DnsInterceptor) RemoveRoute() error {
d.mu.Lock()
defer d.mu.Unlock()
// Remove all intercepted IPs
for key, prefix := range d.interceptedIPs {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
log.Errorf("Failed to remove route for IP %s: %v", prefix, err)
}
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
}
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
}
}
delete(d.interceptedIPs, key)
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
d.statusRecorder.DeleteResolvedDomainsStates(domain)
}
// TODO: remove from mux
clear(d.interceptedDomains)
return nil
if err := d.dnsServer.UnregisterHandler(d.route.Domains.ToPunycodeList()); err != nil {
merr = multierror.Append(merr, fmt.Errorf("unregister DNS handler: %v", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
d.mu.Lock()
defer d.mu.Unlock()
d.currentPeerKey = peerKey
// Re-add all intercepted IPs for the new peer
for _, prefix := range d.interceptedIPs {
if _, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
log.Errorf("Failed to add allowed IP %s: %v", prefix, err)
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
}
}
return nil
d.currentPeerKey = peerKey
return nberrors.FormatErrorOrNil(merr)
}
func (d *DnsInterceptor) RemoveAllowedIPs() error {
d.mu.Lock()
defer d.mu.Unlock()
if d.currentPeerKey != "" {
for _, prefix := range d.interceptedIPs {
var merr *multierror.Error
for _, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
}
}
d.currentPeerKey = ""
return nil
return nberrors.FormatErrorOrNil(merr)
}
// ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Debugf("received DNS request: %v", r)
if len(r.Question) == 0 {
return
}
log.Debugf("received DNS request: %v", r.Question[0].Name)
if err := d.writeMsg(w, r); err != nil {
if d.currentPeerKey == "" {
// TODO: call normal upstream instead of returning an error?
log.Debugf("no current peer key set, not resolving DNS request %s", r.Question[0].Name)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
return
}
upstreamIP, err := d.getUpstreamIP()
if err != nil {
log.Errorf("failed to get upstream IP: %v", err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
return
}
client := &dns.Client{
Timeout: 5 * time.Second,
Net: "udp",
}
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, d.currentPeerKey, r.Question[0].Name, reply.Answer)
if err != nil {
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
return
}
reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
}
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
if r == nil || len(r.Answer) == 0 {
return w.WriteMsg(r)
func (d *DnsInterceptor) getUpstreamIP() (net.IP, error) {
d.mu.RLock()
defer d.mu.RUnlock()
peerConn, exists := d.peerConns[d.currentPeerKey]
if !exists {
return nil, fmt.Errorf("peer connection not found for key: %s", d.currentPeerKey)
}
for _, ans := range r.Answer {
var ip netip.Addr
switch rr := ans.(type) {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
continue
}
ip = addr
default:
continue
}
d.processMatch(r.Question[0].Name, ip)
}
return w.WriteMsg(r)
return peerConn.AllowedIP(), nil
}
func (d *DnsInterceptor) processMatch(domain string, ip netip.Addr) {
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
if r == nil {
return fmt.Errorf("received nil DNS message")
}
if len(r.Answer) > 0 && len(r.Question) > 0 {
// DNS names from miekg/dns are already in punycode format
dom := domain.Domain(r.Question[0].Name)
var newPrefixes []netip.Prefix
for _, ans := range r.Answer {
var ip netip.Addr
switch rr := ans.(type) {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
log.Debugf("failed to convert A record IP: %v", rr.A)
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
log.Debugf("failed to convert AAAA record IP: %v", rr.AAAA)
continue
}
ip = addr
default:
continue
}
prefix := netip.PrefixFrom(ip, ip.BitLen())
newPrefixes = append(newPrefixes, prefix)
}
if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(dom, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
}
}
}
if err := w.WriteMsg(r); err != nil {
return fmt.Errorf("failed to write DNS response: %v", err)
}
return nil
}
func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes []netip.Prefix) error {
d.mu.Lock()
defer d.mu.Unlock()
network := netip.PrefixFrom(ip, ip.BitLen())
key := fmt.Sprintf("%s:%s", domain, network.String())
oldPrefixes := d.interceptedDomains[domain]
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
if _, exists := d.interceptedIPs[key]; exists {
return
}
var merr *multierror.Error
if _, err := d.routeRefCounter.Increment(network, struct{}{}); err != nil {
log.Errorf("Failed to add route for IP %s: %v", network, err)
return
}
// Add new prefixes
for _, prefix := range toAdd {
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
continue
}
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Increment(network, d.currentPeerKey); err != nil {
log.Errorf("Failed to add allowed IP %s: %v", network, err)
// Rollback route addition
if _, err := d.routeRefCounter.Decrement(network); err != nil {
log.Errorf("Failed to rollback route addition for IP %s: %v", network, err)
}
return
if d.currentPeerKey == "" {
continue
}
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
}
d.interceptedIPs[key] = network
log.Debugf("Added route for domain %s -> %s", domain, network)
if !d.route.KeepRoute {
// Remove old prefixes
for _, prefix := range toRemove {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
}
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
}
}
}
// Update domain prefixes
if len(toAdd) > 0 || len(toRemove) > 0 {
d.interceptedDomains[domain] = newPrefixes
d.statusRecorder.UpdateResolvedDomainsStates(domain, newPrefixes)
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for [%s]: %s", domain.SafeString(), toAdd)
}
if len(toRemove) > 0 {
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), toRemove)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
prefixSet := make(map[netip.Prefix]bool)
for _, prefix := range oldPrefixes {
prefixSet[prefix] = false
}
for _, prefix := range newPrefixes {
if _, exists := prefixSet[prefix]; exists {
prefixSet[prefix] = true
} else {
toAdd = append(toAdd, prefix)
}
}
for prefix, inUse := range prefixSet {
if !inUse {
toRemove = append(toRemove, prefix)
}
}
return
}

View File

@ -346,7 +346,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, nil)
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerConns)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}