Implement dns routes for Android

This commit is contained in:
Viktor Liu 2025-06-16 20:52:12 +02:00
parent 8df8c1012f
commit bb74e903cd
15 changed files with 1036 additions and 166 deletions

View File

@ -203,10 +203,6 @@ func (c *Client) Networks() *NetworkArray {
continue
}
if routes[0].IsDynamic() {
continue
}
peer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)

View File

@ -104,6 +104,11 @@ type Manager struct {
flowLogger nftypes.FlowLogger
blockRule firewall.Rule
// Internal 1:1 DNAT
dnatEnabled atomic.Bool
dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex
}
// decoder for packages
@ -189,6 +194,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr),
}
m.routingEnabled.Store(false)
@ -519,22 +525,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
@ -608,6 +598,14 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
return false
}
translated := m.translateOutboundDNAT(packetData, d)
if translated {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error("Failed to re-decode packet after DNAT: %v", err)
return false
}
}
srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
@ -618,7 +616,6 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
return true
}
// for netflow we keep track even if the firewall is stateless
m.trackOutbound(d, srcIP, dstIP, size)
return false
@ -747,9 +744,17 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
return false
}
// For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked.
// Step 1: Check connection tracking FIRST (with original addresses)
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
// Step 2: Apply reverse DNAT for established connections
translated := m.translateInboundReverse(packetData, d)
if translated {
// Re-decode after translation
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
return true
}
}
return false
}

View File

@ -0,0 +1,309 @@
package uspfilter
import (
"encoding/binary"
"fmt"
"net/netip"
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 {
return 0
}
var sum uint32
for i := 0; i < len(header)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
}
if len(header)%2 == 1 {
sum += uint32(header[len(header)-1]) << 8
}
for (sum >> 16) > 0 {
sum = (sum & 0xFFFF) + (sum >> 16)
}
return ^uint16(sum)
}
func icmpChecksum(data []byte) uint16 {
var sum uint32
for i := 0; i < len(data)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
}
if len(data)%2 == 1 {
sum += uint32(data[len(data)-1]) << 8
}
for (sum >> 16) > 0 {
sum = (sum & 0xFFFF) + (sum >> 16)
}
return ^uint16(sum)
}
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
return fmt.Errorf("invalid IP addresses")
}
if m.localipmanager.IsLocalIP(translatedAddr) {
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
}
m.dnatMutex.Lock()
m.dnatMappings[originalAddr] = translatedAddr
if len(m.dnatMappings) == 1 {
m.dnatEnabled.Store(true)
}
m.dnatMutex.Unlock()
return nil
}
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
if _, exists := m.dnatMappings[originalAddr]; !exists {
return fmt.Errorf("mapping not found for: %s", originalAddr)
}
delete(m.dnatMappings, originalAddr)
if len(m.dnatMappings) == 0 {
m.dnatEnabled.Store(false)
}
return nil
}
// getDNATTranslation returns the translated address if a mapping exists
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return addr, false
}
m.dnatMutex.RLock()
translated, exists := m.dnatMappings[addr]
m.dnatMutex.RUnlock()
return translated, exists
}
// findReverseDNATMapping finds original address for return traffic
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return translatedAddr, false
}
m.dnatMutex.RLock()
defer m.dnatMutex.RUnlock()
for original, translated := range m.dnatMappings {
if translated == translatedAddr {
return original, true
}
}
return translatedAddr, false
}
// translateOutboundDNAT applies DNAT translation to outbound packets
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
_, dstIP := m.extractIPs(d)
if !dstIP.IsValid() || !dstIP.Is4() {
return false
}
translatedIP, exists := m.getDNATTranslation(dstIP)
if !exists {
return false
}
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error("Failed to rewrite packet destination: %v", err)
return false
}
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
return true
}
// translateInboundReverse applies reverse DNAT to inbound return traffic
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
srcIP, _ := m.extractIPs(d)
if !srcIP.IsValid() || !srcIP.Is4() {
return false
}
originalIP, exists := m.findReverseDNATMapping(srcIP)
if !exists {
return false
}
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error("Failed to rewrite packet source: %v", err)
return false
}
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
return true
}
// rewritePacketDestination replaces destination IP in the packet
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
if d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return fmt.Errorf("only IPv4 supported")
}
oldDst := make([]byte, 4)
copy(oldDst, packetData[16:20])
newDst := newIP.AsSlice()
copy(packetData[16:20], newDst)
ipHeaderLen := int(d.ip4.IHL) * 4
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst, newDst)
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst, newDst)
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return fmt.Errorf("only IPv4 supported")
}
oldSrc := make([]byte, 4)
copy(oldSrc, packetData[12:16])
newSrc := newIP.AsSlice()
copy(packetData[12:16], newSrc)
ipHeaderLen := int(d.ip4.IHL) * 4
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc, newSrc)
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc, newSrc)
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 {
return
}
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return
}
checksumOffset := udpStart + 6
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum == 0 {
return
}
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 {
return
}
icmpData := packetData[icmpStart:]
binary.BigEndian.PutUint16(icmpData[2:4], 0)
checksum := icmpChecksum(icmpData)
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
for i := 0; i < len(oldBytes)-1; i += 2 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
}
if len(oldBytes)%2 == 1 {
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
}
for i := 0; i < len(newBytes)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
}
if len(newBytes)%2 == 1 {
sum += uint32(newBytes[len(newBytes)-1]) << 8
}
for (sum >> 16) > 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return ^uint16(sum)
}
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}

