mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 02:08:40 +02:00
Preresolve domains
This commit is contained in:
parent
bdf2994e97
commit
3251bc79fa
@ -6,6 +6,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@ -63,6 +64,118 @@ func (d *DnsInterceptor) AddRoute(context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// preResolveDomains performs background DNS resolution for non-wildcard domains
|
||||||
|
func (d *DnsInterceptor) preResolveDomains() {
|
||||||
|
for _, domain := range d.route.Domains {
|
||||||
|
domainStr := string(domain)
|
||||||
|
|
||||||
|
if strings.HasPrefix(domainStr, "*.") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
domainStr = strings.TrimSuffix(domainStr, ".")
|
||||||
|
go func(domain string) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := d.resolveAndUpdateDomain(ctx, domain); err != nil {
|
||||||
|
log.Debugf("pre-resolve failed for domain %s: %v", domain, err)
|
||||||
|
} else {
|
||||||
|
log.Tracef("pre-resolve completed for domain %s", domain)
|
||||||
|
}
|
||||||
|
}(domainStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveAndUpdateDomain performs DNS resolution and updates domain prefixes
|
||||||
|
func (d *DnsInterceptor) resolveAndUpdateDomain(ctx context.Context, qDomain string) error {
|
||||||
|
d.mu.RLock()
|
||||||
|
peerKey := d.currentPeerKey
|
||||||
|
d.mu.RUnlock()
|
||||||
|
|
||||||
|
if peerKey == "" {
|
||||||
|
return fmt.Errorf("no current peer key")
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get upstream IP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
msg.SetQuestion(dns.Fqdn(qDomain), dns.TypeA)
|
||||||
|
msg.Id = dns.Id()
|
||||||
|
msg.MsgHdr.AuthenticatedData = true
|
||||||
|
|
||||||
|
reply, err := d.exchangeWithUpstream(ctx, msg, upstreamIP)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("exchange with upstream: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if reply == nil || len(reply.Answer) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedDomain := domain.Domain(dns.Fqdn(qDomain))
|
||||||
|
return d.processResolveResponse(reply, resolvedDomain, resolvedDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
// exchangeWithUpstream performs DNS exchange with the upstream server
|
||||||
|
func (d *DnsInterceptor) exchangeWithUpstream(ctx context.Context, msg *dns.Msg, upstreamIP netip.Addr) (*dns.Msg, error) {
|
||||||
|
client := &dns.Client{
|
||||||
|
Timeout: nbdns.UpstreamTimeout,
|
||||||
|
Net: "udp",
|
||||||
|
}
|
||||||
|
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||||
|
|
||||||
|
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, msg, upstream)
|
||||||
|
return reply, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractIPsFromDNSResponse extracts IP addresses from DNS answer records
|
||||||
|
func (d *DnsInterceptor) extractIPsFromDNSResponse(reply *dns.Msg, domainForLogging domain.Domain) []netip.Prefix {
|
||||||
|
if reply == nil || len(reply.Answer) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var prefixes []netip.Prefix
|
||||||
|
for _, answer := range reply.Answer {
|
||||||
|
var ip netip.Addr
|
||||||
|
switch rr := answer.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
addr, ok := netip.AddrFromSlice(rr.A)
|
||||||
|
if !ok {
|
||||||
|
log.Tracef("failed to convert A record for domain=%s ip=%v", domainForLogging, rr.A)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = addr
|
||||||
|
case *dns.AAAA:
|
||||||
|
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||||
|
if !ok {
|
||||||
|
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", domainForLogging, rr.AAAA)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = addr
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
|
||||||
|
prefixes = append(prefixes, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefixes
|
||||||
|
}
|
||||||
|
|
||||||
|
// processResolveResponse extracts IPs from DNS response and updates domain prefixes
|
||||||
|
func (d *DnsInterceptor) processResolveResponse(reply *dns.Msg, resolvedDomain, originalDomain domain.Domain) error {
|
||||||
|
newPrefixes := d.extractIPsFromDNSResponse(reply, resolvedDomain)
|
||||||
|
if len(newPrefixes) > 0 {
|
||||||
|
return d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) RemoveRoute() error {
|
func (d *DnsInterceptor) RemoveRoute() error {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
|
|
||||||
@ -113,6 +226,7 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
d.currentPeerKey = peerKey
|
d.currentPeerKey = peerKey
|
||||||
|
go d.preResolveDomains()
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -165,12 +279,8 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
if r.Extra == nil {
|
if r.Extra == nil {
|
||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
client := &dns.Client{
|
|
||||||
Timeout: nbdns.UpstreamTimeout,
|
reply, err := d.exchangeWithUpstream(context.TODO(), r, upstreamIP)
|
||||||
Net: "udp",
|
|
||||||
}
|
|
||||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
|
||||||
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||||
@ -235,43 +345,13 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name))
|
resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name))
|
||||||
|
|
||||||
// already punycode via RegisterHandler()
|
|
||||||
originalDomain := domain.Domain(origPattern)
|
originalDomain := domain.Domain(origPattern)
|
||||||
if originalDomain == "" {
|
if originalDomain == "" {
|
||||||
originalDomain = resolvedDomain
|
originalDomain = resolvedDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
var newPrefixes []netip.Prefix
|
if err := d.processResolveResponse(r, resolvedDomain, originalDomain); err != nil {
|
||||||
for _, answer := range r.Answer {
|
log.Errorf("failed to process DNS response: %v", err)
|
||||||
var ip netip.Addr
|
|
||||||
switch rr := answer.(type) {
|
|
||||||
case *dns.A:
|
|
||||||
addr, ok := netip.AddrFromSlice(rr.A)
|
|
||||||
if !ok {
|
|
||||||
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ip = addr
|
|
||||||
case *dns.AAAA:
|
|
||||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
|
||||||
if !ok {
|
|
||||||
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ip = addr
|
|
||||||
default:
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
|
|
||||||
newPrefixes = append(newPrefixes, prefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(newPrefixes) > 0 {
|
|
||||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
|
||||||
log.Errorf("failed to update domain prefixes: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user