mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-04 18:01:13 +01:00
DNS forwarder (#3024)
* Add dns forwarder service - do not serve unmanaged domains - response the dns server with proper codes - add update operation
This commit is contained in:
parent
d020755dd5
commit
619d899047
@ -9,26 +9,50 @@ import (
|
||||
)
|
||||
|
||||
type DNSForwarder struct {
|
||||
ListenAddress string
|
||||
TTL uint32
|
||||
listenAddress string
|
||||
ttl uint32
|
||||
domains []string
|
||||
|
||||
dnsServer *dns.Server
|
||||
mux *dns.ServeMux
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
|
||||
return &DNSForwarder{
|
||||
listenAddress: listenAddress,
|
||||
ttl: ttl,
|
||||
domains: domains,
|
||||
}
|
||||
}
|
||||
func (f *DNSForwarder) Listen() error {
|
||||
log.Infof("listen DNS forwarder on: %s", f.ListenAddress)
|
||||
log.Infof("listen DNS forwarder on: %s", f.listenAddress)
|
||||
mux := dns.NewServeMux()
|
||||
mux.HandleFunc(".", f.handleDNSQuery)
|
||||
|
||||
for _, d := range f.domains {
|
||||
mux.HandleFunc(d, f.handleDNSQuery)
|
||||
}
|
||||
|
||||
dnsServer := &dns.Server{
|
||||
Addr: f.ListenAddress,
|
||||
Addr: f.listenAddress,
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
}
|
||||
f.dnsServer = dnsServer
|
||||
f.mux = mux
|
||||
return dnsServer.ListenAndServe()
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) UpdateDomains(domains []string) {
|
||||
for _, d := range f.domains {
|
||||
f.mux.HandleRemove(d)
|
||||
}
|
||||
|
||||
for _, d := range domains {
|
||||
f.mux.HandleFunc(d, f.handleDNSQuery)
|
||||
}
|
||||
f.domains = domains
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
if f.dnsServer == nil {
|
||||
return nil
|
||||
@ -37,7 +61,7 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
log.Debugf("received DNS query for DNS forwarder: %v", query)
|
||||
log.Tracef("received DNS query for DNS forwarder: %v", query)
|
||||
if len(query.Question) == 0 {
|
||||
return
|
||||
}
|
||||
@ -49,8 +73,8 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
|
||||
ips, err := net.LookupIP(domain)
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolve query for domain %s: %v", domain, err)
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
log.Warnf("failed to resolve query for domain %s: %v", domain, err)
|
||||
resp.Rcode = dns.RcodeRefused
|
||||
_ = w.WriteMsg(resp)
|
||||
return
|
||||
}
|
||||
@ -66,7 +90,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
Name: domain,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: f.TTL,
|
||||
Ttl: f.ttl,
|
||||
},
|
||||
}
|
||||
respRecord = &rr
|
||||
@ -77,7 +101,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
Name: domain,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: f.TTL,
|
||||
Ttl: f.ttl,
|
||||
},
|
||||
}
|
||||
respRecord = &rr
|
@ -3,26 +3,37 @@ package dnsfwd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
||||
ListenPort = 5353
|
||||
dnsTTL = 60 //seconds
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
Firewall firewall.Manager
|
||||
firewall firewall.Manager
|
||||
|
||||
dnsRules []firewall.Rule
|
||||
service *DNSForwarder
|
||||
fwRules []firewall.Rule
|
||||
dnsForwarder *DNSForwarder
|
||||
}
|
||||
|
||||
func (m *Manager) Start() error {
|
||||
func NewManager(fw firewall.Manager) *Manager {
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Start(domains []string) error {
|
||||
log.Infof("starting DNS forwarder")
|
||||
if m.service != nil {
|
||||
if m.dnsForwarder != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -30,14 +41,9 @@ func (m *Manager) Start() error {
|
||||
return err
|
||||
}
|
||||
|
||||
m.service = &DNSForwarder{
|
||||
// todo listen only NetBird interface
|
||||
ListenAddress: fmt.Sprintf(":%d", ListenPort),
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, domains)
|
||||
go func() {
|
||||
if err := m.service.Listen(); err != nil {
|
||||
if err := m.dnsForwarder.Listen(); err != nil {
|
||||
// todo handle close error if it is exists
|
||||
log.Errorf("failed to start DNS forwarder, err: %v", err)
|
||||
}
|
||||
@ -46,14 +52,30 @@ func (m *Manager) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateDomains(domains []string) {
|
||||
if m.dnsForwarder == nil {
|
||||
return
|
||||
}
|
||||
|
||||
m.dnsForwarder.UpdateDomains(domains)
|
||||
}
|
||||
|
||||
func (m *Manager) Stop(ctx context.Context) error {
|
||||
if m.service == nil {
|
||||
if m.dnsForwarder == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := m.service.Close(ctx)
|
||||
m.service = nil
|
||||
return err
|
||||
var mErr *multierror.Error
|
||||
if err := m.dropDNSFirewall(); err != nil {
|
||||
mErr = multierror.Append(mErr, err)
|
||||
}
|
||||
|
||||
if err := m.dnsForwarder.Close(ctx); err != nil {
|
||||
mErr = multierror.Append(mErr, err)
|
||||
}
|
||||
|
||||
m.dnsForwarder = nil
|
||||
return nberrors.FormatErrorOrNil(mErr)
|
||||
}
|
||||
|
||||
func (h *Manager) allowDNSFirewall() error {
|
||||
@ -61,28 +83,24 @@ func (h *Manager) allowDNSFirewall() error {
|
||||
IsRange: false,
|
||||
Values: []int{ListenPort},
|
||||
}
|
||||
dnsRules, err := h.Firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
||||
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
}
|
||||
h.dnsRules = dnsRules
|
||||
h.fwRules = dnsRules
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Manager) dropDNSFirewall() error {
|
||||
if len(h.dnsRules) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, rule := range h.dnsRules {
|
||||
if err := h.Firewall.DeletePeerRule(rule); err != nil {
|
||||
log.Errorf("failed to delete DNS router rules, err: %v", err)
|
||||
return err
|
||||
var mErr *multierror.Error
|
||||
for _, rule := range h.fwRules {
|
||||
if err := h.firewall.DeletePeerRule(rule); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
h.dnsRules = nil
|
||||
return nil
|
||||
h.fwRules = nil
|
||||
return nberrors.FormatErrorOrNil(mErr)
|
||||
}
|
||||
|
@ -16,8 +16,6 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
|
||||
"github.com/pion/ice/v3"
|
||||
"github.com/pion/stun/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -31,6 +29,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
@ -789,7 +788,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
|
||||
}
|
||||
|
||||
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
|
||||
if networkMap.GetPeerConfig() != nil {
|
||||
err := e.updateConfig(networkMap.GetPeerConfig())
|
||||
@ -809,31 +807,13 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.acl.ApplyFiltering(networkMap)
|
||||
}
|
||||
|
||||
isDNSRouter, routes := toRoutes(networkMap.GetRoutes())
|
||||
routedDomains, routes := toRoutes(networkMap.GetRoutes())
|
||||
|
||||
if err := e.routeManager.UpdateRoutes(serial, routes); err != nil {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
}
|
||||
|
||||
if isDNSRouter {
|
||||
if e.dnsForwardMgr == nil {
|
||||
e.dnsForwardMgr = &dnsfwd.Manager{
|
||||
Firewall: e.firewall,
|
||||
}
|
||||
|
||||
if err := e.dnsForwardMgr.Start(); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if e.dnsForwardMgr != nil {
|
||||
// todo: review context
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
}
|
||||
e.updateDNSForwarder(routedDomains)
|
||||
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
|
||||
@ -895,12 +875,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) {
|
||||
func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) {
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
|
||||
var isDNSRouter bool
|
||||
var dnsRoutes []string
|
||||
routes := make([]*route.Route, 0)
|
||||
for _, protoRoute := range protoRoutes {
|
||||
var prefix netip.Prefix
|
||||
@ -911,7 +891,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
isDNSRouter = true
|
||||
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)
|
||||
|
||||
convertedRoute := &route.Route{
|
||||
ID: route.ID(protoRoute.ID),
|
||||
@ -926,7 +906,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) (bool, []*route.Route) {
|
||||
}
|
||||
routes = append(routes, convertedRoute)
|
||||
}
|
||||
return isDNSRouter, routes
|
||||
return dnsRoutes, routes
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||
@ -1574,6 +1554,31 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||
return nm, nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateDNSForwarder(domains []string) {
|
||||
if len(domains) > 0 {
|
||||
log.Infof("enable domain router service for domains: %v", domains)
|
||||
if e.dnsForwardMgr == nil {
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall)
|
||||
|
||||
if err := e.dnsForwardMgr.Start(domains); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
} else {
|
||||
log.Infof("update domain router service for domains: %v", domains)
|
||||
e.dnsForwardMgr.UpdateDomains(domains)
|
||||
}
|
||||
} else {
|
||||
if e.dnsForwardMgr != nil {
|
||||
log.Infof("disable domain router service")
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||
for _, check := range checks {
|
||||
|
Loading…
Reference in New Issue
Block a user