netbird/client/internal/routemanager/dnsinterceptor/handler.go
Viktor Liu ddc365f7a0
[client, management] Add new network concept (#3047)
---------

Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com>
Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
2024-12-20 11:30:28 +01:00

357 lines
10 KiB
Go

package dnsinterceptor
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
type domainMap map[domain.Domain][]netip.Prefix
type DnsInterceptor struct {
mu sync.RWMutex
route *route.Route
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
statusRecorder *peer.Status
dnsServer nbdns.Server
currentPeerKey string
interceptedDomains domainMap
peerStore *peerstore.Store
}
func New(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
statusRecorder *peer.Status,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
) *DnsInterceptor {
return &DnsInterceptor{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
statusRecorder: statusRecorder,
dnsServer: dnsServer,
interceptedDomains: make(domainMap),
peerStore: peerStore,
}
}
func (d *DnsInterceptor) String() string {
return d.route.Domains.SafeString()
}
func (d *DnsInterceptor) AddRoute(context.Context) error {
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute)
return nil
}
func (d *DnsInterceptor) RemoveRoute() error {
d.mu.Lock()
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
}
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
}
}
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
}
for _, domain := range d.route.Domains {
d.statusRecorder.DeleteResolvedDomainsStates(domain)
}
clear(d.interceptedDomains)
d.mu.Unlock()
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute)
return nberrors.FormatErrorOrNil(merr)
}
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
d.mu.Lock()
defer d.mu.Unlock()
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
}
}
d.currentPeerKey = peerKey
return nberrors.FormatErrorOrNil(merr)
}
func (d *DnsInterceptor) RemoveAllowedIPs() error {
d.mu.Lock()
defer d.mu.Unlock()
var merr *multierror.Error
for _, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
}
}
d.currentPeerKey = ""
return nberrors.FormatErrorOrNil(merr)
}
// ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
log.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
d.mu.RLock()
peerKey := d.currentPeerKey
d.mu.RUnlock()
if peerKey == "" {
log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name)
d.continueToNextHandler(w, r, "no current peer key")
return
}
upstreamIP, err := d.getUpstreamIP(peerKey)
if err != nil {
log.Errorf("failed to get upstream IP: %v", err)
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
return
}
client := &dns.Client{
Timeout: 5 * time.Second,
Net: "udp",
}
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
var answer []dns.RR
if reply != nil {
answer = reply.Answer
}
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer)
if err != nil {
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
return
}
reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
}
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed writing DNS continue response: %v", err)
}
}
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) {
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
if !exists {
return nil, fmt.Errorf("peer connection not found for key: %s", peerKey)
}
return peerAllowedIP, nil
}
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
if r == nil {
return fmt.Errorf("received nil DNS message")
}
if len(r.Answer) > 0 && len(r.Question) > 0 {
origPattern := ""
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
origPattern = writer.GetOrigPattern()
}
resolvedDomain := domain.Domain(r.Question[0].Name)
// already punycode via RegisterHandler()
originalDomain := domain.Domain(origPattern)
if originalDomain == "" {
originalDomain = resolvedDomain
}
var newPrefixes []netip.Prefix
for _, answer := range r.Answer {
var ip netip.Addr
switch rr := answer.(type) {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
continue
}
ip = addr
default:
continue
}
prefix := netip.PrefixFrom(ip, ip.BitLen())
newPrefixes = append(newPrefixes, prefix)
}
if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
}
}
}
if err := w.WriteMsg(r); err != nil {
return fmt.Errorf("failed to write DNS response: %v", err)
}
return nil
}
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
d.mu.Lock()
defer d.mu.Unlock()
oldPrefixes := d.interceptedDomains[resolvedDomain]
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
var merr *multierror.Error
// Add new prefixes
for _, prefix := range toAdd {
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
continue
}
if d.currentPeerKey == "" {
continue
}
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
resolvedDomain.SafeString(),
ref.Out,
)
}
}
if !d.route.KeepRoute {
// Remove old prefixes
for _, prefix := range toRemove {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
}
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
}
}
}
// Update domain prefixes using resolved domain as key
if len(toAdd) > 0 || len(toRemove) > 0 {
d.interceptedDomains[resolvedDomain] = newPrefixes
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes)
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
prefixSet := make(map[netip.Prefix]bool)
for _, prefix := range oldPrefixes {
prefixSet[prefix] = false
}
for _, prefix := range newPrefixes {
if _, exists := prefixSet[prefix]; exists {
prefixSet[prefix] = true
} else {
toAdd = append(toAdd, prefix)
}
}
for prefix, inUse := range prefixSet {
if !inUse {
toRemove = append(toRemove, prefix)
}
}
return
}