[client] Enhance DNS forwarder to track resolved IPs with resource IDs on routing peers (#3620)

[client] Enhance DNS forwarder to track resolved IPs with resource IDs on routing peers (#3620)
This commit is contained in:
hakansa
2025-04-07 15:16:12 +08:00
committed by GitHub
parent 86dbb4ee4f
commit 1ba1e092ce
6 changed files with 183 additions and 69 deletions

View File

@ -5,11 +5,14 @@ import (
"errors"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
)
@ -17,23 +20,27 @@ const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second
type DNSForwarder struct {
listenAddress string
ttl uint32
domains []string
listenAddress string
ttl uint32
domains []string
statusRecorder *peer.Status
dnsServer *dns.Server
mux *dns.ServeMux
resId sync.Map
}
func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder {
func NewDNSForwarder(listenAddress string, ttl uint32, statusRecorder *peer.Status) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{
listenAddress: listenAddress,
ttl: ttl,
listenAddress: listenAddress,
ttl: ttl,
statusRecorder: statusRecorder,
}
}
func (f *DNSForwarder) Listen(domains []string) error {
func (f *DNSForwarder) Listen(domains []string, resIds map[string]string) error {
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
mux := dns.NewServeMux()
@ -45,22 +52,31 @@ func (f *DNSForwarder) Listen(domains []string) error {
f.dnsServer = dnsServer
f.mux = mux
f.UpdateDomains(domains)
f.UpdateDomains(domains, resIds)
return dnsServer.ListenAndServe()
}
func (f *DNSForwarder) UpdateDomains(domains []string) {
func (f *DNSForwarder) UpdateDomains(domains []string, resIds map[string]string) {
log.Debugf("Updating domains from %v to %v", f.domains, domains)
for _, d := range f.domains {
f.mux.HandleRemove(d)
f.statusRecorder.RemoveResolvedIPLookupEntry(d)
}
f.resId.Clear()
newDomains := filterDomains(domains)
for _, d := range newDomains {
f.mux.HandleFunc(d, f.handleDNSQuery)
}
for domain, resId := range resIds {
if domain != "" {
f.resId.Store(domain, resId)
}
}
f.domains = newDomains
}
@ -106,6 +122,21 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
return
}
resId, ok := f.resId.Load(strings.TrimSuffix(domain, "."))
if ok {
for _, ip := range ips {
var ipWithSuffix string
if ip.Is4() {
ipWithSuffix = ip.String() + "/32"
log.Tracef("resolved domain=%s to IPv4=%s", domain, ipWithSuffix)
} else {
ipWithSuffix = ip.String() + "/128"
log.Tracef("resolved domain=%s to IPv6=%s", domain, ipWithSuffix)
}
f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId.(string))
}
}
f.addIPsToResponse(resp, domain, ips)
if err := w.WriteMsg(resp); err != nil {