View File

@ -488,9 +488,9 @@ func (e *Engine) createFirewall() error {
}
func (e *Engine) initFirewall() error {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
return fmt.Errorf("set firewall: %w", err)
}
if e.config.BlockLANAccess {

View File

@ -10,11 +10,10 @@ import (
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
@ -553,41 +552,16 @@ func (w *Watcher) Stop() {
w.currentChosenStatus = nil
}
func HandlerFromRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsRouterInteval time.Duration,
statusRecorder *peer.Status,
wgInterface iface.WGIface,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
useNewDNSRoute bool,
) RouteHandler {
switch handlerType(rt, useNewDNSRoute) {
func HandlerFromRoute(params common.HandlerParams) RouteHandler {
switch handlerType(params.Route, params.UseNewDNSRoute) {
case handlerTypeDnsInterceptor:
return dnsinterceptor.New(
rt,
routeRefCounter,
allowedIPsRefCounter,
statusRecorder,
dnsServer,
wgInterface,
peerStore,
)
return dnsinterceptor.New(params)
case handlerTypeDynamic:
dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouterInteval,
statusRecorder,
wgInterface,
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
)
dns := nbdns.NewServiceViaMemory(params.WgInterface)
dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())
return dynamic.NewRoute(params, dnsAddr)
default:
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
return static.NewRoute(params)
}
}

View File

@ -0,0 +1,28 @@
package common
import (
"time"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
)
type HandlerParams struct {
Route *route.Route
RouteRefCounter *refcounter.RouteRefCounter
AllowedIPsRefCounter *refcounter.AllowedIPsRefCounter
DnsRouterInteval time.Duration
StatusRecorder *peer.Status
WgInterface iface.WGIface
DnsServer dns.Server
PeerStore *peerstore.Store
UseNewDNSRoute bool
Firewall manager.Manager
FakeIPManager *fakeip.FakeIPManager
}

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/netip"
"runtime"
"strings"
"sync"
@ -12,11 +13,14 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
@ -24,6 +28,11 @@ import (
type domainMap map[domain.Domain][]netip.Prefix
type internalDNATer interface {
RemoveInternalDNATMapping(netip.Addr) error
AddInternalDNATMapping(netip.Addr, netip.Addr) error
}
type wgInterface interface {
Name() string
Address() wgaddr.Address
@ -40,26 +49,22 @@ type DnsInterceptor struct {
interceptedDomains domainMap
wgInterface wgInterface
peerStore *peerstore.Store
firewall firewall.Manager
fakeIPManager *fakeip.FakeIPManager
}
func New(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
statusRecorder *peer.Status,
dnsServer nbdns.Server,
wgInterface wgInterface,
peerStore *peerstore.Store,
) *DnsInterceptor {
func New(params common.HandlerParams) *DnsInterceptor {
return &DnsInterceptor{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
statusRecorder: statusRecorder,
dnsServer: dnsServer,
wgInterface: wgInterface,
route: params.Route,
routeRefCounter: params.RouteRefCounter,
allowedIPsRefcounter: params.AllowedIPsRefCounter,
statusRecorder: params.StatusRecorder,
dnsServer: params.DnsServer,
wgInterface: params.WgInterface,
peerStore: params.PeerStore,
firewall: params.Firewall,
fakeIPManager: params.FakeIPManager,
interceptedDomains: make(domainMap),
peerStore: peerStore,
}
}
@ -78,9 +83,13 @@ func (d *DnsInterceptor) RemoveRoute() error {
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
// Routes should use fake IPs
routePrefix := d.transformRealToFakePrefix(prefix)
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", routePrefix, err))
}
// AllowedIPs should use real IPs
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
@ -88,8 +97,10 @@ func (d *DnsInterceptor) RemoveRoute() error {
}
}
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
}
d.cleanupDNATMappings()
for _, domain := range d.route.Domains {
d.statusRecorder.DeleteResolvedDomainsStates(domain)
}
@ -102,6 +113,68 @@ func (d *DnsInterceptor) RemoveRoute() error {
return nberrors.FormatErrorOrNil(merr)
}
// transformRealToFakePrefix returns fake IP prefix for routes (if DNAT enabled)
func (d *DnsInterceptor) transformRealToFakePrefix(realPrefix netip.Prefix) netip.Prefix {
if _, hasDNAT := d.internalDnatFw(); !hasDNAT {
return realPrefix
}
if fakeIP, ok := d.fakeIPManager.GetFakeIP(realPrefix.Addr()); ok {
return netip.PrefixFrom(fakeIP, realPrefix.Bits())
}
return realPrefix
}
// addAllowedIPForPrefix handles the AllowedIPs logic for a single prefix (uses real IPs)
func (d *DnsInterceptor) addAllowedIPForPrefix(realPrefix netip.Prefix, peerKey string, domain domain.Domain) error {
// AllowedIPs always use real IPs
ref, err := d.allowedIPsRefcounter.Increment(realPrefix, peerKey)
if err != nil {
return fmt.Errorf("add allowed IP %s: %v", realPrefix, err)
}
if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
realPrefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
return nil
}
// addRouteAndAllowedIP handles both route and AllowedIPs addition for a prefix
func (d *DnsInterceptor) addRouteAndAllowedIP(realPrefix netip.Prefix, domain domain.Domain) error {
// Routes use fake IPs (so traffic to fake IPs gets routed to interface)
routePrefix := d.transformRealToFakePrefix(realPrefix)
if _, err := d.routeRefCounter.Increment(routePrefix, struct{}{}); err != nil {
return fmt.Errorf("add route for IP %s: %v", routePrefix, err)
}
// Add to AllowedIPs if we have a current peer (uses real IPs)
if d.currentPeerKey == "" {
return nil
}
return d.addAllowedIPForPrefix(realPrefix, d.currentPeerKey, domain)
}
// removeAllowedIP handles AllowedIPs removal for a prefix (uses real IPs)
func (d *DnsInterceptor) removeAllowedIP(realPrefix netip.Prefix) error {
if d.currentPeerKey == "" {
return nil
}
// AllowedIPs use real IPs
if _, err := d.allowedIPsRefcounter.Decrement(realPrefix); err != nil {
return fmt.Errorf("remove allowed IP %s: %v", realPrefix, err)
}
return nil
}
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
d.mu.Lock()
defer d.mu.Unlock()
@ -109,14 +182,9 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
// AllowedIPs use real IPs
if err := d.addAllowedIPForPrefix(prefix, peerKey, domain); err != nil {
merr = multierror.Append(merr, err)
}
}
}
@ -132,6 +200,7 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
var merr *multierror.Error
for _, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
// AllowedIPs use real IPs
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
@ -284,6 +353,8 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
}
d.replaceIPsInDNSResponse(r, newPrefixes)
}
}
@ -294,6 +365,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
return nil
}
// logPrefixChanges handles the logging for prefix changes
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
}
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
d.mu.Lock()
defer d.mu.Unlock()
@ -302,70 +389,184 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
var merr *multierror.Error
var dnatMappings map[netip.Addr]netip.Addr
// Handle DNAT mappings for new prefixes
if _, hasDNAT := d.internalDnatFw(); hasDNAT {
dnatMappings = make(map[netip.Addr]netip.Addr)
for _, prefix := range toAdd {
realIP := prefix.Addr()
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
dnatMappings[fakeIP] = realIP
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
} else {
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
}
}
}
// Add new prefixes
for _, prefix := range toAdd {
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
continue
}
if d.currentPeerKey == "" {
continue
}
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
resolvedDomain.SafeString(),
ref.Out,
)
if err := d.addRouteAndAllowedIP(prefix, resolvedDomain); err != nil {
merr = multierror.Append(merr, err)
}
}
d.addDNATMappings(dnatMappings)
if !d.route.KeepRoute {
// Remove old prefixes
for _, prefix := range toRemove {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
// Routes use fake IPs
routePrefix := d.transformRealToFakePrefix(prefix)
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", routePrefix, err))
}
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
// AllowedIPs use real IPs
if err := d.removeAllowedIP(prefix); err != nil {
merr = multierror.Append(merr, err)
}
}
d.removeDNATMappingsForRealIPs(toRemove)
}
// Update domain prefixes using resolved domain as key
// Update domain prefixes using resolved domain as key - store real IPs
if len(toAdd) > 0 || len(toRemove) > 0 {
if d.route.KeepRoute {
// replace stored prefixes with old + added
// nolint:gocritic
newPrefixes = append(oldPrefixes, toAdd...)
}
d.interceptedDomains[resolvedDomain] = newPrefixes
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
// Store real IPs for status (user-facing), not fake IPs
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
}
return nberrors.FormatErrorOrNil(merr)
}
// removeDNATMappingsForRealIPs removes DNAT mappings from the firewall for real IP prefixes
func (d *DnsInterceptor) removeDNATMappingsForRealIPs(realPrefixes []netip.Prefix) {
if len(realPrefixes) == 0 {
return
}
dnatFirewall, ok := d.internalDnatFw()
if !ok {
return
}
for _, prefix := range realPrefixes {
realIP := prefix.Addr()
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
} else {
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
}
}
}
}
// internalDnatFw checks if the firewall supports internal DNAT
func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
if d.firewall == nil || runtime.GOOS != "android" {
return nil, false
}
fw, ok := d.firewall.(internalDNATer)
return fw, ok
}
// addDNATMappings adds DNAT mappings to the firewall
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
if len(mappings) == 0 {
return
}
dnatFirewall, ok := d.internalDnatFw()
if !ok {
return
}
for fakeIP, realIP := range mappings {
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
} else {
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
}
}
}
// removeDNATMappings removes DNAT mappings from the firewall for removed prefixes
func (d *DnsInterceptor) removeDNATMappings(prefixes []netip.Prefix) {
if len(prefixes) == 0 {
return
}
dnatFirewall, ok := d.internalDnatFw()
if !ok {
return
}
for _, prefix := range prefixes {
fakeIP := prefix.Addr()
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
} else {
log.Debugf("Removed DNAT mapping for: %s", fakeIP)
}
}
}
// cleanupDNATMappings removes all DNAT mappings for this interceptor
func (d *DnsInterceptor) cleanupDNATMappings() {
if _, ok := d.internalDnatFw(); !ok {
return
}
for _, prefixes := range d.interceptedDomains {
d.removeDNATMappingsForRealIPs(prefixes)
}
}
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
if _, ok := d.internalDnatFw(); !ok {
return
}
// Replace A and AAAA records with fake IPs
for _, answer := range reply.Answer {
switch rr := answer.(type) {
case *dns.A:
realIP, ok := netip.AddrFromSlice(rr.A)
if !ok {
continue
}
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.A = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
case *dns.AAAA:
realIP, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
continue
}
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.AAAA = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
}
}
}
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
prefixSet := make(map[netip.Prefix]bool)
for _, prefix := range oldPrefixes {

View File

@ -14,6 +14,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
@ -52,24 +53,16 @@ type Route struct {
resolverAddr string
}
func NewRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
wgInterface iface.WGIface,
resolverAddr string,
) *Route {
func NewRoute(params common.HandlerParams, resolverAddr string) *Route {
return &Route{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
interval: interval,
dynamicDomains: domainMap{},
statusRecorder: statusRecorder,
wgInterface: wgInterface,
route: params.Route,
routeRefCounter: params.RouteRefCounter,
allowedIPsRefcounter: params.AllowedIPsRefCounter,
interval: params.DnsRouterInteval,
statusRecorder: params.StatusRecorder,
wgInterface: params.WgInterface,
resolverAddr: resolverAddr,
dynamicDomains: domainMap{},
}
}

View File

@ -0,0 +1,93 @@
package fakeip
import (
"fmt"
"net/netip"
"sync"
)
// FakeIPManager manages allocation of fake IPs from the 240.0.0.0/8 block
type FakeIPManager struct {
mu sync.Mutex
nextIP netip.Addr // Next IP to allocate
allocated map[netip.Addr]netip.Addr // real IP -> fake IP
fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP
baseIP netip.Addr // First usable IP: 240.0.0.1
maxIP netip.Addr // Last usable IP: 240.255.255.254
}
// NewManager creates a new fake IP manager using 240.0.0.0/8 block
func NewManager() *FakeIPManager {
baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1})
maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254})
return &FakeIPManager{
nextIP: baseIP,
allocated: make(map[netip.Addr]netip.Addr),
fakeToReal: make(map[netip.Addr]netip.Addr),
baseIP: baseIP,
maxIP: maxIP,
}
}
// AllocateFakeIP allocates a fake IP for the given real IP
// Returns the fake IP, or existing fake IP if already allocated
func (f *FakeIPManager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
if !realIP.Is4() {
return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported")
}
f.mu.Lock()
defer f.mu.Unlock()
if fakeIP, exists := f.allocated[realIP]; exists {
return fakeIP, nil
}
startIP := f.nextIP
for {
currentIP := f.nextIP
// Advance to next IP, wrapping at boundary
if f.nextIP.Compare(f.maxIP) >= 0 {
f.nextIP = f.baseIP
} else {
f.nextIP = f.nextIP.Next()
}
// Check if current IP is available
if _, inUse := f.fakeToReal[currentIP]; !inUse {
f.allocated[realIP] = currentIP
f.fakeToReal[currentIP] = realIP
return currentIP, nil
}
// Prevent infinite loop if all IPs exhausted
if f.nextIP.Compare(startIP) == 0 {
return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block")
}
}
}
// GetFakeIP returns the fake IP for a real IP if it exists
func (f *FakeIPManager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) {
f.mu.Lock()
defer f.mu.Unlock()
fakeIP, exists := f.allocated[realIP]
return fakeIP, exists
}
// GetRealIP returns the real IP for a fake IP if it exists, otherwise false
func (f *FakeIPManager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) {
f.mu.Lock()
defer f.mu.Unlock()
realIP, exists := f.fakeToReal[fakeIP]
return realIP, exists
}
// GetFakeIPBlock returns the fake IP block used by this manager
func (f *FakeIPManager) GetFakeIPBlock() netip.Prefix {
return netip.MustParsePrefix("240.0.0.0/8")
}

