mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-20 03:29:19 +02:00
Add dns forwarder service
This commit is contained in:
88
client/internal/dnsfwd/manager.go
Normal file
88
client/internal/dnsfwd/manager.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
package dnsfwd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ListenPort = 5353
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
Firewall firewall.Manager
|
||||||
|
|
||||||
|
dnsRules []firewall.Rule
|
||||||
|
service *DNSForwarder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Start() error {
|
||||||
|
log.Infof("starting DNS forwarder")
|
||||||
|
if m.service != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.allowDNSFirewall(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.service = &DNSForwarder{
|
||||||
|
// todo listen only NetBird interface
|
||||||
|
ListenAddress: fmt.Sprintf(":%d", ListenPort),
|
||||||
|
TTL: 300,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := m.service.Listen(); err != nil {
|
||||||
|
// todo handle close error if it is exists
|
||||||
|
log.Errorf("failed to start DNS forwarder, err: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Stop(ctx context.Context) error {
|
||||||
|
if m.service == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.service.Close(ctx)
|
||||||
|
m.service = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Manager) allowDNSFirewall() error {
|
||||||
|
dport := &firewall.Port{
|
||||||
|
IsRange: false,
|
||||||
|
Values: []int{ListenPort},
|
||||||
|
}
|
||||||
|
dnsRules, err := h.Firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.dnsRules = dnsRules
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Manager) dropDNSFirewall() error {
|
||||||
|
if len(h.dnsRules) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range h.dnsRules {
|
||||||
|
if err := h.Firewall.DeletePeerRule(rule); err != nil {
|
||||||
|
log.Errorf("failed to delete DNS router rules, err: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.dnsRules = nil
|
||||||
|
return nil
|
||||||
|
}
|
91
client/internal/dnsfwd/service.go
Normal file
91
client/internal/dnsfwd/service.go
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
package dnsfwd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DNSForwarder struct {
|
||||||
|
ListenAddress string
|
||||||
|
TTL uint32
|
||||||
|
|
||||||
|
dnsServer *dns.Server
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) Listen() error {
|
||||||
|
log.Infof("listen DNS forwarder on: %s", f.ListenAddress)
|
||||||
|
mux := dns.NewServeMux()
|
||||||
|
mux.HandleFunc(".", f.handleDNSQuery)
|
||||||
|
|
||||||
|
dnsServer := &dns.Server{
|
||||||
|
Addr: f.ListenAddress,
|
||||||
|
Net: "udp",
|
||||||
|
Handler: mux,
|
||||||
|
}
|
||||||
|
f.dnsServer = dnsServer
|
||||||
|
return dnsServer.ListenAndServe()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||||
|
if f.dnsServer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return f.dnsServer.ShutdownContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
|
log.Debugf("received DNS query for DNS forwarder: %v", query)
|
||||||
|
if len(query.Question) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
question := query.Question[0]
|
||||||
|
domain := question.Name
|
||||||
|
|
||||||
|
resp := query.SetReply(query)
|
||||||
|
|
||||||
|
ips, err := net.LookupIP(domain)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to resolve query for domain %s: %v", domain, err)
|
||||||
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
|
_ = w.WriteMsg(resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
log.Infof("resolved domain %s to IP %s", domain, ip)
|
||||||
|
var respRecord dns.RR
|
||||||
|
if ip.To4() == nil {
|
||||||
|
log.Infof("resolved domain %s to IPv6 %s", domain, ip)
|
||||||
|
rr := dns.AAAA{
|
||||||
|
AAAA: ip,
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: domain,
|
||||||
|
Rrtype: dns.TypeAAAA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: f.TTL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
respRecord = &rr
|
||||||
|
} else {
|
||||||
|
rr := dns.A{
|
||||||
|
A: ip,
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: domain,
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: f.TTL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
respRecord = &rr
|
||||||
|
}
|
||||||
|
resp.Answer = append(resp.Answer, respRecord)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
@@ -16,6 +16,8 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
|
|
||||||
"github.com/pion/ice/v3"
|
"github.com/pion/ice/v3"
|
||||||
"github.com/pion/stun/v2"
|
"github.com/pion/stun/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -159,6 +161,7 @@ type Engine struct {
|
|||||||
firewall manager.Manager
|
firewall manager.Manager
|
||||||
routeManager routemanager.Manager
|
routeManager routemanager.Manager
|
||||||
acl acl.Manager
|
acl acl.Manager
|
||||||
|
dnsForwardMgr *dnsfwd.Manager
|
||||||
|
|
||||||
dnsServer dns.Server
|
dnsServer dns.Server
|
||||||
|
|
||||||
@@ -282,6 +285,13 @@ func (e *Engine) Stop() error {
|
|||||||
e.routeManager.Stop(e.stateManager)
|
e.routeManager.Stop(e.stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.dnsForwardMgr != nil {
|
||||||
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
e.dnsForwardMgr = nil
|
||||||
|
}
|
||||||
|
|
||||||
if e.srWatcher != nil {
|
if e.srWatcher != nil {
|
||||||
e.srWatcher.Close()
|
e.srWatcher.Close()
|
||||||
}
|
}
|
||||||
@@ -799,13 +809,30 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
e.acl.ApplyFiltering(networkMap)
|
e.acl.ApplyFiltering(networkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
protoRoutes := networkMap.GetRoutes()
|
isDNSRouter, routes := toRoutes(networkMap.GetRoutes())
|
||||||
if protoRoutes == nil {
|
|
||||||
protoRoutes = []*mgmProto.Route{}
|
if err := e.routeManager.UpdateRoutes(serial, routes); err != nil {
|
||||||
|
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)); err != nil {
|
if isDNSRouter {
|
||||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
if e.dnsForwardMgr == nil {
|
||||||
|
e.dnsForwardMgr = &dnsfwd.Manager{
|
||||||
|
Firewall: e.firewall,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.dnsForwardMgr.Start(); err != nil {
|
||||||
|
log.Errorf("failed to start DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if e.dnsForwardMgr != nil {
|
||||||
|
// todo: review context
|
||||||
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
e.dnsForwardMgr = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||||
@@ -868,7 +895,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) {
|
||||||
|
if protoRoutes == nil {
|
||||||
|
protoRoutes = []*mgmProto.Route{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var isDNSRouter bool
|
||||||
routes := make([]*route.Route, 0)
|
routes := make([]*route.Route, 0)
|
||||||
for _, protoRoute := range protoRoutes {
|
for _, protoRoute := range protoRoutes {
|
||||||
var prefix netip.Prefix
|
var prefix netip.Prefix
|
||||||
@@ -879,6 +911,8 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
isDNSRouter = true
|
||||||
|
|
||||||
convertedRoute := &route.Route{
|
convertedRoute := &route.Route{
|
||||||
ID: route.ID(protoRoute.ID),
|
ID: route.ID(protoRoute.ID),
|
||||||
Network: prefix,
|
Network: prefix,
|
||||||
@@ -892,7 +926,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
|||||||
}
|
}
|
||||||
routes = append(routes, convertedRoute)
|
routes = append(routes, convertedRoute)
|
||||||
}
|
}
|
||||||
return routes
|
return isDNSRouter, routes
|
||||||
}
|
}
|
||||||
|
|
||||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||||
@@ -1226,7 +1260,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
routes := toRoutes(netMap.GetRoutes())
|
_, routes := toRoutes(netMap.GetRoutes())
|
||||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
||||||
return routes, &dnsCfg, nil
|
return routes, &dnsCfg, nil
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user