From 619d899047735514c03af18c9309c4518a6dceb4 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 11 Dec 2024 14:47:55 +0100 Subject: [PATCH] DNS forwarder (#3024) * Add dns forwarder service - do not serve unmanaged domains - response the dns server with proper codes - add update operation --- .../dnsfwd/{service.go => forwarder.go} | 44 ++++++++--- client/internal/dnsfwd/manager.go | 76 ++++++++++++------- client/internal/engine.go | 59 +++++++------- 3 files changed, 113 insertions(+), 66 deletions(-) rename client/internal/dnsfwd/{service.go => forwarder.go} (62%) diff --git a/client/internal/dnsfwd/service.go b/client/internal/dnsfwd/forwarder.go similarity index 62% rename from client/internal/dnsfwd/service.go rename to client/internal/dnsfwd/forwarder.go index 251918dc8..1ffde7e49 100644 --- a/client/internal/dnsfwd/service.go +++ b/client/internal/dnsfwd/forwarder.go @@ -9,26 +9,50 @@ import ( ) type DNSForwarder struct { - ListenAddress string - TTL uint32 + listenAddress string + ttl uint32 + domains []string dnsServer *dns.Server + mux *dns.ServeMux } +func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder { + return &DNSForwarder{ + listenAddress: listenAddress, + ttl: ttl, + domains: domains, + } +} func (f *DNSForwarder) Listen() error { - log.Infof("listen DNS forwarder on: %s", f.ListenAddress) + log.Infof("listen DNS forwarder on: %s", f.listenAddress) mux := dns.NewServeMux() - mux.HandleFunc(".", f.handleDNSQuery) + + for _, d := range f.domains { + mux.HandleFunc(d, f.handleDNSQuery) + } dnsServer := &dns.Server{ - Addr: f.ListenAddress, + Addr: f.listenAddress, Net: "udp", Handler: mux, } f.dnsServer = dnsServer + f.mux = mux return dnsServer.ListenAndServe() } +func (f *DNSForwarder) UpdateDomains(domains []string) { + for _, d := range f.domains { + f.mux.HandleRemove(d) + } + + for _, d := range domains { + f.mux.HandleFunc(d, f.handleDNSQuery) + } + f.domains = domains +} + func (f *DNSForwarder) Close(ctx context.Context) error { if f.dnsServer == nil { return nil @@ -37,7 +61,7 @@ func (f *DNSForwarder) Close(ctx context.Context) error { } func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { - log.Debugf("received DNS query for DNS forwarder: %v", query) + log.Tracef("received DNS query for DNS forwarder: %v", query) if len(query.Question) == 0 { return } @@ -49,8 +73,8 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { ips, err := net.LookupIP(domain) if err != nil { - log.Errorf("failed to resolve query for domain %s: %v", domain, err) - resp.Rcode = dns.RcodeServerFailure + log.Warnf("failed to resolve query for domain %s: %v", domain, err) + resp.Rcode = dns.RcodeRefused _ = w.WriteMsg(resp) return } @@ -66,7 +90,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { Name: domain, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, - Ttl: f.TTL, + Ttl: f.ttl, }, } respRecord = &rr @@ -77,7 +101,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, - Ttl: f.TTL, + Ttl: f.ttl, }, } respRecord = &rr diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 47a90f5a9..bc05e0cec 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -3,26 +3,37 @@ package dnsfwd import ( "context" "fmt" - log "github.com/sirupsen/logrus" "net" + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" ) const ( + // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also ListenPort = 5353 + dnsTTL = 60 //seconds ) type Manager struct { - Firewall firewall.Manager + firewall firewall.Manager - dnsRules []firewall.Rule - service *DNSForwarder + fwRules []firewall.Rule + dnsForwarder *DNSForwarder } -func (m *Manager) Start() error { +func NewManager(fw firewall.Manager) *Manager { + return &Manager{ + firewall: fw, + } +} + +func (m *Manager) Start(domains []string) error { log.Infof("starting DNS forwarder") - if m.service != nil { + if m.dnsForwarder != nil { return nil } @@ -30,14 +41,9 @@ func (m *Manager) Start() error { return err } - m.service = &DNSForwarder{ - // todo listen only NetBird interface - ListenAddress: fmt.Sprintf(":%d", ListenPort), - TTL: 300, - } - + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, domains) go func() { - if err := m.service.Listen(); err != nil { + if err := m.dnsForwarder.Listen(); err != nil { // todo handle close error if it is exists log.Errorf("failed to start DNS forwarder, err: %v", err) } @@ -46,14 +52,30 @@ func (m *Manager) Start() error { return nil } +func (m *Manager) UpdateDomains(domains []string) { + if m.dnsForwarder == nil { + return + } + + m.dnsForwarder.UpdateDomains(domains) +} + func (m *Manager) Stop(ctx context.Context) error { - if m.service == nil { + if m.dnsForwarder == nil { return nil } - err := m.service.Close(ctx) - m.service = nil - return err + var mErr *multierror.Error + if err := m.dropDNSFirewall(); err != nil { + mErr = multierror.Append(mErr, err) + } + + if err := m.dnsForwarder.Close(ctx); err != nil { + mErr = multierror.Append(mErr, err) + } + + m.dnsForwarder = nil + return nberrors.FormatErrorOrNil(mErr) } func (h *Manager) allowDNSFirewall() error { @@ -61,28 +83,24 @@ func (h *Manager) allowDNSFirewall() error { IsRange: false, Values: []int{ListenPort}, } - dnsRules, err := h.Firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") + dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") if err != nil { log.Errorf("failed to add allow DNS router rules, err: %v", err) return err } - h.dnsRules = dnsRules + h.fwRules = dnsRules return nil } func (h *Manager) dropDNSFirewall() error { - if len(h.dnsRules) == 0 { - return nil - } - - for _, rule := range h.dnsRules { - if err := h.Firewall.DeletePeerRule(rule); err != nil { - log.Errorf("failed to delete DNS router rules, err: %v", err) - return err + var mErr *multierror.Error + for _, rule := range h.fwRules { + if err := h.firewall.DeletePeerRule(rule); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) } } - h.dnsRules = nil - return nil + h.fwRules = nil + return nberrors.FormatErrorOrNil(mErr) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 8b78c2391..c36b110d3 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -16,8 +16,6 @@ import ( "sync/atomic" "time" - "github.com/netbirdio/netbird/client/internal/dnsfwd" - "github.com/pion/ice/v3" "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" @@ -31,6 +29,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer/guard" @@ -789,7 +788,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error { } func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { - // intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't if networkMap.GetPeerConfig() != nil { err := e.updateConfig(networkMap.GetPeerConfig()) @@ -809,31 +807,13 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.acl.ApplyFiltering(networkMap) } - isDNSRouter, routes := toRoutes(networkMap.GetRoutes()) + routedDomains, routes := toRoutes(networkMap.GetRoutes()) if err := e.routeManager.UpdateRoutes(serial, routes); err != nil { log.Errorf("failed to update clientRoutes, err: %v", err) } - if isDNSRouter { - if e.dnsForwardMgr == nil { - e.dnsForwardMgr = &dnsfwd.Manager{ - Firewall: e.firewall, - } - - if err := e.dnsForwardMgr.Start(); err != nil { - log.Errorf("failed to start DNS forward: %v", err) - } - } - } else { - if e.dnsForwardMgr != nil { - // todo: review context - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = nil - } - } + e.updateDNSForwarder(routedDomains) log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) @@ -895,12 +875,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return nil } -func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) { +func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) { if protoRoutes == nil { protoRoutes = []*mgmProto.Route{} } - var isDNSRouter bool + var dnsRoutes []string routes := make([]*route.Route, 0) for _, protoRoute := range protoRoutes { var prefix netip.Prefix @@ -911,7 +891,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) { continue } } - isDNSRouter = true + dnsRoutes = append(dnsRoutes, protoRoute.Domains...) convertedRoute := &route.Route{ ID: route.ID(protoRoute.ID), @@ -926,7 +906,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) { } routes = append(routes, convertedRoute) } - return isDNSRouter, routes + return dnsRoutes, routes } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { @@ -1574,6 +1554,31 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) { return nm, nil } +func (e *Engine) updateDNSForwarder(domains []string) { + if len(domains) > 0 { + log.Infof("enable domain router service for domains: %v", domains) + if e.dnsForwardMgr == nil { + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall) + + if err := e.dnsForwardMgr.Start(domains); err != nil { + log.Errorf("failed to start DNS forward: %v", err) + e.dnsForwardMgr = nil + } + } else { + log.Infof("update domain router service for domains: %v", domains) + e.dnsForwardMgr.UpdateDomains(domains) + } + } else { + if e.dnsForwardMgr != nil { + log.Infof("disable domain router service") + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + e.dnsForwardMgr = nil + } + } +} + // isChecksEqual checks if two slices of checks are equal. func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { for _, check := range checks {