View File

@ -0,0 +1,242 @@
package fakeip
import (
"net/netip"
"sync"
"testing"
)
func TestNewManager(t *testing.T) {
manager := NewManager()
if manager.baseIP.String() != "240.0.0.1" {
t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String())
}
if manager.maxIP.String() != "240.255.255.254" {
t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String())
}
if manager.nextIP.Compare(manager.baseIP) != 0 {
t.Errorf("Expected nextIP to start at baseIP")
}
}
func TestAllocateFakeIP(t *testing.T) {
manager := NewManager()
realIP := netip.MustParseAddr("8.8.8.8")
fakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate fake IP: %v", err)
}
if !fakeIP.Is4() {
t.Error("Fake IP should be IPv4")
}
// Check it's in the correct range
if fakeIP.As4()[0] != 240 {
t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String())
}
// Should return same fake IP for same real IP
fakeIP2, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to get existing fake IP: %v", err)
}
if fakeIP.Compare(fakeIP2) != 0 {
t.Errorf("Expected same fake IP for same real IP, got %s and %s", fakeIP.String(), fakeIP2.String())
}
}
func TestAllocateFakeIPIPv6Rejection(t *testing.T) {
manager := NewManager()
realIPv6 := netip.MustParseAddr("2001:db8::1")
_, err := manager.AllocateFakeIP(realIPv6)
if err == nil {
t.Error("Expected error for IPv6 address")
}
}
func TestGetFakeIP(t *testing.T) {
manager := NewManager()
realIP := netip.MustParseAddr("1.1.1.1")
// Should not exist initially
_, exists := manager.GetFakeIP(realIP)
if exists {
t.Error("Fake IP should not exist before allocation")
}
// Allocate and check
expectedFakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate: %v", err)
}
fakeIP, exists := manager.GetFakeIP(realIP)
if !exists {
t.Error("Fake IP should exist after allocation")
}
if fakeIP.Compare(expectedFakeIP) != 0 {
t.Errorf("Expected %s, got %s", expectedFakeIP.String(), fakeIP.String())
}
}
func TestMultipleAllocations(t *testing.T) {
manager := NewManager()
allocations := make(map[netip.Addr]netip.Addr)
// Allocate multiple IPs
for i := 1; i <= 100; i++ {
realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
fakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err)
}
// Check for duplicates
for _, existingFake := range allocations {
if fakeIP.Compare(existingFake) == 0 {
t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String())
}
}
allocations[realIP] = fakeIP
}
// Verify all allocations can be retrieved
for realIP, expectedFake := range allocations {
actualFake, exists := manager.GetFakeIP(realIP)
if !exists {
t.Errorf("Missing allocation for %s", realIP.String())
}
if actualFake.Compare(expectedFake) != 0 {
t.Errorf("Mismatch for %s: expected %s, got %s", realIP.String(), expectedFake.String(), actualFake.String())
}
}
}
func TestGetFakeIPBlock(t *testing.T) {
manager := NewManager()
block := manager.GetFakeIPBlock()
expected := "240.0.0.0/8"
if block.String() != expected {
t.Errorf("Expected %s, got %s", expected, block.String())
}
}
func TestConcurrentAccess(t *testing.T) {
manager := NewManager()
const numGoroutines = 50
const allocationsPerGoroutine = 10
var wg sync.WaitGroup
results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine)
// Concurrent allocations
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
for j := 0; j < allocationsPerGoroutine; j++ {
realIP := netip.AddrFrom4([4]byte{192, 168, byte(goroutineID), byte(j)})
fakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Errorf("Failed to allocate in goroutine %d: %v", goroutineID, err)
return
}
results <- fakeIP
}
}(i)
}
wg.Wait()
close(results)
// Check for duplicates
seen := make(map[netip.Addr]bool)
count := 0
for fakeIP := range results {
if seen[fakeIP] {
t.Errorf("Duplicate fake IP in concurrent test: %s", fakeIP.String())
}
seen[fakeIP] = true
count++
}
if count != numGoroutines*allocationsPerGoroutine {
t.Errorf("Expected %d allocations, got %d", numGoroutines*allocationsPerGoroutine, count)
}
}
func TestIPExhaustion(t *testing.T) {
// Create a manager with limited range for testing
manager := &FakeIPManager{
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
allocated: make(map[netip.Addr]netip.Addr),
fakeToReal: make(map[netip.Addr]netip.Addr),
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available
}
// Allocate all available IPs
realIPs := []netip.Addr{
netip.MustParseAddr("1.0.0.1"),
netip.MustParseAddr("1.0.0.2"),
netip.MustParseAddr("1.0.0.3"),
}
for _, realIP := range realIPs {
_, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate fake IP: %v", err)
}
}
// Try to allocate one more - should fail
_, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4"))
if err == nil {
t.Error("Expected exhaustion error")
}
}
func TestWrapAround(t *testing.T) {
// Create manager starting near the end of range
manager := &FakeIPManager{
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
allocated: make(map[netip.Addr]netip.Addr),
fakeToReal: make(map[netip.Addr]netip.Addr),
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
}
// Allocate the last IP
fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1"))
if err != nil {
t.Fatalf("Failed to allocate first IP: %v", err)
}
if fakeIP1.String() != "240.0.0.254" {
t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String())
}
// Next allocation should wrap around to the beginning
fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2"))
if err != nil {
t.Fatalf("Failed to allocate second IP: %v", err)
}
if fakeIP2.String() != "240.0.0.1" {
t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String())
}
}

View File

@ -11,6 +11,7 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
@ -24,6 +25,8 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
@ -38,6 +41,10 @@ import (
"github.com/netbirdio/netbird/version"
)
type internalDNATer interface {
AddInternalDNATMapping(netip.Addr, netip.Addr) error
}
// Manager is a route manager interface
type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
@ -49,7 +56,7 @@ type Manager interface {
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error
SetFirewall(firewall.Manager) error
Stop(stateManager *statemanager.Manager)
}
@ -89,11 +96,13 @@ type DefaultManager struct {
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
dnsServer dns.Server
firewall firewall.Manager
peerStore *peerstore.Store
useNewDNSRoute bool
disableClientRoutes bool
disableServerRoutes bool
activeRoutes map[route.HAUniqueID]client.RouteHandler
fakeIPManager *fakeip.FakeIPManager
}
func NewManager(config ManagerConfig) *DefaultManager {
@ -129,6 +138,8 @@ func NewManager(config ManagerConfig) *DefaultManager {
}
if runtime.GOOS == "android" {
dm.fakeIPManager = fakeip.NewManager()
cr := dm.initialClientRoutes(config.InitialRoutes)
dm.notifier.SetInitialClientRoutes(cr)
}
@ -222,16 +233,16 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
return routeselector.NewRouteSelector()
}
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
if m.disableServerRoutes {
// SetFirewall sets the firewall manager for the DefaultManager
// Not thread-safe, should be called before starting the manager
func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
m.firewall = firewall
if m.disableServerRoutes || firewall == nil {
log.Info("server routes are disabled")
return nil
}
if firewall == nil {
return errors.New("firewall manager is not set")
}
var err error
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
if err != nil {
@ -299,17 +310,20 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
}
for id, route := range toAdd {
handler := client.HandlerFromRoute(
route,
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsRouteInterval,
m.statusRecorder,
m.wgInterface,
m.dnsServer,
m.peerStore,
m.useNewDNSRoute,
)
params := common.HandlerParams{
Route: route,
RouteRefCounter: m.routeRefCounter,
AllowedIPsRefCounter: m.allowedIPsRefCounter,
DnsRouterInteval: m.dnsRouteInterval,
StatusRecorder: m.statusRecorder,
WgInterface: m.wgInterface,
DnsServer: m.dnsServer,
PeerStore: m.peerStore,
UseNewDNSRoute: m.useNewDNSRoute,
Firewall: m.firewall,
FakeIPManager: m.fakeIPManager,
}
handler := client.HandlerFromRoute(params)
if err := handler.AddRoute(m.ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
continue
@ -517,9 +531,27 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro
for _, routes := range crMap {
rs = append(rs, routes...)
}
fakeIPBlock := m.fakeIPManager.GetFakeIPBlock()
id := uuid.NewString()
fakeIPRoute := &route.Route{
ID: route.ID(id),
Network: fakeIPBlock,
NetID: route.NetID(id),
Peer: m.pubKey,
NetworkType: route.IPv4Network,
}
rs = append(rs, fakeIPRoute)
return rs
}
// supportsInternalDNAT checks if the firewall supports internal DNAT
func (m *DefaultManager) supportsInternalDNAT(fw firewall.Manager) bool {
_, ok := fw.(internalDNATer)
return ok
}
func isRouteSupported(route *route.Route) bool {
if netstack.IsEnabled() || !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
return true

View File

@ -15,7 +15,7 @@ import (
// MockManager is the mock instance of a route manager
type MockManager struct {
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
@ -87,7 +87,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
}
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
func (m *MockManager) SetFirewall(firewall.Manager) error {
panic("implement me")
}

View File

@ -32,10 +32,6 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0)
for _, r := range clientRoutes {
// filter out domain routes
if r.IsDynamic() {
continue
}
nets = append(nets, r.Network.String())
}
sort.Strings(nets)

View File

@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
)
@ -16,11 +17,11 @@ type Route struct {
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
}
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
func NewRoute(params common.HandlerParams) *Route {
return &Route{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
route: params.Route,
routeRefCounter: params.RouteRefCounter,
allowedIPsRefcounter: params.AllowedIPsRefCounter,
}
}