mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-17 18:41:41 +02:00
[client] Enhance DNS forwarder to track resolved IPs with resource IDs on routing peers (#3620)
[client] Enhance DNS forwarder to track resolved IPs with resource IDs on routing peers (#3620)
This commit is contained in:
@ -5,11 +5,14 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
@ -17,23 +20,27 @@ const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||
const upstreamTimeout = 15 * time.Second
|
||||
|
||||
type DNSForwarder struct {
|
||||
listenAddress string
|
||||
ttl uint32
|
||||
domains []string
|
||||
listenAddress string
|
||||
ttl uint32
|
||||
domains []string
|
||||
statusRecorder *peer.Status
|
||||
|
||||
dnsServer *dns.Server
|
||||
mux *dns.ServeMux
|
||||
|
||||
resId sync.Map
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder {
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, statusRecorder *peer.Status) *DNSForwarder {
|
||||
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||
return &DNSForwarder{
|
||||
listenAddress: listenAddress,
|
||||
ttl: ttl,
|
||||
listenAddress: listenAddress,
|
||||
ttl: ttl,
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Listen(domains []string) error {
|
||||
func (f *DNSForwarder) Listen(domains []string, resIds map[string]string) error {
|
||||
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
@ -45,22 +52,31 @@ func (f *DNSForwarder) Listen(domains []string) error {
|
||||
f.dnsServer = dnsServer
|
||||
f.mux = mux
|
||||
|
||||
f.UpdateDomains(domains)
|
||||
f.UpdateDomains(domains, resIds)
|
||||
|
||||
return dnsServer.ListenAndServe()
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) UpdateDomains(domains []string) {
|
||||
func (f *DNSForwarder) UpdateDomains(domains []string, resIds map[string]string) {
|
||||
log.Debugf("Updating domains from %v to %v", f.domains, domains)
|
||||
|
||||
for _, d := range f.domains {
|
||||
f.mux.HandleRemove(d)
|
||||
f.statusRecorder.RemoveResolvedIPLookupEntry(d)
|
||||
}
|
||||
f.resId.Clear()
|
||||
|
||||
newDomains := filterDomains(domains)
|
||||
for _, d := range newDomains {
|
||||
f.mux.HandleFunc(d, f.handleDNSQuery)
|
||||
}
|
||||
|
||||
for domain, resId := range resIds {
|
||||
if domain != "" {
|
||||
f.resId.Store(domain, resId)
|
||||
}
|
||||
}
|
||||
|
||||
f.domains = newDomains
|
||||
}
|
||||
|
||||
@ -106,6 +122,21 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
resId, ok := f.resId.Load(strings.TrimSuffix(domain, "."))
|
||||
if ok {
|
||||
for _, ip := range ips {
|
||||
var ipWithSuffix string
|
||||
if ip.Is4() {
|
||||
ipWithSuffix = ip.String() + "/32"
|
||||
log.Tracef("resolved domain=%s to IPv4=%s", domain, ipWithSuffix)
|
||||
} else {
|
||||
ipWithSuffix = ip.String() + "/128"
|
||||
log.Tracef("resolved domain=%s to IPv6=%s", domain, ipWithSuffix)
|
||||
}
|
||||
f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId.(string))
|
||||
}
|
||||
}
|
||||
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -19,19 +20,21 @@ const (
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
firewall firewall.Manager
|
||||
firewall firewall.Manager
|
||||
statusRecorder *peer.Status
|
||||
|
||||
fwRules []firewall.Rule
|
||||
dnsForwarder *DNSForwarder
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager) *Manager {
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
firewall: fw,
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Start(domains []string) error {
|
||||
func (m *Manager) Start(domains []string, resIds map[string]string) error {
|
||||
log.Infof("starting DNS forwarder")
|
||||
if m.dnsForwarder != nil {
|
||||
return nil
|
||||
@ -41,9 +44,9 @@ func (m *Manager) Start(domains []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL)
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.statusRecorder)
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(domains); err != nil {
|
||||
if err := m.dnsForwarder.Listen(domains, resIds); err != nil {
|
||||
// todo handle close error if it is exists
|
||||
log.Errorf("failed to start DNS forwarder, err: %v", err)
|
||||
}
|
||||
@ -52,12 +55,12 @@ func (m *Manager) Start(domains []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateDomains(domains []string) {
|
||||
func (m *Manager) UpdateDomains(domains []string, resIds map[string]string) {
|
||||
if m.dnsForwarder == nil {
|
||||
return
|
||||
}
|
||||
|
||||
m.dnsForwarder.UpdateDomains(domains)
|
||||
m.dnsForwarder.UpdateDomains(domains, resIds)
|
||||
}
|
||||
|
||||
func (m *Manager) Stop(ctx context.Context) error {
|
||||
|
Reference in New Issue
Block a user