mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-19 00:06:58 +02:00
Implement dns routes for Android
This commit is contained in:
parent
8df8c1012f
commit
bb74e903cd
@ -203,10 +203,6 @@ func (c *Client) Networks() *NetworkArray {
|
||||
continue
|
||||
}
|
||||
|
||||
if routes[0].IsDynamic() {
|
||||
continue
|
||||
}
|
||||
|
||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||
if err != nil {
|
||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||
|
@ -104,6 +104,11 @@ type Manager struct {
|
||||
flowLogger nftypes.FlowLogger
|
||||
|
||||
blockRule firewall.Rule
|
||||
|
||||
// Internal 1:1 DNAT
|
||||
dnatEnabled atomic.Bool
|
||||
dnatMappings map[netip.Addr]netip.Addr
|
||||
dnatMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@ -189,6 +194,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
@ -519,22 +525,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// UpdateSet updates the rule destinations associated with the given set
|
||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
@ -608,6 +598,14 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
translated := m.translateOutboundDNAT(packetData, d)
|
||||
if translated {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error("Failed to re-decode packet after DNAT: %v", err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if !srcIP.IsValid() {
|
||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||
@ -618,7 +616,6 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// for netflow we keep track even if the firewall is stateless
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
|
||||
return false
|
||||
@ -747,9 +744,17 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// For all inbound traffic, first check if it matches a tracked connection.
|
||||
// This must happen before any other filtering because the packets are statefully tracked.
|
||||
// Step 1: Check connection tracking FIRST (with original addresses)
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||
// Step 2: Apply reverse DNAT for established connections
|
||||
translated := m.translateInboundReverse(packetData, d)
|
||||
if translated {
|
||||
// Re-decode after translation
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
309
client/firewall/uspfilter/nat.go
Normal file
309
client/firewall/uspfilter/nat.go
Normal file
@ -0,0 +1,309 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
func ipv4Checksum(header []byte) uint16 {
|
||||
if len(header) < 20 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var sum uint32
|
||||
for i := 0; i < len(header)-1; i += 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
|
||||
}
|
||||
|
||||
if len(header)%2 == 1 {
|
||||
sum += uint32(header[len(header)-1]) << 8
|
||||
}
|
||||
|
||||
for (sum >> 16) > 0 {
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
}
|
||||
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
func icmpChecksum(data []byte) uint16 {
|
||||
var sum uint32
|
||||
for i := 0; i < len(data)-1; i += 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||
}
|
||||
|
||||
if len(data)%2 == 1 {
|
||||
sum += uint32(data[len(data)-1]) << 8
|
||||
}
|
||||
|
||||
for (sum >> 16) > 0 {
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
}
|
||||
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
|
||||
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
|
||||
return fmt.Errorf("invalid IP addresses")
|
||||
}
|
||||
|
||||
if m.localipmanager.IsLocalIP(translatedAddr) {
|
||||
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
|
||||
}
|
||||
|
||||
m.dnatMutex.Lock()
|
||||
m.dnatMappings[originalAddr] = translatedAddr
|
||||
if len(m.dnatMappings) == 1 {
|
||||
m.dnatEnabled.Store(true)
|
||||
}
|
||||
m.dnatMutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
|
||||
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
|
||||
if _, exists := m.dnatMappings[originalAddr]; !exists {
|
||||
return fmt.Errorf("mapping not found for: %s", originalAddr)
|
||||
}
|
||||
|
||||
delete(m.dnatMappings, originalAddr)
|
||||
if len(m.dnatMappings) == 0 {
|
||||
m.dnatEnabled.Store(false)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDNATTranslation returns the translated address if a mapping exists
|
||||
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return addr, false
|
||||
}
|
||||
|
||||
m.dnatMutex.RLock()
|
||||
translated, exists := m.dnatMappings[addr]
|
||||
m.dnatMutex.RUnlock()
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
// findReverseDNATMapping finds original address for return traffic
|
||||
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return translatedAddr, false
|
||||
}
|
||||
|
||||
m.dnatMutex.RLock()
|
||||
defer m.dnatMutex.RUnlock()
|
||||
|
||||
for original, translated := range m.dnatMappings {
|
||||
if translated == translatedAddr {
|
||||
return original, true
|
||||
}
|
||||
}
|
||||
|
||||
return translatedAddr, false
|
||||
}
|
||||
|
||||
// translateOutboundDNAT applies DNAT translation to outbound packets
|
||||
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
_, dstIP := m.extractIPs(d)
|
||||
if !dstIP.IsValid() || !dstIP.Is4() {
|
||||
return false
|
||||
}
|
||||
|
||||
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
|
||||
m.logger.Error("Failed to rewrite packet destination: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// translateInboundReverse applies reverse DNAT to inbound return traffic
|
||||
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP, _ := m.extractIPs(d)
|
||||
if !srcIP.IsValid() || !srcIP.Is4() {
|
||||
return false
|
||||
}
|
||||
|
||||
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
|
||||
m.logger.Error("Failed to rewrite packet source: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// rewritePacketDestination replaces destination IP in the packet
|
||||
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
return fmt.Errorf("only IPv4 supported")
|
||||
}
|
||||
|
||||
oldDst := make([]byte, 4)
|
||||
copy(oldDst, packetData[16:20])
|
||||
newDst := newIP.AsSlice()
|
||||
|
||||
copy(packetData[16:20], newDst)
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst, newDst)
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst, newDst)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewritePacketSource replaces the source IP address in the packet
|
||||
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
return fmt.Errorf("only IPv4 supported")
|
||||
}
|
||||
|
||||
oldSrc := make([]byte, 4)
|
||||
copy(oldSrc, packetData[12:16])
|
||||
newSrc := newIP.AsSlice()
|
||||
|
||||
copy(packetData[12:16], newSrc)
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc, newSrc)
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc, newSrc)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+18 {
|
||||
return
|
||||
}
|
||||
|
||||
checksumOffset := tcpStart + 16
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
udpStart := ipHeaderLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
return
|
||||
}
|
||||
|
||||
checksumOffset := udpStart + 6
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
|
||||
if oldChecksum == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
icmpStart := ipHeaderLen
|
||||
if len(packetData) < icmpStart+8 {
|
||||
return
|
||||
}
|
||||
|
||||
icmpData := packetData[icmpStart:]
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], 0)
|
||||
checksum := icmpChecksum(icmpData)
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||
}
|
||||
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624
|
||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
sum := uint32(^oldChecksum)
|
||||
|
||||
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
|
||||
}
|
||||
if len(oldBytes)%2 == 1 {
|
||||
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
|
||||
}
|
||||
|
||||
for i := 0; i < len(newBytes)-1; i += 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
|
||||
}
|
||||
if len(newBytes)%2 == 1 {
|
||||
sum += uint32(newBytes[len(newBytes)-1]) << 8
|
||||
}
|
||||
|
||||
for (sum >> 16) > 0 {
|
||||
sum = (sum & 0xffff) + (sum >> 16)
|
||||
}
|
||||
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
}
|
@ -488,9 +488,9 @@ func (e *Engine) createFirewall() error {
|
||||
}
|
||||
|
||||
func (e *Engine) initFirewall() error {
|
||||
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
||||
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("enable server router: %w", err)
|
||||
return fmt.Errorf("set firewall: %w", err)
|
||||
}
|
||||
|
||||
if e.config.BlockLANAccess {
|
||||
|
@ -10,11 +10,10 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@ -553,41 +552,16 @@ func (w *Watcher) Stop() {
|
||||
w.currentChosenStatus = nil
|
||||
}
|
||||
|
||||
func HandlerFromRoute(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
dnsRouterInteval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface iface.WGIface,
|
||||
dnsServer nbdns.Server,
|
||||
peerStore *peerstore.Store,
|
||||
useNewDNSRoute bool,
|
||||
) RouteHandler {
|
||||
switch handlerType(rt, useNewDNSRoute) {
|
||||
func HandlerFromRoute(params common.HandlerParams) RouteHandler {
|
||||
switch handlerType(params.Route, params.UseNewDNSRoute) {
|
||||
case handlerTypeDnsInterceptor:
|
||||
return dnsinterceptor.New(
|
||||
rt,
|
||||
routeRefCounter,
|
||||
allowedIPsRefCounter,
|
||||
statusRecorder,
|
||||
dnsServer,
|
||||
wgInterface,
|
||||
peerStore,
|
||||
)
|
||||
return dnsinterceptor.New(params)
|
||||
case handlerTypeDynamic:
|
||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
||||
return dynamic.NewRoute(
|
||||
rt,
|
||||
routeRefCounter,
|
||||
allowedIPsRefCounter,
|
||||
dnsRouterInteval,
|
||||
statusRecorder,
|
||||
wgInterface,
|
||||
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
|
||||
)
|
||||
dns := nbdns.NewServiceViaMemory(params.WgInterface)
|
||||
dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())
|
||||
return dynamic.NewRoute(params, dnsAddr)
|
||||
default:
|
||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||
return static.NewRoute(params)
|
||||
}
|
||||
}
|
||||
|
||||
|
28
client/internal/routemanager/common/params.go
Normal file
28
client/internal/routemanager/common/params.go
Normal file
@ -0,0 +1,28 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type HandlerParams struct {
|
||||
Route *route.Route
|
||||
RouteRefCounter *refcounter.RouteRefCounter
|
||||
AllowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||
DnsRouterInteval time.Duration
|
||||
StatusRecorder *peer.Status
|
||||
WgInterface iface.WGIface
|
||||
DnsServer dns.Server
|
||||
PeerStore *peerstore.Store
|
||||
UseNewDNSRoute bool
|
||||
Firewall manager.Manager
|
||||
FakeIPManager *fakeip.FakeIPManager
|
||||
}
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@ -12,11 +13,14 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@ -24,6 +28,11 @@ import (
|
||||
|
||||
type domainMap map[domain.Domain][]netip.Prefix
|
||||
|
||||
type internalDNATer interface {
|
||||
RemoveInternalDNATMapping(netip.Addr) error
|
||||
AddInternalDNATMapping(netip.Addr, netip.Addr) error
|
||||
}
|
||||
|
||||
type wgInterface interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
@ -40,26 +49,22 @@ type DnsInterceptor struct {
|
||||
interceptedDomains domainMap
|
||||
wgInterface wgInterface
|
||||
peerStore *peerstore.Store
|
||||
firewall firewall.Manager
|
||||
fakeIPManager *fakeip.FakeIPManager
|
||||
}
|
||||
|
||||
func New(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
statusRecorder *peer.Status,
|
||||
dnsServer nbdns.Server,
|
||||
wgInterface wgInterface,
|
||||
peerStore *peerstore.Store,
|
||||
) *DnsInterceptor {
|
||||
func New(params common.HandlerParams) *DnsInterceptor {
|
||||
return &DnsInterceptor{
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
statusRecorder: statusRecorder,
|
||||
dnsServer: dnsServer,
|
||||
wgInterface: wgInterface,
|
||||
route: params.Route,
|
||||
routeRefCounter: params.RouteRefCounter,
|
||||
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||
statusRecorder: params.StatusRecorder,
|
||||
dnsServer: params.DnsServer,
|
||||
wgInterface: params.WgInterface,
|
||||
peerStore: params.PeerStore,
|
||||
firewall: params.Firewall,
|
||||
fakeIPManager: params.FakeIPManager,
|
||||
interceptedDomains: make(domainMap),
|
||||
peerStore: peerStore,
|
||||
}
|
||||
}
|
||||
|
||||
@ -78,9 +83,13 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
|
||||
// Routes should use fake IPs
|
||||
routePrefix := d.transformRealToFakePrefix(prefix)
|
||||
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", routePrefix, err))
|
||||
}
|
||||
|
||||
// AllowedIPs should use real IPs
|
||||
if d.currentPeerKey != "" {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
@ -88,8 +97,10 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
||||
}
|
||||
}
|
||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||
|
||||
}
|
||||
|
||||
d.cleanupDNATMappings()
|
||||
|
||||
for _, domain := range d.route.Domains {
|
||||
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||
}
|
||||
@ -102,6 +113,68 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// transformRealToFakePrefix returns fake IP prefix for routes (if DNAT enabled)
|
||||
func (d *DnsInterceptor) transformRealToFakePrefix(realPrefix netip.Prefix) netip.Prefix {
|
||||
if _, hasDNAT := d.internalDnatFw(); !hasDNAT {
|
||||
return realPrefix
|
||||
}
|
||||
|
||||
if fakeIP, ok := d.fakeIPManager.GetFakeIP(realPrefix.Addr()); ok {
|
||||
return netip.PrefixFrom(fakeIP, realPrefix.Bits())
|
||||
}
|
||||
|
||||
return realPrefix
|
||||
}
|
||||
|
||||
// addAllowedIPForPrefix handles the AllowedIPs logic for a single prefix (uses real IPs)
|
||||
func (d *DnsInterceptor) addAllowedIPForPrefix(realPrefix netip.Prefix, peerKey string, domain domain.Domain) error {
|
||||
// AllowedIPs always use real IPs
|
||||
ref, err := d.allowedIPsRefcounter.Increment(realPrefix, peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add allowed IP %s: %v", realPrefix, err)
|
||||
}
|
||||
|
||||
if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
realPrefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRouteAndAllowedIP handles both route and AllowedIPs addition for a prefix
|
||||
func (d *DnsInterceptor) addRouteAndAllowedIP(realPrefix netip.Prefix, domain domain.Domain) error {
|
||||
// Routes use fake IPs (so traffic to fake IPs gets routed to interface)
|
||||
routePrefix := d.transformRealToFakePrefix(realPrefix)
|
||||
if _, err := d.routeRefCounter.Increment(routePrefix, struct{}{}); err != nil {
|
||||
return fmt.Errorf("add route for IP %s: %v", routePrefix, err)
|
||||
}
|
||||
|
||||
// Add to AllowedIPs if we have a current peer (uses real IPs)
|
||||
if d.currentPeerKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return d.addAllowedIPForPrefix(realPrefix, d.currentPeerKey, domain)
|
||||
}
|
||||
|
||||
// removeAllowedIP handles AllowedIPs removal for a prefix (uses real IPs)
|
||||
func (d *DnsInterceptor) removeAllowedIP(realPrefix netip.Prefix) error {
|
||||
if d.currentPeerKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowedIPs use real IPs
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(realPrefix); err != nil {
|
||||
return fmt.Errorf("remove allowed IP %s: %v", realPrefix, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
@ -109,14 +182,9 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
// AllowedIPs use real IPs
|
||||
if err := d.addAllowedIPForPrefix(prefix, peerKey, domain); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -132,6 +200,7 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
||||
var merr *multierror.Error
|
||||
for _, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
// AllowedIPs use real IPs
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
@ -284,6 +353,8 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
||||
log.Errorf("failed to update domain prefixes: %v", err)
|
||||
}
|
||||
|
||||
d.replaceIPsInDNSResponse(r, newPrefixes)
|
||||
}
|
||||
}
|
||||
|
||||
@ -294,6 +365,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// logPrefixChanges handles the logging for prefix changes
|
||||
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
|
||||
if len(toAdd) > 0 {
|
||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toAdd)
|
||||
}
|
||||
if len(toRemove) > 0 && !d.route.KeepRoute {
|
||||
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toRemove)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
@ -302,70 +389,184 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||
|
||||
var merr *multierror.Error
|
||||
var dnatMappings map[netip.Addr]netip.Addr
|
||||
|
||||
// Handle DNAT mappings for new prefixes
|
||||
if _, hasDNAT := d.internalDnatFw(); hasDNAT {
|
||||
dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||
for _, prefix := range toAdd {
|
||||
realIP := prefix.Addr()
|
||||
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
|
||||
dnatMappings[fakeIP] = realIP
|
||||
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
|
||||
} else {
|
||||
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add new prefixes
|
||||
for _, prefix := range toAdd {
|
||||
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if d.currentPeerKey == "" {
|
||||
continue
|
||||
}
|
||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
resolvedDomain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
if err := d.addRouteAndAllowedIP(prefix, resolvedDomain); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
|
||||
d.addDNATMappings(dnatMappings)
|
||||
|
||||
if !d.route.KeepRoute {
|
||||
// Remove old prefixes
|
||||
for _, prefix := range toRemove {
|
||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
|
||||
// Routes use fake IPs
|
||||
routePrefix := d.transformRealToFakePrefix(prefix)
|
||||
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", routePrefix, err))
|
||||
}
|
||||
if d.currentPeerKey != "" {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
// AllowedIPs use real IPs
|
||||
if err := d.removeAllowedIP(prefix); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
|
||||
d.removeDNATMappingsForRealIPs(toRemove)
|
||||
}
|
||||
|
||||
// Update domain prefixes using resolved domain as key
|
||||
// Update domain prefixes using resolved domain as key - store real IPs
|
||||
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||
if d.route.KeepRoute {
|
||||
// replace stored prefixes with old + added
|
||||
// nolint:gocritic
|
||||
newPrefixes = append(oldPrefixes, toAdd...)
|
||||
}
|
||||
d.interceptedDomains[resolvedDomain] = newPrefixes
|
||||
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
||||
|
||||
// Store real IPs for status (user-facing), not fake IPs
|
||||
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
||||
|
||||
if len(toAdd) > 0 {
|
||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toAdd)
|
||||
}
|
||||
if len(toRemove) > 0 && !d.route.KeepRoute {
|
||||
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toRemove)
|
||||
}
|
||||
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// removeDNATMappingsForRealIPs removes DNAT mappings from the firewall for real IP prefixes
|
||||
func (d *DnsInterceptor) removeDNATMappingsForRealIPs(realPrefixes []netip.Prefix) {
|
||||
if len(realPrefixes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
dnatFirewall, ok := d.internalDnatFw()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for _, prefix := range realPrefixes {
|
||||
realIP := prefix.Addr()
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
|
||||
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
|
||||
} else {
|
||||
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// internalDnatFw checks if the firewall supports internal DNAT
|
||||
func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
|
||||
if d.firewall == nil || runtime.GOOS != "android" {
|
||||
return nil, false
|
||||
}
|
||||
fw, ok := d.firewall.(internalDNATer)
|
||||
return fw, ok
|
||||
}
|
||||
|
||||
// addDNATMappings adds DNAT mappings to the firewall
|
||||
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
|
||||
if len(mappings) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
dnatFirewall, ok := d.internalDnatFw()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for fakeIP, realIP := range mappings {
|
||||
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
|
||||
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
|
||||
} else {
|
||||
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// removeDNATMappings removes DNAT mappings from the firewall for removed prefixes
|
||||
func (d *DnsInterceptor) removeDNATMappings(prefixes []netip.Prefix) {
|
||||
if len(prefixes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
dnatFirewall, ok := d.internalDnatFw()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
fakeIP := prefix.Addr()
|
||||
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
|
||||
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
|
||||
} else {
|
||||
log.Debugf("Removed DNAT mapping for: %s", fakeIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupDNATMappings removes all DNAT mappings for this interceptor
|
||||
func (d *DnsInterceptor) cleanupDNATMappings() {
|
||||
if _, ok := d.internalDnatFw(); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for _, prefixes := range d.interceptedDomains {
|
||||
d.removeDNATMappingsForRealIPs(prefixes)
|
||||
}
|
||||
}
|
||||
|
||||
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
|
||||
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
|
||||
if _, ok := d.internalDnatFw(); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Replace A and AAAA records with fake IPs
|
||||
for _, answer := range reply.Answer {
|
||||
switch rr := answer.(type) {
|
||||
case *dns.A:
|
||||
realIP, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
rr.A = fakeIP.AsSlice()
|
||||
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||
}
|
||||
|
||||
case *dns.AAAA:
|
||||
realIP, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||
rr.AAAA = fakeIP.AsSlice()
|
||||
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||
prefixSet := make(map[netip.Prefix]bool)
|
||||
for _, prefix := range oldPrefixes {
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
@ -52,24 +53,16 @@ type Route struct {
|
||||
resolverAddr string
|
||||
}
|
||||
|
||||
func NewRoute(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
interval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface iface.WGIface,
|
||||
resolverAddr string,
|
||||
) *Route {
|
||||
func NewRoute(params common.HandlerParams, resolverAddr string) *Route {
|
||||
return &Route{
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
interval: interval,
|
||||
dynamicDomains: domainMap{},
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
route: params.Route,
|
||||
routeRefCounter: params.RouteRefCounter,
|
||||
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||
interval: params.DnsRouterInteval,
|
||||
statusRecorder: params.StatusRecorder,
|
||||
wgInterface: params.WgInterface,
|
||||
resolverAddr: resolverAddr,
|
||||
dynamicDomains: domainMap{},
|
||||
}
|
||||
}
|
||||
|
||||
|
93
client/internal/routemanager/fakeip/fakeip.go
Normal file
93
client/internal/routemanager/fakeip/fakeip.go
Normal file
@ -0,0 +1,93 @@
|
||||
package fakeip
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FakeIPManager manages allocation of fake IPs from the 240.0.0.0/8 block
|
||||
type FakeIPManager struct {
|
||||
mu sync.Mutex
|
||||
nextIP netip.Addr // Next IP to allocate
|
||||
allocated map[netip.Addr]netip.Addr // real IP -> fake IP
|
||||
fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP
|
||||
baseIP netip.Addr // First usable IP: 240.0.0.1
|
||||
maxIP netip.Addr // Last usable IP: 240.255.255.254
|
||||
}
|
||||
|
||||
// NewManager creates a new fake IP manager using 240.0.0.0/8 block
|
||||
func NewManager() *FakeIPManager {
|
||||
baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1})
|
||||
maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254})
|
||||
|
||||
return &FakeIPManager{
|
||||
nextIP: baseIP,
|
||||
allocated: make(map[netip.Addr]netip.Addr),
|
||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||
baseIP: baseIP,
|
||||
maxIP: maxIP,
|
||||
}
|
||||
}
|
||||
|
||||
// AllocateFakeIP allocates a fake IP for the given real IP
|
||||
// Returns the fake IP, or existing fake IP if already allocated
|
||||
func (f *FakeIPManager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
|
||||
if !realIP.Is4() {
|
||||
return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported")
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if fakeIP, exists := f.allocated[realIP]; exists {
|
||||
return fakeIP, nil
|
||||
}
|
||||
|
||||
startIP := f.nextIP
|
||||
for {
|
||||
currentIP := f.nextIP
|
||||
|
||||
// Advance to next IP, wrapping at boundary
|
||||
if f.nextIP.Compare(f.maxIP) >= 0 {
|
||||
f.nextIP = f.baseIP
|
||||
} else {
|
||||
f.nextIP = f.nextIP.Next()
|
||||
}
|
||||
|
||||
// Check if current IP is available
|
||||
if _, inUse := f.fakeToReal[currentIP]; !inUse {
|
||||
f.allocated[realIP] = currentIP
|
||||
f.fakeToReal[currentIP] = realIP
|
||||
return currentIP, nil
|
||||
}
|
||||
|
||||
// Prevent infinite loop if all IPs exhausted
|
||||
if f.nextIP.Compare(startIP) == 0 {
|
||||
return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetFakeIP returns the fake IP for a real IP if it exists
|
||||
func (f *FakeIPManager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
fakeIP, exists := f.allocated[realIP]
|
||||
return fakeIP, exists
|
||||
}
|
||||
|
||||
// GetRealIP returns the real IP for a fake IP if it exists, otherwise false
|
||||
func (f *FakeIPManager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
realIP, exists := f.fakeToReal[fakeIP]
|
||||
return realIP, exists
|
||||
}
|
||||
|
||||
// GetFakeIPBlock returns the fake IP block used by this manager
|
||||
func (f *FakeIPManager) GetFakeIPBlock() netip.Prefix {
|
||||
return netip.MustParsePrefix("240.0.0.0/8")
|
||||
}
|
242
client/internal/routemanager/fakeip/fakeip_test.go
Normal file
242
client/internal/routemanager/fakeip/fakeip_test.go
Normal file
@ -0,0 +1,242 @@
|
||||
package fakeip
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
if manager.baseIP.String() != "240.0.0.1" {
|
||||
t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String())
|
||||
}
|
||||
|
||||
if manager.maxIP.String() != "240.255.255.254" {
|
||||
t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String())
|
||||
}
|
||||
|
||||
if manager.nextIP.Compare(manager.baseIP) != 0 {
|
||||
t.Errorf("Expected nextIP to start at baseIP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateFakeIP(t *testing.T) {
|
||||
manager := NewManager()
|
||||
realIP := netip.MustParseAddr("8.8.8.8")
|
||||
|
||||
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate fake IP: %v", err)
|
||||
}
|
||||
|
||||
if !fakeIP.Is4() {
|
||||
t.Error("Fake IP should be IPv4")
|
||||
}
|
||||
|
||||
// Check it's in the correct range
|
||||
if fakeIP.As4()[0] != 240 {
|
||||
t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String())
|
||||
}
|
||||
|
||||
// Should return same fake IP for same real IP
|
||||
fakeIP2, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get existing fake IP: %v", err)
|
||||
}
|
||||
|
||||
if fakeIP.Compare(fakeIP2) != 0 {
|
||||
t.Errorf("Expected same fake IP for same real IP, got %s and %s", fakeIP.String(), fakeIP2.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateFakeIPIPv6Rejection(t *testing.T) {
|
||||
manager := NewManager()
|
||||
realIPv6 := netip.MustParseAddr("2001:db8::1")
|
||||
|
||||
_, err := manager.AllocateFakeIP(realIPv6)
|
||||
if err == nil {
|
||||
t.Error("Expected error for IPv6 address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFakeIP(t *testing.T) {
|
||||
manager := NewManager()
|
||||
realIP := netip.MustParseAddr("1.1.1.1")
|
||||
|
||||
// Should not exist initially
|
||||
_, exists := manager.GetFakeIP(realIP)
|
||||
if exists {
|
||||
t.Error("Fake IP should not exist before allocation")
|
||||
}
|
||||
|
||||
// Allocate and check
|
||||
expectedFakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate: %v", err)
|
||||
}
|
||||
|
||||
fakeIP, exists := manager.GetFakeIP(realIP)
|
||||
if !exists {
|
||||
t.Error("Fake IP should exist after allocation")
|
||||
}
|
||||
|
||||
if fakeIP.Compare(expectedFakeIP) != 0 {
|
||||
t.Errorf("Expected %s, got %s", expectedFakeIP.String(), fakeIP.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
func TestMultipleAllocations(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
allocations := make(map[netip.Addr]netip.Addr)
|
||||
|
||||
// Allocate multiple IPs
|
||||
for i := 1; i <= 100; i++ {
|
||||
realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err)
|
||||
}
|
||||
|
||||
// Check for duplicates
|
||||
for _, existingFake := range allocations {
|
||||
if fakeIP.Compare(existingFake) == 0 {
|
||||
t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String())
|
||||
}
|
||||
}
|
||||
|
||||
allocations[realIP] = fakeIP
|
||||
}
|
||||
|
||||
// Verify all allocations can be retrieved
|
||||
for realIP, expectedFake := range allocations {
|
||||
actualFake, exists := manager.GetFakeIP(realIP)
|
||||
if !exists {
|
||||
t.Errorf("Missing allocation for %s", realIP.String())
|
||||
}
|
||||
if actualFake.Compare(expectedFake) != 0 {
|
||||
t.Errorf("Mismatch for %s: expected %s, got %s", realIP.String(), expectedFake.String(), actualFake.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFakeIPBlock(t *testing.T) {
|
||||
manager := NewManager()
|
||||
block := manager.GetFakeIPBlock()
|
||||
|
||||
expected := "240.0.0.0/8"
|
||||
if block.String() != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, block.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
const numGoroutines = 50
|
||||
const allocationsPerGoroutine = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine)
|
||||
|
||||
// Concurrent allocations
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < allocationsPerGoroutine; j++ {
|
||||
realIP := netip.AddrFrom4([4]byte{192, 168, byte(goroutineID), byte(j)})
|
||||
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to allocate in goroutine %d: %v", goroutineID, err)
|
||||
return
|
||||
}
|
||||
results <- fakeIP
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
// Check for duplicates
|
||||
seen := make(map[netip.Addr]bool)
|
||||
count := 0
|
||||
for fakeIP := range results {
|
||||
if seen[fakeIP] {
|
||||
t.Errorf("Duplicate fake IP in concurrent test: %s", fakeIP.String())
|
||||
}
|
||||
seen[fakeIP] = true
|
||||
count++
|
||||
}
|
||||
|
||||
if count != numGoroutines*allocationsPerGoroutine {
|
||||
t.Errorf("Expected %d allocations, got %d", numGoroutines*allocationsPerGoroutine, count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPExhaustion(t *testing.T) {
|
||||
// Create a manager with limited range for testing
|
||||
manager := &FakeIPManager{
|
||||
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||
allocated: make(map[netip.Addr]netip.Addr),
|
||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available
|
||||
}
|
||||
|
||||
// Allocate all available IPs
|
||||
realIPs := []netip.Addr{
|
||||
netip.MustParseAddr("1.0.0.1"),
|
||||
netip.MustParseAddr("1.0.0.2"),
|
||||
netip.MustParseAddr("1.0.0.3"),
|
||||
}
|
||||
|
||||
for _, realIP := range realIPs {
|
||||
_, err := manager.AllocateFakeIP(realIP)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate fake IP: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Try to allocate one more - should fail
|
||||
_, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4"))
|
||||
if err == nil {
|
||||
t.Error("Expected exhaustion error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapAround(t *testing.T) {
|
||||
// Create manager starting near the end of range
|
||||
manager := &FakeIPManager{
|
||||
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||
allocated: make(map[netip.Addr]netip.Addr),
|
||||
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||
}
|
||||
|
||||
// Allocate the last IP
|
||||
fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate first IP: %v", err)
|
||||
}
|
||||
|
||||
if fakeIP1.String() != "240.0.0.254" {
|
||||
t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String())
|
||||
}
|
||||
|
||||
// Next allocation should wrap around to the beginning
|
||||
fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to allocate second IP: %v", err)
|
||||
}
|
||||
|
||||
if fakeIP2.String() != "240.0.0.1" {
|
||||
t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String())
|
||||
}
|
||||
}
|
@ -11,6 +11,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
@ -24,6 +25,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
@ -38,6 +41,10 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
type internalDNATer interface {
|
||||
AddInternalDNATMapping(netip.Addr, netip.Addr) error
|
||||
}
|
||||
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||
@ -49,7 +56,7 @@ type Manager interface {
|
||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
EnableServerRouter(firewall firewall.Manager) error
|
||||
SetFirewall(firewall.Manager) error
|
||||
Stop(stateManager *statemanager.Manager)
|
||||
}
|
||||
|
||||
@ -89,11 +96,13 @@ type DefaultManager struct {
|
||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||
clientRoutes route.HAMap
|
||||
dnsServer dns.Server
|
||||
firewall firewall.Manager
|
||||
peerStore *peerstore.Store
|
||||
useNewDNSRoute bool
|
||||
disableClientRoutes bool
|
||||
disableServerRoutes bool
|
||||
activeRoutes map[route.HAUniqueID]client.RouteHandler
|
||||
fakeIPManager *fakeip.FakeIPManager
|
||||
}
|
||||
|
||||
func NewManager(config ManagerConfig) *DefaultManager {
|
||||
@ -129,6 +138,8 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
||||
}
|
||||
|
||||
if runtime.GOOS == "android" {
|
||||
dm.fakeIPManager = fakeip.NewManager()
|
||||
|
||||
cr := dm.initialClientRoutes(config.InitialRoutes)
|
||||
dm.notifier.SetInitialClientRoutes(cr)
|
||||
}
|
||||
@ -222,16 +233,16 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
||||
return routeselector.NewRouteSelector()
|
||||
}
|
||||
|
||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
if m.disableServerRoutes {
|
||||
// SetFirewall sets the firewall manager for the DefaultManager
|
||||
// Not thread-safe, should be called before starting the manager
|
||||
func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
|
||||
m.firewall = firewall
|
||||
|
||||
if m.disableServerRoutes || firewall == nil {
|
||||
log.Info("server routes are disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
if firewall == nil {
|
||||
return errors.New("firewall manager is not set")
|
||||
}
|
||||
|
||||
var err error
|
||||
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||
if err != nil {
|
||||
@ -299,17 +310,20 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
||||
}
|
||||
|
||||
for id, route := range toAdd {
|
||||
handler := client.HandlerFromRoute(
|
||||
route,
|
||||
m.routeRefCounter,
|
||||
m.allowedIPsRefCounter,
|
||||
m.dnsRouteInterval,
|
||||
m.statusRecorder,
|
||||
m.wgInterface,
|
||||
m.dnsServer,
|
||||
m.peerStore,
|
||||
m.useNewDNSRoute,
|
||||
)
|
||||
params := common.HandlerParams{
|
||||
Route: route,
|
||||
RouteRefCounter: m.routeRefCounter,
|
||||
AllowedIPsRefCounter: m.allowedIPsRefCounter,
|
||||
DnsRouterInteval: m.dnsRouteInterval,
|
||||
StatusRecorder: m.statusRecorder,
|
||||
WgInterface: m.wgInterface,
|
||||
DnsServer: m.dnsServer,
|
||||
PeerStore: m.peerStore,
|
||||
UseNewDNSRoute: m.useNewDNSRoute,
|
||||
Firewall: m.firewall,
|
||||
FakeIPManager: m.fakeIPManager,
|
||||
}
|
||||
handler := client.HandlerFromRoute(params)
|
||||
if err := handler.AddRoute(m.ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
|
||||
continue
|
||||
@ -517,9 +531,27 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro
|
||||
for _, routes := range crMap {
|
||||
rs = append(rs, routes...)
|
||||
}
|
||||
|
||||
fakeIPBlock := m.fakeIPManager.GetFakeIPBlock()
|
||||
id := uuid.NewString()
|
||||
fakeIPRoute := &route.Route{
|
||||
ID: route.ID(id),
|
||||
Network: fakeIPBlock,
|
||||
NetID: route.NetID(id),
|
||||
Peer: m.pubKey,
|
||||
NetworkType: route.IPv4Network,
|
||||
}
|
||||
rs = append(rs, fakeIPRoute)
|
||||
|
||||
return rs
|
||||
}
|
||||
|
||||
// supportsInternalDNAT checks if the firewall supports internal DNAT
|
||||
func (m *DefaultManager) supportsInternalDNAT(fw firewall.Manager) bool {
|
||||
_, ok := fw.(internalDNATer)
|
||||
return ok
|
||||
}
|
||||
|
||||
func isRouteSupported(route *route.Route) bool {
|
||||
if netstack.IsEnabled() || !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
|
||||
return true
|
||||
|
@ -15,7 +15,7 @@ import (
|
||||
// MockManager is the mock instance of a route manager
|
||||
type MockManager struct {
|
||||
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
|
||||
UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
||||
UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
||||
TriggerSelectionFunc func(haMap route.HAMap)
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
GetClientRoutesFunc func() route.HAMap
|
||||
@ -87,7 +87,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
|
||||
|
||||
}
|
||||
|
||||
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
func (m *MockManager) SetFirewall(firewall.Manager) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
|
@ -32,10 +32,6 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
||||
nets := make([]string, 0)
|
||||
for _, r := range clientRoutes {
|
||||
// filter out domain routes
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
nets = append(nets, r.Network.String())
|
||||
}
|
||||
sort.Strings(nets)
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@ -16,11 +17,11 @@ type Route struct {
|
||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||
}
|
||||
|
||||
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
|
||||
func NewRoute(params common.HandlerParams) *Route {
|
||||
return &Route{
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
route: params.Route,
|
||||
routeRefCounter: params.RouteRefCounter,
|
||||
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user