mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 02:41:34 +01:00
Implement upstream DNS for intercepted domains (#3027)
This commit is contained in:
parent
619d899047
commit
da0a54c6d6
@ -10,12 +10,24 @@ import (
|
||||
|
||||
// MockServer is the mock instance of a dns server
|
||||
type MockServer struct {
|
||||
InitializeFunc func() error
|
||||
StopFunc func()
|
||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||
InitializeFunc func() error
|
||||
StopFunc func()
|
||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||
RegisterHandlerFunc func([]string, dns.Handler) error
|
||||
DeregisterHandlerFunc func([]string) error
|
||||
}
|
||||
|
||||
func (m *MockServer) RegisterHandler([]string, dns.Handler) error {
|
||||
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler) error {
|
||||
if m.RegisterHandlerFunc != nil {
|
||||
return m.RegisterHandlerFunc(domains, handler)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockServer) DeregisterHandler(domains []string) error {
|
||||
if m.DeregisterHandlerFunc != nil {
|
||||
return m.DeregisterHandlerFunc(domains)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -38,6 +38,7 @@ type Server interface {
|
||||
OnUpdatedHostDNSServer(strings []string)
|
||||
SearchDomains() []string
|
||||
ProbeAvailability()
|
||||
UnregisterHandler(domains []string) error
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[string]handlerWithStop
|
||||
@ -166,6 +167,20 @@ func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) UnregisterHandler(domains []string) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
log.Debugf("unregistering handler for domains %s", domains)
|
||||
for _, domain := range domains {
|
||||
wosuff, _ := strings.CutPrefix(domain, "*.")
|
||||
pattern := dns.Fqdn(wosuff)
|
||||
s.service.DeregisterMux(pattern)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize instantiate host manager and the dns service
|
||||
func (s *DefaultServer) Initialize() (err error) {
|
||||
s.mux.Lock()
|
||||
|
@ -747,6 +747,11 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
||||
conn.wgProxyRelay = proxy
|
||||
}
|
||||
|
||||
// AllowedIP returns the allowed IP of the remote peer
|
||||
func (conn *Conn) AllowedIP() net.IP {
|
||||
return conn.allowedIP
|
||||
}
|
||||
|
||||
func isController(config ConnConfig) bool {
|
||||
return config.LocalKey > config.Key
|
||||
}
|
||||
|
@ -3,18 +3,27 @@ package dnsinterceptor
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type domainMap map[domain.Domain][]netip.Prefix
|
||||
|
||||
type DnsInterceptor struct {
|
||||
mu sync.RWMutex
|
||||
route *route.Route
|
||||
@ -23,8 +32,9 @@ type DnsInterceptor struct {
|
||||
statusRecorder *peer.Status
|
||||
dnsServer nbdns.Server
|
||||
currentPeerKey string
|
||||
interceptedIPs map[string]netip.Prefix
|
||||
interceptedDomains domainMap
|
||||
peerConns map[string]*peer.Conn
|
||||
// TODO: peerConns add lock to sync with engine
|
||||
}
|
||||
|
||||
func New(
|
||||
@ -41,7 +51,7 @@ func New(
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
statusRecorder: statusRecorder,
|
||||
dnsServer: dnsServer,
|
||||
interceptedIPs: make(map[string]netip.Prefix),
|
||||
interceptedDomains: make(domainMap),
|
||||
peerConns: peerConns,
|
||||
}
|
||||
}
|
||||
@ -62,125 +72,255 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
// Remove all intercepted IPs
|
||||
for key, prefix := range d.interceptedIPs {
|
||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||
log.Errorf("Failed to remove route for IP %s: %v", prefix, err)
|
||||
}
|
||||
if d.currentPeerKey != "" {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
|
||||
}
|
||||
if d.currentPeerKey != "" {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(d.interceptedIPs, key)
|
||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||
|
||||
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||
}
|
||||
|
||||
// TODO: remove from mux
|
||||
clear(d.interceptedDomains)
|
||||
|
||||
return nil
|
||||
if err := d.dnsServer.UnregisterHandler(d.route.Domains.ToPunycodeList()); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("unregister DNS handler: %v", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.currentPeerKey = peerKey
|
||||
|
||||
// Re-add all intercepted IPs for the new peer
|
||||
for _, prefix := range d.interceptedIPs {
|
||||
if _, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||
log.Errorf("Failed to add allowed IP %s: %v", prefix, err)
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
d.currentPeerKey = peerKey
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if d.currentPeerKey != "" {
|
||||
for _, prefix := range d.interceptedIPs {
|
||||
var merr *multierror.Error
|
||||
for _, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d.currentPeerKey = ""
|
||||
return nil
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// ServeDNS implements the dns.Handler interface
|
||||
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
log.Debugf("received DNS request: %v", r)
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
log.Debugf("received DNS request: %v", r.Question[0].Name)
|
||||
|
||||
if err := d.writeMsg(w, r); err != nil {
|
||||
if d.currentPeerKey == "" {
|
||||
// TODO: call normal upstream instead of returning an error?
|
||||
log.Debugf("no current peer key set, not resolving DNS request %s", r.Question[0].Name)
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
upstreamIP, err := d.getUpstreamIP()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get upstream IP: %v", err)
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Net: "udp",
|
||||
}
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
|
||||
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
||||
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, d.currentPeerKey, r.Question[0].Name, reply.Answer)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
reply.Id = r.Id
|
||||
if err := d.writeMsg(w, reply); err != nil {
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
if r == nil || len(r.Answer) == 0 {
|
||||
return w.WriteMsg(r)
|
||||
func (d *DnsInterceptor) getUpstreamIP() (net.IP, error) {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
peerConn, exists := d.peerConns[d.currentPeerKey]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("peer connection not found for key: %s", d.currentPeerKey)
|
||||
}
|
||||
|
||||
for _, ans := range r.Answer {
|
||||
var ip netip.Addr
|
||||
switch rr := ans.(type) {
|
||||
case *dns.A:
|
||||
addr, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
case *dns.AAAA:
|
||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
d.processMatch(r.Question[0].Name, ip)
|
||||
}
|
||||
|
||||
return w.WriteMsg(r)
|
||||
return peerConn.AllowedIP(), nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) processMatch(domain string, ip netip.Addr) {
|
||||
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
if r == nil {
|
||||
return fmt.Errorf("received nil DNS message")
|
||||
}
|
||||
|
||||
if len(r.Answer) > 0 && len(r.Question) > 0 {
|
||||
// DNS names from miekg/dns are already in punycode format
|
||||
dom := domain.Domain(r.Question[0].Name)
|
||||
|
||||
var newPrefixes []netip.Prefix
|
||||
for _, ans := range r.Answer {
|
||||
var ip netip.Addr
|
||||
switch rr := ans.(type) {
|
||||
case *dns.A:
|
||||
addr, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
log.Debugf("failed to convert A record IP: %v", rr.A)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
case *dns.AAAA:
|
||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
log.Debugf("failed to convert AAAA record IP: %v", rr.AAAA)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(ip, ip.BitLen())
|
||||
newPrefixes = append(newPrefixes, prefix)
|
||||
}
|
||||
|
||||
if len(newPrefixes) > 0 {
|
||||
if err := d.updateDomainPrefixes(dom, newPrefixes); err != nil {
|
||||
log.Errorf("failed to update domain prefixes: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(r); err != nil {
|
||||
return fmt.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
network := netip.PrefixFrom(ip, ip.BitLen())
|
||||
key := fmt.Sprintf("%s:%s", domain, network.String())
|
||||
oldPrefixes := d.interceptedDomains[domain]
|
||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||
|
||||
if _, exists := d.interceptedIPs[key]; exists {
|
||||
return
|
||||
}
|
||||
var merr *multierror.Error
|
||||
|
||||
if _, err := d.routeRefCounter.Increment(network, struct{}{}); err != nil {
|
||||
log.Errorf("Failed to add route for IP %s: %v", network, err)
|
||||
return
|
||||
}
|
||||
// Add new prefixes
|
||||
for _, prefix := range toAdd {
|
||||
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if d.currentPeerKey != "" {
|
||||
if _, err := d.allowedIPsRefcounter.Increment(network, d.currentPeerKey); err != nil {
|
||||
log.Errorf("Failed to add allowed IP %s: %v", network, err)
|
||||
// Rollback route addition
|
||||
if _, err := d.routeRefCounter.Decrement(network); err != nil {
|
||||
log.Errorf("Failed to rollback route addition for IP %s: %v", network, err)
|
||||
}
|
||||
return
|
||||
if d.currentPeerKey == "" {
|
||||
continue
|
||||
}
|
||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
d.interceptedIPs[key] = network
|
||||
log.Debugf("Added route for domain %s -> %s", domain, network)
|
||||
if !d.route.KeepRoute {
|
||||
// Remove old prefixes
|
||||
for _, prefix := range toRemove {
|
||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
|
||||
}
|
||||
if d.currentPeerKey != "" {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update domain prefixes
|
||||
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||
d.interceptedDomains[domain] = newPrefixes
|
||||
d.statusRecorder.UpdateResolvedDomainsStates(domain, newPrefixes)
|
||||
|
||||
if len(toAdd) > 0 {
|
||||
log.Debugf("added dynamic route(s) for [%s]: %s", domain.SafeString(), toAdd)
|
||||
}
|
||||
if len(toRemove) > 0 {
|
||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), toRemove)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||
prefixSet := make(map[netip.Prefix]bool)
|
||||
for _, prefix := range oldPrefixes {
|
||||
prefixSet[prefix] = false
|
||||
}
|
||||
for _, prefix := range newPrefixes {
|
||||
if _, exists := prefixSet[prefix]; exists {
|
||||
prefixSet[prefix] = true
|
||||
} else {
|
||||
toAdd = append(toAdd, prefix)
|
||||
}
|
||||
}
|
||||
for prefix, inUse := range prefixSet {
|
||||
if !inUse {
|
||||
toRemove = append(toRemove, prefix)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -346,7 +346,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, nil)
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerConns)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user