mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-09 07:15:15 +02:00
[client, management] Add new network concept (#3047)
--------- Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
This commit is contained in:
356
client/internal/routemanager/dnsinterceptor/handler.go
Normal file
356
client/internal/routemanager/dnsinterceptor/handler.go
Normal file
@ -0,0 +1,356 @@
|
||||
package dnsinterceptor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type domainMap map[domain.Domain][]netip.Prefix
|
||||
|
||||
type DnsInterceptor struct {
|
||||
mu sync.RWMutex
|
||||
route *route.Route
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||
statusRecorder *peer.Status
|
||||
dnsServer nbdns.Server
|
||||
currentPeerKey string
|
||||
interceptedDomains domainMap
|
||||
peerStore *peerstore.Store
|
||||
}
|
||||
|
||||
func New(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
statusRecorder *peer.Status,
|
||||
dnsServer nbdns.Server,
|
||||
peerStore *peerstore.Store,
|
||||
) *DnsInterceptor {
|
||||
return &DnsInterceptor{
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
statusRecorder: statusRecorder,
|
||||
dnsServer: dnsServer,
|
||||
interceptedDomains: make(domainMap),
|
||||
peerStore: peerStore,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) String() string {
|
||||
return d.route.Domains.SafeString()
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) AddRoute(context.Context) error {
|
||||
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) RemoveRoute() error {
|
||||
d.mu.Lock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
|
||||
}
|
||||
if d.currentPeerKey != "" {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||
|
||||
}
|
||||
for _, domain := range d.route.Domains {
|
||||
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||
}
|
||||
|
||||
clear(d.interceptedDomains)
|
||||
d.mu.Unlock()
|
||||
|
||||
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute)
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d.currentPeerKey = peerKey
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d.currentPeerKey = ""
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// ServeDNS implements the dns.Handler interface
|
||||
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
log.Tracef("received DNS request for domain=%s type=%v class=%v",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
|
||||
d.mu.RLock()
|
||||
peerKey := d.currentPeerKey
|
||||
d.mu.RUnlock()
|
||||
|
||||
if peerKey == "" {
|
||||
log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name)
|
||||
|
||||
d.continueToNextHandler(w, r, "no current peer key")
|
||||
return
|
||||
}
|
||||
|
||||
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get upstream IP: %v", err)
|
||||
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Net: "udp",
|
||||
}
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
|
||||
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
||||
|
||||
var answer []dns.RR
|
||||
if reply != nil {
|
||||
answer = reply.Answer
|
||||
}
|
||||
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
reply.Id = r.Id
|
||||
if err := d.writeMsg(w, reply); err != nil {
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// continueToNextHandler signals the handler chain to try the next handler
|
||||
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
||||
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
// Set Zero bit to signal handler chain to continue
|
||||
resp.MsgHdr.Zero = true
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed writing DNS continue response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) {
|
||||
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("peer connection not found for key: %s", peerKey)
|
||||
}
|
||||
return peerAllowedIP, nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
if r == nil {
|
||||
return fmt.Errorf("received nil DNS message")
|
||||
}
|
||||
|
||||
if len(r.Answer) > 0 && len(r.Question) > 0 {
|
||||
origPattern := ""
|
||||
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
|
||||
origPattern = writer.GetOrigPattern()
|
||||
}
|
||||
|
||||
resolvedDomain := domain.Domain(r.Question[0].Name)
|
||||
|
||||
// already punycode via RegisterHandler()
|
||||
originalDomain := domain.Domain(origPattern)
|
||||
if originalDomain == "" {
|
||||
originalDomain = resolvedDomain
|
||||
}
|
||||
|
||||
var newPrefixes []netip.Prefix
|
||||
for _, answer := range r.Answer {
|
||||
var ip netip.Addr
|
||||
switch rr := answer.(type) {
|
||||
case *dns.A:
|
||||
addr, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
case *dns.AAAA:
|
||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(ip, ip.BitLen())
|
||||
newPrefixes = append(newPrefixes, prefix)
|
||||
}
|
||||
|
||||
if len(newPrefixes) > 0 {
|
||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
||||
log.Errorf("failed to update domain prefixes: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(r); err != nil {
|
||||
return fmt.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
oldPrefixes := d.interceptedDomains[resolvedDomain]
|
||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
// Add new prefixes
|
||||
for _, prefix := range toAdd {
|
||||
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if d.currentPeerKey == "" {
|
||||
continue
|
||||
}
|
||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
resolvedDomain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if !d.route.KeepRoute {
|
||||
// Remove old prefixes
|
||||
for _, prefix := range toRemove {
|
||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
|
||||
}
|
||||
if d.currentPeerKey != "" {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update domain prefixes using resolved domain as key
|
||||
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||
d.interceptedDomains[resolvedDomain] = newPrefixes
|
||||
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
||||
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes)
|
||||
|
||||
if len(toAdd) > 0 {
|
||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toAdd)
|
||||
}
|
||||
if len(toRemove) > 0 {
|
||||
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toRemove)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||
prefixSet := make(map[netip.Prefix]bool)
|
||||
for _, prefix := range oldPrefixes {
|
||||
prefixSet[prefix] = false
|
||||
}
|
||||
for _, prefix := range newPrefixes {
|
||||
if _, exists := prefixSet[prefix]; exists {
|
||||
prefixSet[prefix] = true
|
||||
} else {
|
||||
toAdd = append(toAdd, prefix)
|
||||
}
|
||||
}
|
||||
for prefix, inUse := range prefixSet {
|
||||
if !inUse {
|
||||
toRemove = append(toRemove, prefix)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
Reference in New Issue
Block a user