DNS forwarder (#3024)

* Add dns forwarder service

- do not serve unmanaged domains
- response the dns server with proper codes
- add update operation
This commit is contained in:
Zoltan Papp 2024-12-11 14:47:55 +01:00 committed by GitHub
parent d020755dd5
commit 619d899047
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 113 additions and 66 deletions

View File

@ -9,26 +9,50 @@ import (
)
type DNSForwarder struct {
ListenAddress string
TTL uint32
listenAddress string
ttl uint32
domains []string
dnsServer *dns.Server
mux *dns.ServeMux
}
func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
return &DNSForwarder{
listenAddress: listenAddress,
ttl: ttl,
domains: domains,
}
}
func (f *DNSForwarder) Listen() error {
log.Infof("listen DNS forwarder on: %s", f.ListenAddress)
log.Infof("listen DNS forwarder on: %s", f.listenAddress)
mux := dns.NewServeMux()
mux.HandleFunc(".", f.handleDNSQuery)
for _, d := range f.domains {
mux.HandleFunc(d, f.handleDNSQuery)
}
dnsServer := &dns.Server{
Addr: f.ListenAddress,
Addr: f.listenAddress,
Net: "udp",
Handler: mux,
}
f.dnsServer = dnsServer
f.mux = mux
return dnsServer.ListenAndServe()
}
func (f *DNSForwarder) UpdateDomains(domains []string) {
for _, d := range f.domains {
f.mux.HandleRemove(d)
}
for _, d := range domains {
f.mux.HandleFunc(d, f.handleDNSQuery)
}
f.domains = domains
}
func (f *DNSForwarder) Close(ctx context.Context) error {
if f.dnsServer == nil {
return nil
@ -37,7 +61,7 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
}
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
log.Debugf("received DNS query for DNS forwarder: %v", query)
log.Tracef("received DNS query for DNS forwarder: %v", query)
if len(query.Question) == 0 {
return
}
@ -49,8 +73,8 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
ips, err := net.LookupIP(domain)
if err != nil {
log.Errorf("failed to resolve query for domain %s: %v", domain, err)
resp.Rcode = dns.RcodeServerFailure
log.Warnf("failed to resolve query for domain %s: %v", domain, err)
resp.Rcode = dns.RcodeRefused
_ = w.WriteMsg(resp)
return
}
@ -66,7 +90,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
Name: domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: f.TTL,
Ttl: f.ttl,
},
}
respRecord = &rr
@ -77,7 +101,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: f.TTL,
Ttl: f.ttl,
},
}
respRecord = &rr

View File

@ -3,26 +3,37 @@ package dnsfwd
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"net"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
ListenPort = 5353
dnsTTL = 60 //seconds
)
type Manager struct {
Firewall firewall.Manager
firewall firewall.Manager
dnsRules []firewall.Rule
service *DNSForwarder
fwRules []firewall.Rule
dnsForwarder *DNSForwarder
}
func (m *Manager) Start() error {
func NewManager(fw firewall.Manager) *Manager {
return &Manager{
firewall: fw,
}
}
func (m *Manager) Start(domains []string) error {
log.Infof("starting DNS forwarder")
if m.service != nil {
if m.dnsForwarder != nil {
return nil
}
@ -30,14 +41,9 @@ func (m *Manager) Start() error {
return err
}
m.service = &DNSForwarder{
// todo listen only NetBird interface
ListenAddress: fmt.Sprintf(":%d", ListenPort),
TTL: 300,
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, domains)
go func() {
if err := m.service.Listen(); err != nil {
if err := m.dnsForwarder.Listen(); err != nil {
// todo handle close error if it is exists
log.Errorf("failed to start DNS forwarder, err: %v", err)
}
@ -46,14 +52,30 @@ func (m *Manager) Start() error {
return nil
}
func (m *Manager) UpdateDomains(domains []string) {
if m.dnsForwarder == nil {
return
}
m.dnsForwarder.UpdateDomains(domains)
}
func (m *Manager) Stop(ctx context.Context) error {
if m.service == nil {
if m.dnsForwarder == nil {
return nil
}
err := m.service.Close(ctx)
m.service = nil
return err
var mErr *multierror.Error
if err := m.dropDNSFirewall(); err != nil {
mErr = multierror.Append(mErr, err)
}
if err := m.dnsForwarder.Close(ctx); err != nil {
mErr = multierror.Append(mErr, err)
}
m.dnsForwarder = nil
return nberrors.FormatErrorOrNil(mErr)
}
func (h *Manager) allowDNSFirewall() error {
@ -61,28 +83,24 @@ func (h *Manager) allowDNSFirewall() error {
IsRange: false,
Values: []int{ListenPort},
}
dnsRules, err := h.Firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
}
h.dnsRules = dnsRules
h.fwRules = dnsRules
return nil
}
func (h *Manager) dropDNSFirewall() error {
if len(h.dnsRules) == 0 {
return nil
}
for _, rule := range h.dnsRules {
if err := h.Firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete DNS router rules, err: %v", err)
return err
var mErr *multierror.Error
for _, rule := range h.fwRules {
if err := h.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}
h.dnsRules = nil
return nil
h.fwRules = nil
return nberrors.FormatErrorOrNil(mErr)
}

View File

@ -16,8 +16,6 @@ import (
"sync/atomic"
"time"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/pion/ice/v3"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
@ -31,6 +29,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
@ -789,7 +788,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
}
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
if networkMap.GetPeerConfig() != nil {
err := e.updateConfig(networkMap.GetPeerConfig())
@ -809,31 +807,13 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.acl.ApplyFiltering(networkMap)
}
isDNSRouter, routes := toRoutes(networkMap.GetRoutes())
routedDomains, routes := toRoutes(networkMap.GetRoutes())
if err := e.routeManager.UpdateRoutes(serial, routes); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
if isDNSRouter {
if e.dnsForwardMgr == nil {
e.dnsForwardMgr = &dnsfwd.Manager{
Firewall: e.firewall,
}
if err := e.dnsForwardMgr.Start(); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
}
}
} else {
if e.dnsForwardMgr != nil {
// todo: review context
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
}
}
e.updateDNSForwarder(routedDomains)
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
@ -895,12 +875,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}
func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) {
func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) {
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
var isDNSRouter bool
var dnsRoutes []string
routes := make([]*route.Route, 0)
for _, protoRoute := range protoRoutes {
var prefix netip.Prefix
@ -911,7 +891,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) {
continue
}
}
isDNSRouter = true
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)
convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID),
@ -926,7 +906,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) {
}
routes = append(routes, convertedRoute)
}
return isDNSRouter, routes
return dnsRoutes, routes
}
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
@ -1574,6 +1554,31 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
return nm, nil
}
func (e *Engine) updateDNSForwarder(domains []string) {
if len(domains) > 0 {
log.Infof("enable domain router service for domains: %v", domains)
if e.dnsForwardMgr == nil {
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall)
if err := e.dnsForwardMgr.Start(domains); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
} else {
log.Infof("update domain router service for domains: %v", domains)
e.dnsForwardMgr.UpdateDomains(domains)
}
} else {
if e.dnsForwardMgr != nil {
log.Infof("disable domain router service")
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
}
}
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks {