mirror of
https://github.com/netbirdio/netbird.git
synced 2025-07-13 21:10:47 +02:00
164 lines
4.1 KiB
Go
164 lines
4.1 KiB
Go
package dnsinterceptor
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/netip"
|
|
"sync"
|
|
|
|
"github.com/miekg/dns"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
|
"github.com/netbirdio/netbird/client/internal/peer"
|
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
|
"github.com/netbirdio/netbird/route"
|
|
)
|
|
|
|
type DnsInterceptor struct {
|
|
mu sync.RWMutex
|
|
route *route.Route
|
|
routeRefCounter *refcounter.RouteRefCounter
|
|
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
|
statusRecorder *peer.Status
|
|
dnsServer nbdns.Server
|
|
currentPeerKey string
|
|
interceptedIPs map[string]netip.Prefix
|
|
}
|
|
|
|
func New(
|
|
rt *route.Route,
|
|
routeRefCounter *refcounter.RouteRefCounter,
|
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
|
statusRecorder *peer.Status,
|
|
dnsServer nbdns.Server,
|
|
) *DnsInterceptor {
|
|
return &DnsInterceptor{
|
|
route: rt,
|
|
routeRefCounter: routeRefCounter,
|
|
allowedIPsRefcounter: allowedIPsRefCounter,
|
|
statusRecorder: statusRecorder,
|
|
dnsServer: dnsServer,
|
|
interceptedIPs: make(map[string]netip.Prefix),
|
|
}
|
|
}
|
|
|
|
func (h *DnsInterceptor) String() string {
|
|
s, err := h.route.Domains.String()
|
|
if err != nil {
|
|
return h.route.Domains.PunycodeString()
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (h *DnsInterceptor) AddRoute(context.Context) error {
|
|
return h.dnsServer.RegisterHandler(h.route.Domains.ToPunycodeList(), h)
|
|
}
|
|
|
|
func (h *DnsInterceptor) RemoveRoute() error {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
// Remove all intercepted IPs
|
|
for key, prefix := range h.interceptedIPs {
|
|
if _, err := h.routeRefCounter.Decrement(prefix); err != nil {
|
|
log.Errorf("Failed to remove route for IP %s: %v", prefix, err)
|
|
}
|
|
if h.currentPeerKey != "" {
|
|
if _, err := h.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
|
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
|
|
}
|
|
}
|
|
delete(h.interceptedIPs, key)
|
|
}
|
|
|
|
// TODO: remove from mux
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
h.currentPeerKey = peerKey
|
|
|
|
// Re-add all intercepted IPs for the new peer
|
|
for _, prefix := range h.interceptedIPs {
|
|
if _, err := h.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
|
log.Errorf("Failed to add allowed IP %s: %v", prefix, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *DnsInterceptor) RemoveAllowedIPs() error {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
if h.currentPeerKey != "" {
|
|
for _, prefix := range h.interceptedIPs {
|
|
if _, err := h.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
|
log.Errorf("Failed to remove allowed IP %s: %v", prefix, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
h.currentPeerKey = ""
|
|
return nil
|
|
}
|
|
|
|
// ServeDNS implements the dns.Handler interface
|
|
func (h *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
log.Debugf("Received DNS request: %v", r)
|
|
if len(r.Question) == 0 {
|
|
return
|
|
}
|
|
|
|
// Create response interceptor to capture the response
|
|
interceptor := &responseInterceptor{
|
|
ResponseWriter: w,
|
|
handler: h,
|
|
question: r.Question[0],
|
|
answered: false,
|
|
}
|
|
|
|
// Let the request pass through with our interceptor
|
|
err := interceptor.WriteMsg(r)
|
|
if err != nil {
|
|
log.Errorf("Failed writing DNS response: %v", err)
|
|
}
|
|
}
|
|
|
|
func (h *DnsInterceptor) processMatch(domain string, ip netip.Addr) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
network := netip.PrefixFrom(ip, ip.BitLen())
|
|
key := fmt.Sprintf("%s:%s", domain, network.String())
|
|
|
|
if _, exists := h.interceptedIPs[key]; exists {
|
|
return
|
|
}
|
|
|
|
if _, err := h.routeRefCounter.Increment(network, struct{}{}); err != nil {
|
|
log.Errorf("Failed to add route for IP %s: %v", network, err)
|
|
return
|
|
}
|
|
|
|
if h.currentPeerKey != "" {
|
|
if _, err := h.allowedIPsRefcounter.Increment(network, h.currentPeerKey); err != nil {
|
|
log.Errorf("Failed to add allowed IP %s: %v", network, err)
|
|
// Rollback route addition
|
|
if _, err := h.routeRefCounter.Decrement(network); err != nil {
|
|
log.Errorf("Failed to rollback route addition for IP %s: %v", network, err)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
h.interceptedIPs[key] = network
|
|
log.Debugf("Added route for domain %s -> %s", domain, network)
|
|
}
|