mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-19 11:20:18 +02:00
Add dns forwarder service
This commit is contained in:
88
client/internal/dnsfwd/manager.go
Normal file
88
client/internal/dnsfwd/manager.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
ListenPort = 5353
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
Firewall firewall.Manager
|
||||
|
||||
dnsRules []firewall.Rule
|
||||
service *DNSForwarder
|
||||
}
|
||||
|
||||
func (m *Manager) Start() error {
|
||||
log.Infof("starting DNS forwarder")
|
||||
if m.service != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.allowDNSFirewall(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.service = &DNSForwarder{
|
||||
// todo listen only NetBird interface
|
||||
ListenAddress: fmt.Sprintf(":%d", ListenPort),
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := m.service.Listen(); err != nil {
|
||||
// todo handle close error if it is exists
|
||||
log.Errorf("failed to start DNS forwarder, err: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Stop(ctx context.Context) error {
|
||||
if m.service == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := m.service.Close(ctx)
|
||||
m.service = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *Manager) allowDNSFirewall() error {
|
||||
dport := &firewall.Port{
|
||||
IsRange: false,
|
||||
Values: []int{ListenPort},
|
||||
}
|
||||
dnsRules, err := h.Firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
}
|
||||
h.dnsRules = dnsRules
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Manager) dropDNSFirewall() error {
|
||||
if len(h.dnsRules) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, rule := range h.dnsRules {
|
||||
if err := h.Firewall.DeletePeerRule(rule); err != nil {
|
||||
log.Errorf("failed to delete DNS router rules, err: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
h.dnsRules = nil
|
||||
return nil
|
||||
}
|
91
client/internal/dnsfwd/service.go
Normal file
91
client/internal/dnsfwd/service.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type DNSForwarder struct {
|
||||
ListenAddress string
|
||||
TTL uint32
|
||||
|
||||
dnsServer *dns.Server
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Listen() error {
|
||||
log.Infof("listen DNS forwarder on: %s", f.ListenAddress)
|
||||
mux := dns.NewServeMux()
|
||||
mux.HandleFunc(".", f.handleDNSQuery)
|
||||
|
||||
dnsServer := &dns.Server{
|
||||
Addr: f.ListenAddress,
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
}
|
||||
f.dnsServer = dnsServer
|
||||
return dnsServer.ListenAndServe()
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
if f.dnsServer == nil {
|
||||
return nil
|
||||
}
|
||||
return f.dnsServer.ShutdownContext(ctx)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
log.Debugf("received DNS query for DNS forwarder: %v", query)
|
||||
if len(query.Question) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
question := query.Question[0]
|
||||
domain := question.Name
|
||||
|
||||
resp := query.SetReply(query)
|
||||
|
||||
ips, err := net.LookupIP(domain)
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolve query for domain %s: %v", domain, err)
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
_ = w.WriteMsg(resp)
|
||||
return
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
log.Infof("resolved domain %s to IP %s", domain, ip)
|
||||
var respRecord dns.RR
|
||||
if ip.To4() == nil {
|
||||
log.Infof("resolved domain %s to IPv6 %s", domain, ip)
|
||||
rr := dns.AAAA{
|
||||
AAAA: ip,
|
||||
Hdr: dns.RR_Header{
|
||||
Name: domain,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: f.TTL,
|
||||
},
|
||||
}
|
||||
respRecord = &rr
|
||||
} else {
|
||||
rr := dns.A{
|
||||
A: ip,
|
||||
Hdr: dns.RR_Header{
|
||||
Name: domain,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: f.TTL,
|
||||
},
|
||||
}
|
||||
respRecord = &rr
|
||||
}
|
||||
resp.Answer = append(resp.Answer, respRecord)
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
}
|
@@ -16,6 +16,8 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
|
||||
"github.com/pion/ice/v3"
|
||||
"github.com/pion/stun/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -159,6 +161,7 @@ type Engine struct {
|
||||
firewall manager.Manager
|
||||
routeManager routemanager.Manager
|
||||
acl acl.Manager
|
||||
dnsForwardMgr *dnsfwd.Manager
|
||||
|
||||
dnsServer dns.Server
|
||||
|
||||
@@ -282,6 +285,13 @@ func (e *Engine) Stop() error {
|
||||
e.routeManager.Stop(e.stateManager)
|
||||
}
|
||||
|
||||
if e.dnsForwardMgr != nil {
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
if e.srWatcher != nil {
|
||||
e.srWatcher.Close()
|
||||
}
|
||||
@@ -799,13 +809,30 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.acl.ApplyFiltering(networkMap)
|
||||
}
|
||||
|
||||
protoRoutes := networkMap.GetRoutes()
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
isDNSRouter, routes := toRoutes(networkMap.GetRoutes())
|
||||
|
||||
if err := e.routeManager.UpdateRoutes(serial, routes); err != nil {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
}
|
||||
|
||||
if err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)); err != nil {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
if isDNSRouter {
|
||||
if e.dnsForwardMgr == nil {
|
||||
e.dnsForwardMgr = &dnsfwd.Manager{
|
||||
Firewall: e.firewall,
|
||||
}
|
||||
|
||||
if err := e.dnsForwardMgr.Start(); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if e.dnsForwardMgr != nil {
|
||||
// todo: review context
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
@@ -868,7 +895,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) {
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
|
||||
var isDNSRouter bool
|
||||
routes := make([]*route.Route, 0)
|
||||
for _, protoRoute := range protoRoutes {
|
||||
var prefix netip.Prefix
|
||||
@@ -879,6 +911,8 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
continue
|
||||
}
|
||||
}
|
||||
isDNSRouter = true
|
||||
|
||||
convertedRoute := &route.Route{
|
||||
ID: route.ID(protoRoute.ID),
|
||||
Network: prefix,
|
||||
@@ -892,7 +926,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
}
|
||||
routes = append(routes, convertedRoute)
|
||||
}
|
||||
return routes
|
||||
return isDNSRouter, routes
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||
@@ -1226,7 +1260,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
routes := toRoutes(netMap.GetRoutes())
|
||||
_, routes := toRoutes(netMap.GetRoutes())
|
||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
||||
return routes, &dnsCfg, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user