diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 87de168a5..556438970 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -10,12 +10,24 @@ import ( // MockServer is the mock instance of a dns server type MockServer struct { - InitializeFunc func() error - StopFunc func() - UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + InitializeFunc func() error + StopFunc func() + UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + RegisterHandlerFunc func([]string, dns.Handler) error + DeregisterHandlerFunc func([]string) error } -func (m *MockServer) RegisterHandler([]string, dns.Handler) error { +func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler) error { + if m.RegisterHandlerFunc != nil { + return m.RegisterHandlerFunc(domains, handler) + } + return nil +} + +func (m *MockServer) DeregisterHandler(domains []string) error { + if m.DeregisterHandlerFunc != nil { + return m.DeregisterHandlerFunc(domains) + } return nil } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index f1a9d3155..ff9ebc00a 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -38,6 +38,7 @@ type Server interface { OnUpdatedHostDNSServer(strings []string) SearchDomains() []string ProbeAvailability() + UnregisterHandler(domains []string) error } type registeredHandlerMap map[string]handlerWithStop @@ -166,6 +167,20 @@ func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler) e return nil } +func (s *DefaultServer) UnregisterHandler(domains []string) error { + s.mux.Lock() + defer s.mux.Unlock() + + log.Debugf("unregistering handler for domains %s", domains) + for _, domain := range domains { + wosuff, _ := strings.CutPrefix(domain, "*.") + pattern := dns.Fqdn(wosuff) + s.service.DeregisterMux(pattern) + } + + return nil +} + // Initialize instantiate host manager and the dns service func (s *DefaultServer) Initialize() (err error) { s.mux.Lock() diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 5c2e2cb60..b8cb2582f 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -747,6 +747,11 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { conn.wgProxyRelay = proxy } +// AllowedIP returns the allowed IP of the remote peer +func (conn *Conn) AllowedIP() net.IP { + return conn.allowedIP +} + func isController(config ConnConfig) bool { return config.LocalKey > config.Key } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 4cfa8f042..7e67dbc68 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -3,18 +3,27 @@ package dnsinterceptor import ( "context" "fmt" + "net" "net/netip" + "strings" "sync" + "time" + "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" nbdns "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) +type domainMap map[domain.Domain][]netip.Prefix + type DnsInterceptor struct { mu sync.RWMutex route *route.Route @@ -23,8 +32,9 @@ type DnsInterceptor struct { statusRecorder *peer.Status dnsServer nbdns.Server currentPeerKey string - interceptedIPs map[string]netip.Prefix + interceptedDomains domainMap peerConns map[string]*peer.Conn + // TODO: peerConns add lock to sync with engine } func New( @@ -41,7 +51,7 @@ func New( allowedIPsRefcounter: allowedIPsRefCounter, statusRecorder: statusRecorder, dnsServer: dnsServer, - interceptedIPs: make(map[string]netip.Prefix), + interceptedDomains: make(domainMap), peerConns: peerConns, } } @@ -62,125 +72,255 @@ func (d *DnsInterceptor) RemoveRoute() error { d.mu.Lock() defer d.mu.Unlock() - // Remove all intercepted IPs - for key, prefix := range d.interceptedIPs { - if _, err := d.routeRefCounter.Decrement(prefix); err != nil { - log.Errorf("Failed to remove route for IP %s: %v", prefix, err) - } - if d.currentPeerKey != "" { - if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { - log.Errorf("Failed to remove allowed IP %s: %v", prefix, err) + var merr *multierror.Error + for domain, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { + if _, err := d.routeRefCounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err)) + } + if d.currentPeerKey != "" { + if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) + } } } - delete(d.interceptedIPs, key) + log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", ")) + + d.statusRecorder.DeleteResolvedDomainsStates(domain) } - // TODO: remove from mux + clear(d.interceptedDomains) - return nil + if err := d.dnsServer.UnregisterHandler(d.route.Domains.ToPunycodeList()); err != nil { + merr = multierror.Append(merr, fmt.Errorf("unregister DNS handler: %v", err)) + } + + return nberrors.FormatErrorOrNil(merr) } func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error { d.mu.Lock() defer d.mu.Unlock() - d.currentPeerKey = peerKey - - // Re-add all intercepted IPs for the new peer - for _, prefix := range d.interceptedIPs { - if _, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil { - log.Errorf("Failed to add allowed IP %s: %v", prefix, err) + var merr *multierror.Error + for domain, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { + if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) + } else if ref.Count > 1 && ref.Out != peerKey { + log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", + prefix.Addr(), + domain.SafeString(), + ref.Out, + ) + } } } - return nil + d.currentPeerKey = peerKey + return nberrors.FormatErrorOrNil(merr) } func (d *DnsInterceptor) RemoveAllowedIPs() error { d.mu.Lock() defer d.mu.Unlock() - if d.currentPeerKey != "" { - for _, prefix := range d.interceptedIPs { + var merr *multierror.Error + for _, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { - log.Errorf("Failed to remove allowed IP %s: %v", prefix, err) + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) } } } d.currentPeerKey = "" - return nil + return nberrors.FormatErrorOrNil(merr) } // ServeDNS implements the dns.Handler interface func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - log.Debugf("received DNS request: %v", r) if len(r.Question) == 0 { return } + log.Debugf("received DNS request: %v", r.Question[0].Name) - if err := d.writeMsg(w, r); err != nil { + if d.currentPeerKey == "" { + // TODO: call normal upstream instead of returning an error? + log.Debugf("no current peer key set, not resolving DNS request %s", r.Question[0].Name) + if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { + log.Errorf("failed writing DNS response: %v", err) + } + return + } + + upstreamIP, err := d.getUpstreamIP() + if err != nil { + log.Errorf("failed to get upstream IP: %v", err) + if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { + log.Errorf("failed writing DNS response: %v", err) + } + return + } + + client := &dns.Client{ + Timeout: 5 * time.Second, + Net: "udp", + } + upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort) + reply, _, err := client.ExchangeContext(context.Background(), r, upstream) + log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, d.currentPeerKey, r.Question[0].Name, reply.Answer) + + if err != nil { + log.Errorf("failed to exchange DNS request with %s: %v", upstream, err) + if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { + log.Errorf("failed writing DNS response: %v", err) + } + return + } + + reply.Id = r.Id + if err := d.writeMsg(w, reply); err != nil { log.Errorf("failed writing DNS response: %v", err) } } -func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { - if r == nil || len(r.Answer) == 0 { - return w.WriteMsg(r) +func (d *DnsInterceptor) getUpstreamIP() (net.IP, error) { + d.mu.RLock() + defer d.mu.RUnlock() + + peerConn, exists := d.peerConns[d.currentPeerKey] + if !exists { + return nil, fmt.Errorf("peer connection not found for key: %s", d.currentPeerKey) } - - for _, ans := range r.Answer { - var ip netip.Addr - switch rr := ans.(type) { - case *dns.A: - addr, ok := netip.AddrFromSlice(rr.A) - if !ok { - continue - } - ip = addr - case *dns.AAAA: - addr, ok := netip.AddrFromSlice(rr.AAAA) - if !ok { - continue - } - ip = addr - default: - continue - } - - d.processMatch(r.Question[0].Name, ip) - } - - return w.WriteMsg(r) + return peerConn.AllowedIP(), nil } -func (d *DnsInterceptor) processMatch(domain string, ip netip.Addr) { +func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { + if r == nil { + return fmt.Errorf("received nil DNS message") + } + + if len(r.Answer) > 0 && len(r.Question) > 0 { + // DNS names from miekg/dns are already in punycode format + dom := domain.Domain(r.Question[0].Name) + + var newPrefixes []netip.Prefix + for _, ans := range r.Answer { + var ip netip.Addr + switch rr := ans.(type) { + case *dns.A: + addr, ok := netip.AddrFromSlice(rr.A) + if !ok { + log.Debugf("failed to convert A record IP: %v", rr.A) + continue + } + ip = addr + case *dns.AAAA: + addr, ok := netip.AddrFromSlice(rr.AAAA) + if !ok { + log.Debugf("failed to convert AAAA record IP: %v", rr.AAAA) + continue + } + ip = addr + default: + continue + } + + prefix := netip.PrefixFrom(ip, ip.BitLen()) + newPrefixes = append(newPrefixes, prefix) + } + + if len(newPrefixes) > 0 { + if err := d.updateDomainPrefixes(dom, newPrefixes); err != nil { + log.Errorf("failed to update domain prefixes: %v", err) + } + } + } + + if err := w.WriteMsg(r); err != nil { + return fmt.Errorf("failed to write DNS response: %v", err) + } + + return nil +} + +func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes []netip.Prefix) error { d.mu.Lock() defer d.mu.Unlock() - network := netip.PrefixFrom(ip, ip.BitLen()) - key := fmt.Sprintf("%s:%s", domain, network.String()) + oldPrefixes := d.interceptedDomains[domain] + toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes) - if _, exists := d.interceptedIPs[key]; exists { - return - } + var merr *multierror.Error - if _, err := d.routeRefCounter.Increment(network, struct{}{}); err != nil { - log.Errorf("Failed to add route for IP %s: %v", network, err) - return - } + // Add new prefixes + for _, prefix := range toAdd { + if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err)) + continue + } - if d.currentPeerKey != "" { - if _, err := d.allowedIPsRefcounter.Increment(network, d.currentPeerKey); err != nil { - log.Errorf("Failed to add allowed IP %s: %v", network, err) - // Rollback route addition - if _, err := d.routeRefCounter.Decrement(network); err != nil { - log.Errorf("Failed to rollback route addition for IP %s: %v", network, err) - } - return + if d.currentPeerKey == "" { + continue + } + if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) + } else if ref.Count > 1 && ref.Out != d.currentPeerKey { + log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", + prefix.Addr(), + domain.SafeString(), + ref.Out, + ) } } - d.interceptedIPs[key] = network - log.Debugf("Added route for domain %s -> %s", domain, network) + if !d.route.KeepRoute { + // Remove old prefixes + for _, prefix := range toRemove { + if _, err := d.routeRefCounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err)) + } + if d.currentPeerKey != "" { + if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) + } + } + } + } + + // Update domain prefixes + if len(toAdd) > 0 || len(toRemove) > 0 { + d.interceptedDomains[domain] = newPrefixes + d.statusRecorder.UpdateResolvedDomainsStates(domain, newPrefixes) + + if len(toAdd) > 0 { + log.Debugf("added dynamic route(s) for [%s]: %s", domain.SafeString(), toAdd) + } + if len(toRemove) > 0 { + log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), toRemove) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + +func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) { + prefixSet := make(map[netip.Prefix]bool) + for _, prefix := range oldPrefixes { + prefixSet[prefix] = false + } + for _, prefix := range newPrefixes { + if _, exists := prefixSet[prefix]; exists { + prefixSet[prefix] = true + } else { + toAdd = append(toAdd, prefix) + } + } + for prefix, inUse := range prefixSet { + if !inUse { + toRemove = append(toRemove, prefix) + } + } + return } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 3afdd509e..ab2ce0361 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -346,7 +346,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout for id, routes := range networks { clientNetworkWatcher, found := m.clientNetworks[id] if !found { - clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, nil) + clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerConns) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() }