Add DNS interceptor

This commit is contained in:
Viktor Liu 2024-12-10 11:03:40 +01:00
parent 97bb74f824
commit d77ac20760
7 changed files with 240 additions and 6 deletions

View File

@ -3,6 +3,7 @@ package dns
import (
"fmt"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
nbdns "github.com/netbirdio/netbird/dns"
)
@ -13,6 +14,10 @@ type MockServer struct {
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
}
func (m *MockServer) RegisterHandler(*dnsinterceptor.RouteMatchHandler) error {
return nil
}
// Initialize mock implementation of Initialize from Server interface
func (m *MockServer) Initialize() error {
if m.InitializeFunc != nil {

View File

@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns"
)
@ -30,6 +31,7 @@ type IosDnsManager interface {
// Server is a dns server interface
type Server interface {
RegisterHandler(handler *dnsinterceptor.RouteMatchHandler) error
Initialize() error
Stop()
DnsIP() string
@ -151,6 +153,10 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
return defaultServer
}
func (m *DefaultServer) RegisterHandler(*dnsinterceptor.RouteMatchHandler) error {
return nil
}
// Initialize instantiate host manager and the dns service
func (s *DefaultServer) Initialize() (err error) {
s.mux.Lock()

View File

@ -382,6 +382,7 @@ func (e *Engine) Start() error {
e.relayManager,
initialRoutes,
e.stateManager,
dnsServer,
)
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil {

View File

@ -251,7 +251,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
},
}
engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil)
_, _, err = engine.routeManager.Init()
require.NoError(t, err)
engine.dnsServer = &dns.MockServer{

View File

@ -13,12 +13,15 @@ import (
"github.com/netbirdio/netbird/client/iface"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/route"
)
const useNewDNSRoute = true
type routerPeerStatus struct {
connected bool
relayed bool
@ -53,7 +56,16 @@ type clientNetwork struct {
updateSerial uint64
}
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
func newClientNetworkWatcher(
ctx context.Context,
dnsRouteInterval time.Duration,
wgInterface iface.IWGIface,
statusRecorder *peer.Status,
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsServer nbdns.Server,
) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{
@ -65,7 +77,7 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface, dnsServer),
}
return client
}
@ -368,8 +380,19 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
}
}
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler {
func handlerFromRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsRouterInteval time.Duration,
statusRecorder *peer.Status,
wgInterface iface.IWGIface,
dnsServer nbdns.Server,
) RouteHandler {
if rt.IsDynamic() {
if useNewDNSRoute {
return dnsinterceptor.New(rt, routeRefCounter, allowedIPsRefCounter, statusRecorder, dnsServer)
}
dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
}

View File

@ -0,0 +1,177 @@
package dnsinterceptor
import (
"context"
"fmt"
"net/netip"
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
)
type RouteMatchHandler struct {
mu sync.RWMutex
route *route.Route
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
statusRecorder *peer.Status
dnsServer nbdns.Server
currentPeerKey string
domainRoutes map[string]*route.Route
}
func New(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
statusRecorder *peer.Status,
dnsServer nbdns.Server,
) routemanager.RouteHandler {
return &RouteMatchHandler{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
statusRecorder: statusRecorder,
dnsServer: dnsServer,
domainRoutes: make(map[string]*route.Route),
}
}
func (h *RouteMatchHandler) String() string {
return fmt.Sprintf("dns route for domains: %v", h.route.Domains)
}
func (h *RouteMatchHandler) AddRoute(ctx context.Context) error {
h.mu.Lock()
defer h.mu.Unlock()
for _, domain := range h.route.Domains {
pattern := dns.Fqdn(string(domain))
h.domainRoutes[pattern] = h.route
}
return h.dnsServer.RegisterHandler(h)
}
func (h *RouteMatchHandler) RemoveRoute() error {
h.mu.Lock()
defer h.mu.Unlock()
h.domainRoutes = make(map[string]*route.Route)
return nil
}
func (h *RouteMatchHandler) AddAllowedIPs(peerKey string) error {
h.mu.Lock()
defer h.mu.Unlock()
h.currentPeerKey = peerKey
return nil
}
func (h *RouteMatchHandler) RemoveAllowedIPs() error {
h.mu.Lock()
defer h.mu.Unlock()
h.currentPeerKey = ""
return nil
}
type responseInterceptor struct {
dns.ResponseWriter
handler *RouteMatchHandler
question dns.Question
answered bool
}
func (i *responseInterceptor) WriteMsg(resp *dns.Msg) error {
if i.answered {
return nil
}
i.answered = true
if resp == nil || len(resp.Answer) == 0 {
return i.ResponseWriter.WriteMsg(resp)
}
i.handler.mu.RLock()
defer i.handler.mu.RUnlock()
questionName := i.question.Name
for _, ans := range resp.Answer {
var ip netip.Addr
switch rr := ans.(type) {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
continue
}
ip = addr
default:
continue
}
if route := i.handler.findMatchingRoute(questionName); route != nil {
i.handler.processMatch(route, questionName, ip)
}
}
return i.ResponseWriter.WriteMsg(resp)
}
func (h *RouteMatchHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
interceptor := &responseInterceptor{
ResponseWriter: w,
handler: h,
question: r.Question[0],
}
h.dnsServer.ServeDNS(interceptor, r)
}
func (h *RouteMatchHandler) findMatchingRoute(domain string) *route.Route {
domain = strings.ToLower(domain)
if route, ok := h.domainRoutes[domain]; ok {
return route
}
labels := dns.SplitDomainName(domain)
if labels == nil {
return nil
}
for i := 0; i < len(labels); i++ {
wildcard := "*." + strings.Join(labels[i:], ".") + "."
if route, ok := h.domainRoutes[wildcard]; ok {
return route
}
}
return nil
}
func (h *RouteMatchHandler) processMatch(route *route.Route, domain string, ip netip.Addr) {
network := netip.PrefixFrom(ip, ip.BitLen())
if h.currentPeerKey == "" {
return
}
if err := h.allowedIPsRefcounter.Increment(network, h.currentPeerKey); err != nil {
log.Errorf("Failed to add allowed IP %s: %v", network, err)
}
}

View File

@ -16,6 +16,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
@ -60,6 +61,7 @@ type DefaultManager struct {
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration
stateManager *statemanager.Manager
dnsServer dns.Server
}
func NewManager(
@ -71,6 +73,7 @@ func NewManager(
relayMgr *relayClient.Manager,
initialRoutes []*route.Route,
stateManager *statemanager.Manager,
dnsServer dns.Server,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
notifier := notifier.NewNotifier()
@ -88,6 +91,7 @@ func NewManager(
pubKey: pubKey,
notifier: notifier,
stateManager: stateManager,
dnsServer: dnsServer,
}
dm.routeRefCounter = refcounter.New(
@ -273,7 +277,16 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
continue
}
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
clientNetworkWatcher := newClientNetworkWatcher(
m.ctx,
m.dnsRouteInterval,
m.wgInterface,
m.statusRecorder,
routes[0],
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsServer,
)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
@ -302,7 +315,16 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
clientNetworkWatcher = newClientNetworkWatcher(
m.ctx,
m.dnsRouteInterval,
m.wgInterface,
m.statusRecorder,
routes[0],
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsServer,
)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}