From 3251bc79faf77bdea2f0cd733877350d81629bf4 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 13 Jun 2025 13:54:49 +0200 Subject: [PATCH] Preresolve domains --- .../routemanager/dnsinterceptor/handler.go | 156 +++++++++++++----- 1 file changed, 118 insertions(+), 38 deletions(-) diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 78d5e3b30..dcbf63142 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -6,6 +6,7 @@ import ( "net/netip" "strings" "sync" + "time" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" @@ -63,6 +64,118 @@ func (d *DnsInterceptor) AddRoute(context.Context) error { return nil } +// preResolveDomains performs background DNS resolution for non-wildcard domains +func (d *DnsInterceptor) preResolveDomains() { + for _, domain := range d.route.Domains { + domainStr := string(domain) + + if strings.HasPrefix(domainStr, "*.") { + continue + } + + domainStr = strings.TrimSuffix(domainStr, ".") + go func(domain string) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := d.resolveAndUpdateDomain(ctx, domain); err != nil { + log.Debugf("pre-resolve failed for domain %s: %v", domain, err) + } else { + log.Tracef("pre-resolve completed for domain %s", domain) + } + }(domainStr) + } +} + +// resolveAndUpdateDomain performs DNS resolution and updates domain prefixes +func (d *DnsInterceptor) resolveAndUpdateDomain(ctx context.Context, qDomain string) error { + d.mu.RLock() + peerKey := d.currentPeerKey + d.mu.RUnlock() + + if peerKey == "" { + return fmt.Errorf("no current peer key") + } + + upstreamIP, err := d.getUpstreamIP(peerKey) + if err != nil { + return fmt.Errorf("get upstream IP: %v", err) + } + + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(qDomain), dns.TypeA) + msg.Id = dns.Id() + msg.MsgHdr.AuthenticatedData = true + + reply, err := d.exchangeWithUpstream(ctx, msg, upstreamIP) + if err != nil { + return fmt.Errorf("exchange with upstream: %v", err) + } + + if reply == nil || len(reply.Answer) == 0 { + return nil + } + + resolvedDomain := domain.Domain(dns.Fqdn(qDomain)) + return d.processResolveResponse(reply, resolvedDomain, resolvedDomain) +} + +// exchangeWithUpstream performs DNS exchange with the upstream server +func (d *DnsInterceptor) exchangeWithUpstream(ctx context.Context, msg *dns.Msg, upstreamIP netip.Addr) (*dns.Msg, error) { + client := &dns.Client{ + Timeout: nbdns.UpstreamTimeout, + Net: "udp", + } + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) + + reply, _, err := nbdns.ExchangeWithFallback(ctx, client, msg, upstream) + return reply, err +} + +// extractIPsFromDNSResponse extracts IP addresses from DNS answer records +func (d *DnsInterceptor) extractIPsFromDNSResponse(reply *dns.Msg, domainForLogging domain.Domain) []netip.Prefix { + if reply == nil || len(reply.Answer) == 0 { + return nil + } + + var prefixes []netip.Prefix + for _, answer := range reply.Answer { + var ip netip.Addr + switch rr := answer.(type) { + case *dns.A: + addr, ok := netip.AddrFromSlice(rr.A) + if !ok { + log.Tracef("failed to convert A record for domain=%s ip=%v", domainForLogging, rr.A) + continue + } + ip = addr + case *dns.AAAA: + addr, ok := netip.AddrFromSlice(rr.AAAA) + if !ok { + log.Tracef("failed to convert AAAA record for domain=%s ip=%v", domainForLogging, rr.AAAA) + continue + } + ip = addr + default: + continue + } + + prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen()) + prefixes = append(prefixes, prefix) + } + + return prefixes +} + +// processResolveResponse extracts IPs from DNS response and updates domain prefixes +func (d *DnsInterceptor) processResolveResponse(reply *dns.Msg, resolvedDomain, originalDomain domain.Domain) error { + newPrefixes := d.extractIPsFromDNSResponse(reply, resolvedDomain) + if len(newPrefixes) > 0 { + return d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes) + } + return nil +} + func (d *DnsInterceptor) RemoveRoute() error { d.mu.Lock() @@ -113,6 +226,7 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error { } d.currentPeerKey = peerKey + go d.preResolveDomains() return nberrors.FormatErrorOrNil(merr) } @@ -165,12 +279,8 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if r.Extra == nil { r.MsgHdr.AuthenticatedData = true } - client := &dns.Client{ - Timeout: nbdns.UpstreamTimeout, - Net: "udp", - } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) - reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) + + reply, err := d.exchangeWithUpstream(context.TODO(), r, upstreamIP) if err != nil { log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { @@ -235,43 +345,13 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { } resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name)) - - // already punycode via RegisterHandler() originalDomain := domain.Domain(origPattern) if originalDomain == "" { originalDomain = resolvedDomain } - var newPrefixes []netip.Prefix - for _, answer := range r.Answer { - var ip netip.Addr - switch rr := answer.(type) { - case *dns.A: - addr, ok := netip.AddrFromSlice(rr.A) - if !ok { - log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A) - continue - } - ip = addr - case *dns.AAAA: - addr, ok := netip.AddrFromSlice(rr.AAAA) - if !ok { - log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA) - continue - } - ip = addr - default: - continue - } - - prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen()) - newPrefixes = append(newPrefixes, prefix) - } - - if len(newPrefixes) > 0 { - if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil { - log.Errorf("failed to update domain prefixes: %v", err) - } + if err := d.processResolveResponse(r, resolvedDomain, originalDomain); err != nil { + log.Errorf("failed to process DNS response: %v", err) } }