mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-13 18:31:18 +01:00
Add DNS interceptor
This commit is contained in:
parent
97bb74f824
commit
d77ac20760
@ -3,6 +3,7 @@ package dns
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
@ -13,6 +14,10 @@ type MockServer struct {
|
||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||
}
|
||||
|
||||
func (m *MockServer) RegisterHandler(*dnsinterceptor.RouteMatchHandler) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize mock implementation of Initialize from Server interface
|
||||
func (m *MockServer) Initialize() error {
|
||||
if m.InitializeFunc != nil {
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
@ -30,6 +31,7 @@ type IosDnsManager interface {
|
||||
|
||||
// Server is a dns server interface
|
||||
type Server interface {
|
||||
RegisterHandler(handler *dnsinterceptor.RouteMatchHandler) error
|
||||
Initialize() error
|
||||
Stop()
|
||||
DnsIP() string
|
||||
@ -151,6 +153,10 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
||||
return defaultServer
|
||||
}
|
||||
|
||||
func (m *DefaultServer) RegisterHandler(*dnsinterceptor.RouteMatchHandler) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize instantiate host manager and the dns service
|
||||
func (s *DefaultServer) Initialize() (err error) {
|
||||
s.mux.Lock()
|
||||
|
@ -382,6 +382,7 @@ func (e *Engine) Start() error {
|
||||
e.relayManager,
|
||||
initialRoutes,
|
||||
e.stateManager,
|
||||
dnsServer,
|
||||
)
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
if err != nil {
|
||||
|
@ -251,7 +251,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
},
|
||||
}
|
||||
engine.wgInterface = wgIface
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil)
|
||||
_, _, err = engine.routeManager.Init()
|
||||
require.NoError(t, err)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
|
@ -13,12 +13,15 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const useNewDNSRoute = true
|
||||
|
||||
type routerPeerStatus struct {
|
||||
connected bool
|
||||
relayed bool
|
||||
@ -53,7 +56,16 @@ type clientNetwork struct {
|
||||
updateSerial uint64
|
||||
}
|
||||
|
||||
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
|
||||
func newClientNetworkWatcher(
|
||||
ctx context.Context,
|
||||
dnsRouteInterval time.Duration,
|
||||
wgInterface iface.IWGIface,
|
||||
statusRecorder *peer.Status,
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
dnsServer nbdns.Server,
|
||||
) *clientNetwork {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
client := &clientNetwork{
|
||||
@ -65,7 +77,7 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
|
||||
routePeersNotifiers: make(map[string]chan struct{}),
|
||||
routeUpdate: make(chan routesUpdate),
|
||||
peerStateUpdate: make(chan struct{}),
|
||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
|
||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface, dnsServer),
|
||||
}
|
||||
return client
|
||||
}
|
||||
@ -368,8 +380,19 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
}
|
||||
}
|
||||
|
||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler {
|
||||
func handlerFromRoute(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
dnsRouterInteval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface iface.IWGIface,
|
||||
dnsServer nbdns.Server,
|
||||
) RouteHandler {
|
||||
if rt.IsDynamic() {
|
||||
if useNewDNSRoute {
|
||||
return dnsinterceptor.New(rt, routeRefCounter, allowedIPsRefCounter, statusRecorder, dnsServer)
|
||||
}
|
||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
|
||||
}
|
||||
|
177
client/internal/routemanager/dnsinterceptor/handler.go
Normal file
177
client/internal/routemanager/dnsinterceptor/handler.go
Normal file
@ -0,0 +1,177 @@
|
||||
package dnsinterceptor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type RouteMatchHandler struct {
|
||||
mu sync.RWMutex
|
||||
route *route.Route
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||
statusRecorder *peer.Status
|
||||
dnsServer nbdns.Server
|
||||
currentPeerKey string
|
||||
domainRoutes map[string]*route.Route
|
||||
}
|
||||
|
||||
func New(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
statusRecorder *peer.Status,
|
||||
dnsServer nbdns.Server,
|
||||
) routemanager.RouteHandler {
|
||||
|
||||
return &RouteMatchHandler{
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
statusRecorder: statusRecorder,
|
||||
dnsServer: dnsServer,
|
||||
domainRoutes: make(map[string]*route.Route),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RouteMatchHandler) String() string {
|
||||
return fmt.Sprintf("dns route for domains: %v", h.route.Domains)
|
||||
}
|
||||
|
||||
func (h *RouteMatchHandler) AddRoute(ctx context.Context) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
for _, domain := range h.route.Domains {
|
||||
pattern := dns.Fqdn(string(domain))
|
||||
h.domainRoutes[pattern] = h.route
|
||||
}
|
||||
|
||||
return h.dnsServer.RegisterHandler(h)
|
||||
}
|
||||
|
||||
func (h *RouteMatchHandler) RemoveRoute() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
h.domainRoutes = make(map[string]*route.Route)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *RouteMatchHandler) AddAllowedIPs(peerKey string) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.currentPeerKey = peerKey
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *RouteMatchHandler) RemoveAllowedIPs() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.currentPeerKey = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
type responseInterceptor struct {
|
||||
dns.ResponseWriter
|
||||
handler *RouteMatchHandler
|
||||
question dns.Question
|
||||
answered bool
|
||||
}
|
||||
|
||||
func (i *responseInterceptor) WriteMsg(resp *dns.Msg) error {
|
||||
if i.answered {
|
||||
return nil
|
||||
}
|
||||
i.answered = true
|
||||
|
||||
if resp == nil || len(resp.Answer) == 0 {
|
||||
return i.ResponseWriter.WriteMsg(resp)
|
||||
}
|
||||
|
||||
i.handler.mu.RLock()
|
||||
defer i.handler.mu.RUnlock()
|
||||
|
||||
questionName := i.question.Name
|
||||
for _, ans := range resp.Answer {
|
||||
var ip netip.Addr
|
||||
switch rr := ans.(type) {
|
||||
case *dns.A:
|
||||
addr, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
case *dns.AAAA:
|
||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if route := i.handler.findMatchingRoute(questionName); route != nil {
|
||||
i.handler.processMatch(route, questionName, ip)
|
||||
}
|
||||
}
|
||||
|
||||
return i.ResponseWriter.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (h *RouteMatchHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
interceptor := &responseInterceptor{
|
||||
ResponseWriter: w,
|
||||
handler: h,
|
||||
question: r.Question[0],
|
||||
}
|
||||
|
||||
h.dnsServer.ServeDNS(interceptor, r)
|
||||
}
|
||||
|
||||
func (h *RouteMatchHandler) findMatchingRoute(domain string) *route.Route {
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
if route, ok := h.domainRoutes[domain]; ok {
|
||||
return route
|
||||
}
|
||||
|
||||
labels := dns.SplitDomainName(domain)
|
||||
if labels == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := 0; i < len(labels); i++ {
|
||||
wildcard := "*." + strings.Join(labels[i:], ".") + "."
|
||||
if route, ok := h.domainRoutes[wildcard]; ok {
|
||||
return route
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *RouteMatchHandler) processMatch(route *route.Route, domain string, ip netip.Addr) {
|
||||
network := netip.PrefixFrom(ip, ip.BitLen())
|
||||
|
||||
if h.currentPeerKey == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.allowedIPsRefcounter.Increment(network, h.currentPeerKey); err != nil {
|
||||
log.Errorf("Failed to add allowed IP %s: %v", network, err)
|
||||
}
|
||||
}
|
@ -16,6 +16,7 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
@ -60,6 +61,7 @@ type DefaultManager struct {
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||
dnsRouteInterval time.Duration
|
||||
stateManager *statemanager.Manager
|
||||
dnsServer dns.Server
|
||||
}
|
||||
|
||||
func NewManager(
|
||||
@ -71,6 +73,7 @@ func NewManager(
|
||||
relayMgr *relayClient.Manager,
|
||||
initialRoutes []*route.Route,
|
||||
stateManager *statemanager.Manager,
|
||||
dnsServer dns.Server,
|
||||
) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
notifier := notifier.NewNotifier()
|
||||
@ -88,6 +91,7 @@ func NewManager(
|
||||
pubKey: pubKey,
|
||||
notifier: notifier,
|
||||
stateManager: stateManager,
|
||||
dnsServer: dnsServer,
|
||||
}
|
||||
|
||||
dm.routeRefCounter = refcounter.New(
|
||||
@ -273,7 +277,16 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
continue
|
||||
}
|
||||
|
||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
||||
clientNetworkWatcher := newClientNetworkWatcher(
|
||||
m.ctx,
|
||||
m.dnsRouteInterval,
|
||||
m.wgInterface,
|
||||
m.statusRecorder,
|
||||
routes[0],
|
||||
m.routeRefCounter,
|
||||
m.allowedIPsRefCounter,
|
||||
m.dnsServer,
|
||||
)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||
@ -302,7 +315,16 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
||||
clientNetworkWatcher = newClientNetworkWatcher(
|
||||
m.ctx,
|
||||
m.dnsRouteInterval,
|
||||
m.wgInterface,
|
||||
m.statusRecorder,
|
||||
routes[0],
|
||||
m.routeRefCounter,
|
||||
m.allowedIPsRefCounter,
|
||||
m.dnsServer,
|
||||
)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user