mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-19 11:20:18 +02:00
Add DNS interceptor
This commit is contained in:
@@ -3,6 +3,7 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -13,6 +14,10 @@ type MockServer struct {
|
|||||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockServer) RegisterHandler(*dnsinterceptor.RouteMatchHandler) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize mock implementation of Initialize from Server interface
|
// Initialize mock implementation of Initialize from Server interface
|
||||||
func (m *MockServer) Initialize() error {
|
func (m *MockServer) Initialize() error {
|
||||||
if m.InitializeFunc != nil {
|
if m.InitializeFunc != nil {
|
||||||
|
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
@@ -30,6 +31,7 @@ type IosDnsManager interface {
|
|||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
|
RegisterHandler(handler *dnsinterceptor.RouteMatchHandler) error
|
||||||
Initialize() error
|
Initialize() error
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
DnsIP() string
|
||||||
@@ -151,6 +153,10 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
|||||||
return defaultServer
|
return defaultServer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *DefaultServer) RegisterHandler(*dnsinterceptor.RouteMatchHandler) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize instantiate host manager and the dns service
|
// Initialize instantiate host manager and the dns service
|
||||||
func (s *DefaultServer) Initialize() (err error) {
|
func (s *DefaultServer) Initialize() (err error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
|
@@ -382,6 +382,7 @@ func (e *Engine) Start() error {
|
|||||||
e.relayManager,
|
e.relayManager,
|
||||||
initialRoutes,
|
initialRoutes,
|
||||||
e.stateManager,
|
e.stateManager,
|
||||||
|
dnsServer,
|
||||||
)
|
)
|
||||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@@ -251,7 +251,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
engine.wgInterface = wgIface
|
engine.wgInterface = wgIface
|
||||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
|
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil)
|
||||||
_, _, err = engine.routeManager.Init()
|
_, _, err = engine.routeManager.Init()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
|
@@ -13,12 +13,15 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const useNewDNSRoute = true
|
||||||
|
|
||||||
type routerPeerStatus struct {
|
type routerPeerStatus struct {
|
||||||
connected bool
|
connected bool
|
||||||
relayed bool
|
relayed bool
|
||||||
@@ -53,7 +56,16 @@ type clientNetwork struct {
|
|||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
|
func newClientNetworkWatcher(
|
||||||
|
ctx context.Context,
|
||||||
|
dnsRouteInterval time.Duration,
|
||||||
|
wgInterface iface.IWGIface,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
rt *route.Route,
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter,
|
||||||
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
|
dnsServer nbdns.Server,
|
||||||
|
) *clientNetwork {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
client := &clientNetwork{
|
client := &clientNetwork{
|
||||||
@@ -65,7 +77,7 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
|
|||||||
routePeersNotifiers: make(map[string]chan struct{}),
|
routePeersNotifiers: make(map[string]chan struct{}),
|
||||||
routeUpdate: make(chan routesUpdate),
|
routeUpdate: make(chan routesUpdate),
|
||||||
peerStateUpdate: make(chan struct{}),
|
peerStateUpdate: make(chan struct{}),
|
||||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
|
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface, dnsServer),
|
||||||
}
|
}
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
@@ -368,8 +380,19 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler {
|
func handlerFromRoute(
|
||||||
|
rt *route.Route,
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter,
|
||||||
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
|
dnsRouterInteval time.Duration,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
wgInterface iface.IWGIface,
|
||||||
|
dnsServer nbdns.Server,
|
||||||
|
) RouteHandler {
|
||||||
if rt.IsDynamic() {
|
if rt.IsDynamic() {
|
||||||
|
if useNewDNSRoute {
|
||||||
|
return dnsinterceptor.New(rt, routeRefCounter, allowedIPsRefCounter, statusRecorder, dnsServer)
|
||||||
|
}
|
||||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
dns := nbdns.NewServiceViaMemory(wgInterface)
|
||||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
|
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
|
||||||
}
|
}
|
||||||
|
177
client/internal/routemanager/dnsinterceptor/handler.go
Normal file
177
client/internal/routemanager/dnsinterceptor/handler.go
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
package dnsinterceptor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RouteMatchHandler struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
route *route.Route
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter
|
||||||
|
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
dnsServer nbdns.Server
|
||||||
|
currentPeerKey string
|
||||||
|
domainRoutes map[string]*route.Route
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(
|
||||||
|
rt *route.Route,
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter,
|
||||||
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
dnsServer nbdns.Server,
|
||||||
|
) routemanager.RouteHandler {
|
||||||
|
|
||||||
|
return &RouteMatchHandler{
|
||||||
|
route: rt,
|
||||||
|
routeRefCounter: routeRefCounter,
|
||||||
|
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
dnsServer: dnsServer,
|
||||||
|
domainRoutes: make(map[string]*route.Route),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RouteMatchHandler) String() string {
|
||||||
|
return fmt.Sprintf("dns route for domains: %v", h.route.Domains)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RouteMatchHandler) AddRoute(ctx context.Context) error {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
for _, domain := range h.route.Domains {
|
||||||
|
pattern := dns.Fqdn(string(domain))
|
||||||
|
h.domainRoutes[pattern] = h.route
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.dnsServer.RegisterHandler(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RouteMatchHandler) RemoveRoute() error {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
h.domainRoutes = make(map[string]*route.Route)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RouteMatchHandler) AddAllowedIPs(peerKey string) error {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.currentPeerKey = peerKey
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RouteMatchHandler) RemoveAllowedIPs() error {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.currentPeerKey = ""
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type responseInterceptor struct {
|
||||||
|
dns.ResponseWriter
|
||||||
|
handler *RouteMatchHandler
|
||||||
|
question dns.Question
|
||||||
|
answered bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *responseInterceptor) WriteMsg(resp *dns.Msg) error {
|
||||||
|
if i.answered {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
i.answered = true
|
||||||
|
|
||||||
|
if resp == nil || len(resp.Answer) == 0 {
|
||||||
|
return i.ResponseWriter.WriteMsg(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
i.handler.mu.RLock()
|
||||||
|
defer i.handler.mu.RUnlock()
|
||||||
|
|
||||||
|
questionName := i.question.Name
|
||||||
|
for _, ans := range resp.Answer {
|
||||||
|
var ip netip.Addr
|
||||||
|
switch rr := ans.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
addr, ok := netip.AddrFromSlice(rr.A)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = addr
|
||||||
|
case *dns.AAAA:
|
||||||
|
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = addr
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if route := i.handler.findMatchingRoute(questionName); route != nil {
|
||||||
|
i.handler.processMatch(route, questionName, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return i.ResponseWriter.WriteMsg(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RouteMatchHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
interceptor := &responseInterceptor{
|
||||||
|
ResponseWriter: w,
|
||||||
|
handler: h,
|
||||||
|
question: r.Question[0],
|
||||||
|
}
|
||||||
|
|
||||||
|
h.dnsServer.ServeDNS(interceptor, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RouteMatchHandler) findMatchingRoute(domain string) *route.Route {
|
||||||
|
domain = strings.ToLower(domain)
|
||||||
|
|
||||||
|
if route, ok := h.domainRoutes[domain]; ok {
|
||||||
|
return route
|
||||||
|
}
|
||||||
|
|
||||||
|
labels := dns.SplitDomainName(domain)
|
||||||
|
if labels == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(labels); i++ {
|
||||||
|
wildcard := "*." + strings.Join(labels[i:], ".") + "."
|
||||||
|
if route, ok := h.domainRoutes[wildcard]; ok {
|
||||||
|
return route
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RouteMatchHandler) processMatch(route *route.Route, domain string, ip netip.Addr) {
|
||||||
|
network := netip.PrefixFrom(ip, ip.BitLen())
|
||||||
|
|
||||||
|
if h.currentPeerKey == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.allowedIPsRefcounter.Increment(network, h.currentPeerKey); err != nil {
|
||||||
|
log.Errorf("Failed to add allowed IP %s: %v", network, err)
|
||||||
|
}
|
||||||
|
}
|
@@ -16,6 +16,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||||
@@ -60,6 +61,7 @@ type DefaultManager struct {
|
|||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||||
dnsRouteInterval time.Duration
|
dnsRouteInterval time.Duration
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
|
dnsServer dns.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(
|
func NewManager(
|
||||||
@@ -71,6 +73,7 @@ func NewManager(
|
|||||||
relayMgr *relayClient.Manager,
|
relayMgr *relayClient.Manager,
|
||||||
initialRoutes []*route.Route,
|
initialRoutes []*route.Route,
|
||||||
stateManager *statemanager.Manager,
|
stateManager *statemanager.Manager,
|
||||||
|
dnsServer dns.Server,
|
||||||
) *DefaultManager {
|
) *DefaultManager {
|
||||||
mCTX, cancel := context.WithCancel(ctx)
|
mCTX, cancel := context.WithCancel(ctx)
|
||||||
notifier := notifier.NewNotifier()
|
notifier := notifier.NewNotifier()
|
||||||
@@ -88,6 +91,7 @@ func NewManager(
|
|||||||
pubKey: pubKey,
|
pubKey: pubKey,
|
||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
stateManager: stateManager,
|
stateManager: stateManager,
|
||||||
|
dnsServer: dnsServer,
|
||||||
}
|
}
|
||||||
|
|
||||||
dm.routeRefCounter = refcounter.New(
|
dm.routeRefCounter = refcounter.New(
|
||||||
@@ -273,7 +277,16 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
clientNetworkWatcher := newClientNetworkWatcher(
|
||||||
|
m.ctx,
|
||||||
|
m.dnsRouteInterval,
|
||||||
|
m.wgInterface,
|
||||||
|
m.statusRecorder,
|
||||||
|
routes[0],
|
||||||
|
m.routeRefCounter,
|
||||||
|
m.allowedIPsRefCounter,
|
||||||
|
m.dnsServer,
|
||||||
|
)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||||
@@ -302,7 +315,16 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
|||||||
for id, routes := range networks {
|
for id, routes := range networks {
|
||||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||||
if !found {
|
if !found {
|
||||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
clientNetworkWatcher = newClientNetworkWatcher(
|
||||||
|
m.ctx,
|
||||||
|
m.dnsRouteInterval,
|
||||||
|
m.wgInterface,
|
||||||
|
m.statusRecorder,
|
||||||
|
routes[0],
|
||||||
|
m.routeRefCounter,
|
||||||
|
m.allowedIPsRefCounter,
|
||||||
|
m.dnsServer,
|
||||||
|
)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user