Add dns forwarder service

This commit is contained in:
Zoltan Papp
2024-12-10 19:14:09 +01:00
committed by Viktor Liu
parent d802b7b9ba
commit d020755dd5
3 changed files with 224 additions and 11 deletions

View File

@@ -0,0 +1,88 @@
package dnsfwd
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"net"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
ListenPort = 5353
)
type Manager struct {
Firewall firewall.Manager
dnsRules []firewall.Rule
service *DNSForwarder
}
func (m *Manager) Start() error {
log.Infof("starting DNS forwarder")
if m.service != nil {
return nil
}
if err := m.allowDNSFirewall(); err != nil {
return err
}
m.service = &DNSForwarder{
// todo listen only NetBird interface
ListenAddress: fmt.Sprintf(":%d", ListenPort),
TTL: 300,
}
go func() {
if err := m.service.Listen(); err != nil {
// todo handle close error if it is exists
log.Errorf("failed to start DNS forwarder, err: %v", err)
}
}()
return nil
}
func (m *Manager) Stop(ctx context.Context) error {
if m.service == nil {
return nil
}
err := m.service.Close(ctx)
m.service = nil
return err
}
func (h *Manager) allowDNSFirewall() error {
dport := &firewall.Port{
IsRange: false,
Values: []int{ListenPort},
}
dnsRules, err := h.Firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
}
h.dnsRules = dnsRules
return nil
}
func (h *Manager) dropDNSFirewall() error {
if len(h.dnsRules) == 0 {
return nil
}
for _, rule := range h.dnsRules {
if err := h.Firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete DNS router rules, err: %v", err)
return err
}
}
h.dnsRules = nil
return nil
}

View File

@@ -0,0 +1,91 @@
package dnsfwd
import (
"context"
"net"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
type DNSForwarder struct {
ListenAddress string
TTL uint32
dnsServer *dns.Server
}
func (f *DNSForwarder) Listen() error {
log.Infof("listen DNS forwarder on: %s", f.ListenAddress)
mux := dns.NewServeMux()
mux.HandleFunc(".", f.handleDNSQuery)
dnsServer := &dns.Server{
Addr: f.ListenAddress,
Net: "udp",
Handler: mux,
}
f.dnsServer = dnsServer
return dnsServer.ListenAndServe()
}
func (f *DNSForwarder) Close(ctx context.Context) error {
if f.dnsServer == nil {
return nil
}
return f.dnsServer.ShutdownContext(ctx)
}
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
log.Debugf("received DNS query for DNS forwarder: %v", query)
if len(query.Question) == 0 {
return
}
question := query.Question[0]
domain := question.Name
resp := query.SetReply(query)
ips, err := net.LookupIP(domain)
if err != nil {
log.Errorf("failed to resolve query for domain %s: %v", domain, err)
resp.Rcode = dns.RcodeServerFailure
_ = w.WriteMsg(resp)
return
}
for _, ip := range ips {
log.Infof("resolved domain %s to IP %s", domain, ip)
var respRecord dns.RR
if ip.To4() == nil {
log.Infof("resolved domain %s to IPv6 %s", domain, ip)
rr := dns.AAAA{
AAAA: ip,
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: f.TTL,
},
}
respRecord = &rr
} else {
rr := dns.A{
A: ip,
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: f.TTL,
},
}
respRecord = &rr
}
resp.Answer = append(resp.Answer, respRecord)
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}

View File

@@ -16,6 +16,8 @@ import (
"sync/atomic"
"time"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/pion/ice/v3"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
@@ -159,6 +161,7 @@ type Engine struct {
firewall manager.Manager
routeManager routemanager.Manager
acl acl.Manager
dnsForwardMgr *dnsfwd.Manager
dnsServer dns.Server
@@ -282,6 +285,13 @@ func (e *Engine) Stop() error {
e.routeManager.Stop(e.stateManager)
}
if e.dnsForwardMgr != nil {
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
}
if e.srWatcher != nil {
e.srWatcher.Close()
}
@@ -799,13 +809,30 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.acl.ApplyFiltering(networkMap)
}
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
isDNSRouter, routes := toRoutes(networkMap.GetRoutes())
if err := e.routeManager.UpdateRoutes(serial, routes); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
if err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
if isDNSRouter {
if e.dnsForwardMgr == nil {
e.dnsForwardMgr = &dnsfwd.Manager{
Firewall: e.firewall,
}
if err := e.dnsForwardMgr.Start(); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
}
}
} else {
if e.dnsForwardMgr != nil {
// todo: review context
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
}
}
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
@@ -868,7 +895,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) {
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
var isDNSRouter bool
routes := make([]*route.Route, 0)
for _, protoRoute := range protoRoutes {
var prefix netip.Prefix
@@ -879,6 +911,8 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
continue
}
}
isDNSRouter = true
convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID),
Network: prefix,
@@ -892,7 +926,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
}
routes = append(routes, convertedRoute)
}
return routes
return isDNSRouter, routes
}
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
@@ -1226,7 +1260,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
if err != nil {
return nil, nil, err
}
routes := toRoutes(netMap.GetRoutes())
_, routes := toRoutes(netMap.GetRoutes())
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
return routes, &dnsCfg, nil
}