mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 02:08:40 +02:00
Preresolve domains
This commit is contained in:
parent
bdf2994e97
commit
3251bc79fa
@ -6,6 +6,7 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
@ -63,6 +64,118 @@ func (d *DnsInterceptor) AddRoute(context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// preResolveDomains performs background DNS resolution for non-wildcard domains
|
||||
func (d *DnsInterceptor) preResolveDomains() {
|
||||
for _, domain := range d.route.Domains {
|
||||
domainStr := string(domain)
|
||||
|
||||
if strings.HasPrefix(domainStr, "*.") {
|
||||
continue
|
||||
}
|
||||
|
||||
domainStr = strings.TrimSuffix(domainStr, ".")
|
||||
go func(domain string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := d.resolveAndUpdateDomain(ctx, domain); err != nil {
|
||||
log.Debugf("pre-resolve failed for domain %s: %v", domain, err)
|
||||
} else {
|
||||
log.Tracef("pre-resolve completed for domain %s", domain)
|
||||
}
|
||||
}(domainStr)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveAndUpdateDomain performs DNS resolution and updates domain prefixes
|
||||
func (d *DnsInterceptor) resolveAndUpdateDomain(ctx context.Context, qDomain string) error {
|
||||
d.mu.RLock()
|
||||
peerKey := d.currentPeerKey
|
||||
d.mu.RUnlock()
|
||||
|
||||
if peerKey == "" {
|
||||
return fmt.Errorf("no current peer key")
|
||||
}
|
||||
|
||||
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get upstream IP: %v", err)
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(qDomain), dns.TypeA)
|
||||
msg.Id = dns.Id()
|
||||
msg.MsgHdr.AuthenticatedData = true
|
||||
|
||||
reply, err := d.exchangeWithUpstream(ctx, msg, upstreamIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exchange with upstream: %v", err)
|
||||
}
|
||||
|
||||
if reply == nil || len(reply.Answer) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
resolvedDomain := domain.Domain(dns.Fqdn(qDomain))
|
||||
return d.processResolveResponse(reply, resolvedDomain, resolvedDomain)
|
||||
}
|
||||
|
||||
// exchangeWithUpstream performs DNS exchange with the upstream server
|
||||
func (d *DnsInterceptor) exchangeWithUpstream(ctx context.Context, msg *dns.Msg, upstreamIP netip.Addr) (*dns.Msg, error) {
|
||||
client := &dns.Client{
|
||||
Timeout: nbdns.UpstreamTimeout,
|
||||
Net: "udp",
|
||||
}
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||
|
||||
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, msg, upstream)
|
||||
return reply, err
|
||||
}
|
||||
|
||||
// extractIPsFromDNSResponse extracts IP addresses from DNS answer records
|
||||
func (d *DnsInterceptor) extractIPsFromDNSResponse(reply *dns.Msg, domainForLogging domain.Domain) []netip.Prefix {
|
||||
if reply == nil || len(reply.Answer) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var prefixes []netip.Prefix
|
||||
for _, answer := range reply.Answer {
|
||||
var ip netip.Addr
|
||||
switch rr := answer.(type) {
|
||||
case *dns.A:
|
||||
addr, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
log.Tracef("failed to convert A record for domain=%s ip=%v", domainForLogging, rr.A)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
case *dns.AAAA:
|
||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", domainForLogging, rr.AAAA)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
|
||||
prefixes = append(prefixes, prefix)
|
||||
}
|
||||
|
||||
return prefixes
|
||||
}
|
||||
|
||||
// processResolveResponse extracts IPs from DNS response and updates domain prefixes
|
||||
func (d *DnsInterceptor) processResolveResponse(reply *dns.Msg, resolvedDomain, originalDomain domain.Domain) error {
|
||||
newPrefixes := d.extractIPsFromDNSResponse(reply, resolvedDomain)
|
||||
if len(newPrefixes) > 0 {
|
||||
return d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) RemoveRoute() error {
|
||||
d.mu.Lock()
|
||||
|
||||
@ -113,6 +226,7 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||
}
|
||||
|
||||
d.currentPeerKey = peerKey
|
||||
go d.preResolveDomains()
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
@ -165,12 +279,8 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if r.Extra == nil {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
client := &dns.Client{
|
||||
Timeout: nbdns.UpstreamTimeout,
|
||||
Net: "udp",
|
||||
}
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
|
||||
|
||||
reply, err := d.exchangeWithUpstream(context.TODO(), r, upstreamIP)
|
||||
if err != nil {
|
||||
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
@ -235,43 +345,13 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
}
|
||||
|
||||
resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name))
|
||||
|
||||
// already punycode via RegisterHandler()
|
||||
originalDomain := domain.Domain(origPattern)
|
||||
if originalDomain == "" {
|
||||
originalDomain = resolvedDomain
|
||||
}
|
||||
|
||||
var newPrefixes []netip.Prefix
|
||||
for _, answer := range r.Answer {
|
||||
var ip netip.Addr
|
||||
switch rr := answer.(type) {
|
||||
case *dns.A:
|
||||
addr, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
case *dns.AAAA:
|
||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
|
||||
newPrefixes = append(newPrefixes, prefix)
|
||||
}
|
||||
|
||||
if len(newPrefixes) > 0 {
|
||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
||||
log.Errorf("failed to update domain prefixes: %v", err)
|
||||
}
|
||||
if err := d.processResolveResponse(r, resolvedDomain, originalDomain); err != nil {
|
||||
log.Errorf("failed to process DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user