mirror of
https://github.com/netbirdio/netbird.git
synced 2025-07-11 11:55:48 +02:00
Implement dns routes for Android
This commit is contained in:
client
android
firewall
uspfilter
internal
@ -203,10 +203,6 @@ func (c *Client) Networks() *NetworkArray {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if routes[0].IsDynamic() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||||
|
@ -104,6 +104,11 @@ type Manager struct {
|
|||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
|
|
||||||
blockRule firewall.Rule
|
blockRule firewall.Rule
|
||||||
|
|
||||||
|
// Internal 1:1 DNAT
|
||||||
|
dnatEnabled atomic.Bool
|
||||||
|
dnatMappings map[netip.Addr]netip.Addr
|
||||||
|
dnatMutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@ -189,6 +194,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
|
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||||
}
|
}
|
||||||
m.routingEnabled.Store(false)
|
m.routingEnabled.Store(false)
|
||||||
|
|
||||||
@ -519,22 +525,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
// AddDNATRule adds a DNAT rule
|
|
||||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return nil, errNatNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.AddDNATRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteDNATRule deletes a DNAT rule
|
|
||||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return errNatNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSet updates the rule destinations associated with the given set
|
// UpdateSet updates the rule destinations associated with the given set
|
||||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
@ -608,6 +598,14 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
translated := m.translateOutboundDNAT(packetData, d)
|
||||||
|
if translated {
|
||||||
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
|
m.logger.Error("Failed to re-decode packet after DNAT: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
srcIP, dstIP := m.extractIPs(d)
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
if !srcIP.IsValid() {
|
if !srcIP.IsValid() {
|
||||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||||
@ -618,7 +616,6 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// for netflow we keep track even if the firewall is stateless
|
|
||||||
m.trackOutbound(d, srcIP, dstIP, size)
|
m.trackOutbound(d, srcIP, dstIP, size)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
@ -747,9 +744,17 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// For all inbound traffic, first check if it matches a tracked connection.
|
// Step 1: Check connection tracking FIRST (with original addresses)
|
||||||
// This must happen before any other filtering because the packets are statefully tracked.
|
|
||||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||||
|
// Step 2: Apply reverse DNAT for established connections
|
||||||
|
translated := m.translateInboundReverse(packetData, d)
|
||||||
|
if translated {
|
||||||
|
// Re-decode after translation
|
||||||
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
|
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
309
client/firewall/uspfilter/nat.go
Normal file
309
client/firewall/uspfilter/nat.go
Normal file
@ -0,0 +1,309 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ipv4Checksum(header []byte) uint16 {
|
||||||
|
if len(header) < 20 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var sum uint32
|
||||||
|
for i := 0; i < len(header)-1; i += 2 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(header)%2 == 1 {
|
||||||
|
sum += uint32(header[len(header)-1]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
for (sum >> 16) > 0 {
|
||||||
|
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func icmpChecksum(data []byte) uint16 {
|
||||||
|
var sum uint32
|
||||||
|
for i := 0; i < len(data)-1; i += 2 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data)%2 == 1 {
|
||||||
|
sum += uint32(data[len(data)-1]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
for (sum >> 16) > 0 {
|
||||||
|
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
|
||||||
|
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
|
||||||
|
return fmt.Errorf("invalid IP addresses")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.localipmanager.IsLocalIP(translatedAddr) {
|
||||||
|
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMutex.Lock()
|
||||||
|
m.dnatMappings[originalAddr] = translatedAddr
|
||||||
|
if len(m.dnatMappings) == 1 {
|
||||||
|
m.dnatEnabled.Store(true)
|
||||||
|
}
|
||||||
|
m.dnatMutex.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
|
||||||
|
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||||
|
m.dnatMutex.Lock()
|
||||||
|
defer m.dnatMutex.Unlock()
|
||||||
|
|
||||||
|
if _, exists := m.dnatMappings[originalAddr]; !exists {
|
||||||
|
return fmt.Errorf("mapping not found for: %s", originalAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(m.dnatMappings, originalAddr)
|
||||||
|
if len(m.dnatMappings) == 0 {
|
||||||
|
m.dnatEnabled.Store(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDNATTranslation returns the translated address if a mapping exists
|
||||||
|
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return addr, false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMutex.RLock()
|
||||||
|
translated, exists := m.dnatMappings[addr]
|
||||||
|
m.dnatMutex.RUnlock()
|
||||||
|
return translated, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// findReverseDNATMapping finds original address for return traffic
|
||||||
|
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return translatedAddr, false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMutex.RLock()
|
||||||
|
defer m.dnatMutex.RUnlock()
|
||||||
|
|
||||||
|
for original, translated := range m.dnatMappings {
|
||||||
|
if translated == translatedAddr {
|
||||||
|
return original, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return translatedAddr, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// translateOutboundDNAT applies DNAT translation to outbound packets
|
||||||
|
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
_, dstIP := m.extractIPs(d)
|
||||||
|
if !dstIP.IsValid() || !dstIP.Is4() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
|
||||||
|
m.logger.Error("Failed to rewrite packet destination: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// translateInboundReverse applies reverse DNAT to inbound return traffic
|
||||||
|
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP, _ := m.extractIPs(d)
|
||||||
|
if !srcIP.IsValid() || !srcIP.Is4() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
|
||||||
|
m.logger.Error("Failed to rewrite packet source: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewritePacketDestination replaces destination IP in the packet
|
||||||
|
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||||
|
if d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||||
|
return fmt.Errorf("only IPv4 supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
oldDst := make([]byte, 4)
|
||||||
|
copy(oldDst, packetData[16:20])
|
||||||
|
newDst := newIP.AsSlice()
|
||||||
|
|
||||||
|
copy(packetData[16:20], newDst)
|
||||||
|
|
||||||
|
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||||
|
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||||
|
|
||||||
|
if len(d.decoded) > 1 {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst, newDst)
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst, newDst)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewritePacketSource replaces the source IP address in the packet
|
||||||
|
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||||
|
if d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||||
|
return fmt.Errorf("only IPv4 supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
oldSrc := make([]byte, 4)
|
||||||
|
copy(oldSrc, packetData[12:16])
|
||||||
|
newSrc := newIP.AsSlice()
|
||||||
|
|
||||||
|
copy(packetData[12:16], newSrc)
|
||||||
|
|
||||||
|
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||||
|
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||||
|
|
||||||
|
if len(d.decoded) > 1 {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc, newSrc)
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc, newSrc)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||||
|
tcpStart := ipHeaderLen
|
||||||
|
if len(packetData) < tcpStart+18 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
checksumOffset := tcpStart + 16
|
||||||
|
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||||
|
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||||
|
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||||
|
udpStart := ipHeaderLen
|
||||||
|
if len(packetData) < udpStart+8 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
checksumOffset := udpStart + 6
|
||||||
|
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||||
|
|
||||||
|
if oldChecksum == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||||
|
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||||
|
icmpStart := ipHeaderLen
|
||||||
|
if len(packetData) < icmpStart+8 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
icmpData := packetData[icmpStart:]
|
||||||
|
binary.BigEndian.PutUint16(icmpData[2:4], 0)
|
||||||
|
checksum := icmpChecksum(icmpData)
|
||||||
|
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// incrementalUpdate performs incremental checksum update per RFC 1624
|
||||||
|
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||||
|
sum := uint32(^oldChecksum)
|
||||||
|
|
||||||
|
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||||
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
|
||||||
|
}
|
||||||
|
if len(oldBytes)%2 == 1 {
|
||||||
|
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(newBytes)-1; i += 2 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
|
||||||
|
}
|
||||||
|
if len(newBytes)%2 == 1 {
|
||||||
|
sum += uint32(newBytes[len(newBytes)-1]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
for (sum >> 16) > 0 {
|
||||||
|
sum = (sum & 0xffff) + (sum >> 16)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
|
||||||
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return nil, errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.AddDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
|
||||||
|
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||||
|
}
|
@ -488,9 +488,9 @@ func (e *Engine) createFirewall() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) initFirewall() error {
|
func (e *Engine) initFirewall() error {
|
||||||
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
return fmt.Errorf("enable server router: %w", err)
|
return fmt.Errorf("set firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.config.BlockLANAccess {
|
if e.config.BlockLANAccess {
|
||||||
|
@ -10,11 +10,10 @@ import (
|
|||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@ -553,41 +552,16 @@ func (w *Watcher) Stop() {
|
|||||||
w.currentChosenStatus = nil
|
w.currentChosenStatus = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandlerFromRoute(
|
func HandlerFromRoute(params common.HandlerParams) RouteHandler {
|
||||||
rt *route.Route,
|
switch handlerType(params.Route, params.UseNewDNSRoute) {
|
||||||
routeRefCounter *refcounter.RouteRefCounter,
|
|
||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
|
||||||
dnsRouterInteval time.Duration,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
wgInterface iface.WGIface,
|
|
||||||
dnsServer nbdns.Server,
|
|
||||||
peerStore *peerstore.Store,
|
|
||||||
useNewDNSRoute bool,
|
|
||||||
) RouteHandler {
|
|
||||||
switch handlerType(rt, useNewDNSRoute) {
|
|
||||||
case handlerTypeDnsInterceptor:
|
case handlerTypeDnsInterceptor:
|
||||||
return dnsinterceptor.New(
|
return dnsinterceptor.New(params)
|
||||||
rt,
|
|
||||||
routeRefCounter,
|
|
||||||
allowedIPsRefCounter,
|
|
||||||
statusRecorder,
|
|
||||||
dnsServer,
|
|
||||||
wgInterface,
|
|
||||||
peerStore,
|
|
||||||
)
|
|
||||||
case handlerTypeDynamic:
|
case handlerTypeDynamic:
|
||||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
dns := nbdns.NewServiceViaMemory(params.WgInterface)
|
||||||
return dynamic.NewRoute(
|
dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())
|
||||||
rt,
|
return dynamic.NewRoute(params, dnsAddr)
|
||||||
routeRefCounter,
|
|
||||||
allowedIPsRefCounter,
|
|
||||||
dnsRouterInteval,
|
|
||||||
statusRecorder,
|
|
||||||
wgInterface,
|
|
||||||
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
|
|
||||||
)
|
|
||||||
default:
|
default:
|
||||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
return static.NewRoute(params)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
28
client/internal/routemanager/common/params.go
Normal file
28
client/internal/routemanager/common/params.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HandlerParams struct {
|
||||||
|
Route *route.Route
|
||||||
|
RouteRefCounter *refcounter.RouteRefCounter
|
||||||
|
AllowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||||
|
DnsRouterInteval time.Duration
|
||||||
|
StatusRecorder *peer.Status
|
||||||
|
WgInterface iface.WGIface
|
||||||
|
DnsServer dns.Server
|
||||||
|
PeerStore *peerstore.Store
|
||||||
|
UseNewDNSRoute bool
|
||||||
|
Firewall manager.Manager
|
||||||
|
FakeIPManager *fakeip.FakeIPManager
|
||||||
|
}
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@ -12,11 +13,14 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@ -24,6 +28,11 @@ import (
|
|||||||
|
|
||||||
type domainMap map[domain.Domain][]netip.Prefix
|
type domainMap map[domain.Domain][]netip.Prefix
|
||||||
|
|
||||||
|
type internalDNATer interface {
|
||||||
|
RemoveInternalDNATMapping(netip.Addr) error
|
||||||
|
AddInternalDNATMapping(netip.Addr, netip.Addr) error
|
||||||
|
}
|
||||||
|
|
||||||
type wgInterface interface {
|
type wgInterface interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
@ -40,26 +49,22 @@ type DnsInterceptor struct {
|
|||||||
interceptedDomains domainMap
|
interceptedDomains domainMap
|
||||||
wgInterface wgInterface
|
wgInterface wgInterface
|
||||||
peerStore *peerstore.Store
|
peerStore *peerstore.Store
|
||||||
|
firewall firewall.Manager
|
||||||
|
fakeIPManager *fakeip.FakeIPManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(
|
func New(params common.HandlerParams) *DnsInterceptor {
|
||||||
rt *route.Route,
|
|
||||||
routeRefCounter *refcounter.RouteRefCounter,
|
|
||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
dnsServer nbdns.Server,
|
|
||||||
wgInterface wgInterface,
|
|
||||||
peerStore *peerstore.Store,
|
|
||||||
) *DnsInterceptor {
|
|
||||||
return &DnsInterceptor{
|
return &DnsInterceptor{
|
||||||
route: rt,
|
route: params.Route,
|
||||||
routeRefCounter: routeRefCounter,
|
routeRefCounter: params.RouteRefCounter,
|
||||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: params.StatusRecorder,
|
||||||
dnsServer: dnsServer,
|
dnsServer: params.DnsServer,
|
||||||
wgInterface: wgInterface,
|
wgInterface: params.WgInterface,
|
||||||
|
peerStore: params.PeerStore,
|
||||||
|
firewall: params.Firewall,
|
||||||
|
fakeIPManager: params.FakeIPManager,
|
||||||
interceptedDomains: make(domainMap),
|
interceptedDomains: make(domainMap),
|
||||||
peerStore: peerStore,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,9 +83,13 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for domain, prefixes := range d.interceptedDomains {
|
for domain, prefixes := range d.interceptedDomains {
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
// Routes should use fake IPs
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
|
routePrefix := d.transformRealToFakePrefix(prefix)
|
||||||
|
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", routePrefix, err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowedIPs should use real IPs
|
||||||
if d.currentPeerKey != "" {
|
if d.currentPeerKey != "" {
|
||||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
@ -88,8 +97,10 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.cleanupDNATMappings()
|
||||||
|
|
||||||
for _, domain := range d.route.Domains {
|
for _, domain := range d.route.Domains {
|
||||||
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||||
}
|
}
|
||||||
@ -102,6 +113,68 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// transformRealToFakePrefix returns fake IP prefix for routes (if DNAT enabled)
|
||||||
|
func (d *DnsInterceptor) transformRealToFakePrefix(realPrefix netip.Prefix) netip.Prefix {
|
||||||
|
if _, hasDNAT := d.internalDnatFw(); !hasDNAT {
|
||||||
|
return realPrefix
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP, ok := d.fakeIPManager.GetFakeIP(realPrefix.Addr()); ok {
|
||||||
|
return netip.PrefixFrom(fakeIP, realPrefix.Bits())
|
||||||
|
}
|
||||||
|
|
||||||
|
return realPrefix
|
||||||
|
}
|
||||||
|
|
||||||
|
// addAllowedIPForPrefix handles the AllowedIPs logic for a single prefix (uses real IPs)
|
||||||
|
func (d *DnsInterceptor) addAllowedIPForPrefix(realPrefix netip.Prefix, peerKey string, domain domain.Domain) error {
|
||||||
|
// AllowedIPs always use real IPs
|
||||||
|
ref, err := d.allowedIPsRefcounter.Increment(realPrefix, peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("add allowed IP %s: %v", realPrefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ref.Count > 1 && ref.Out != peerKey {
|
||||||
|
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||||
|
realPrefix.Addr(),
|
||||||
|
domain.SafeString(),
|
||||||
|
ref.Out,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRouteAndAllowedIP handles both route and AllowedIPs addition for a prefix
|
||||||
|
func (d *DnsInterceptor) addRouteAndAllowedIP(realPrefix netip.Prefix, domain domain.Domain) error {
|
||||||
|
// Routes use fake IPs (so traffic to fake IPs gets routed to interface)
|
||||||
|
routePrefix := d.transformRealToFakePrefix(realPrefix)
|
||||||
|
if _, err := d.routeRefCounter.Increment(routePrefix, struct{}{}); err != nil {
|
||||||
|
return fmt.Errorf("add route for IP %s: %v", routePrefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to AllowedIPs if we have a current peer (uses real IPs)
|
||||||
|
if d.currentPeerKey == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.addAllowedIPForPrefix(realPrefix, d.currentPeerKey, domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeAllowedIP handles AllowedIPs removal for a prefix (uses real IPs)
|
||||||
|
func (d *DnsInterceptor) removeAllowedIP(realPrefix netip.Prefix) error {
|
||||||
|
if d.currentPeerKey == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowedIPs use real IPs
|
||||||
|
if _, err := d.allowedIPsRefcounter.Decrement(realPrefix); err != nil {
|
||||||
|
return fmt.Errorf("remove allowed IP %s: %v", realPrefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
defer d.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
@ -109,14 +182,9 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for domain, prefixes := range d.interceptedDomains {
|
for domain, prefixes := range d.interceptedDomains {
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
// AllowedIPs use real IPs
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
if err := d.addAllowedIPForPrefix(prefix, peerKey, domain); err != nil {
|
||||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
merr = multierror.Append(merr, err)
|
||||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
|
||||||
prefix.Addr(),
|
|
||||||
domain.SafeString(),
|
|
||||||
ref.Out,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -132,6 +200,7 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for _, prefixes := range d.interceptedDomains {
|
for _, prefixes := range d.interceptedDomains {
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
|
// AllowedIPs use real IPs
|
||||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
}
|
}
|
||||||
@ -284,6 +353,8 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
||||||
log.Errorf("failed to update domain prefixes: %v", err)
|
log.Errorf("failed to update domain prefixes: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.replaceIPsInDNSResponse(r, newPrefixes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -294,6 +365,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// logPrefixChanges handles the logging for prefix changes
|
||||||
|
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
|
||||||
|
if len(toAdd) > 0 {
|
||||||
|
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||||
|
resolvedDomain.SafeString(),
|
||||||
|
originalDomain.SafeString(),
|
||||||
|
toAdd)
|
||||||
|
}
|
||||||
|
if len(toRemove) > 0 && !d.route.KeepRoute {
|
||||||
|
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||||
|
resolvedDomain.SafeString(),
|
||||||
|
originalDomain.SafeString(),
|
||||||
|
toRemove)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
defer d.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
@ -302,70 +389,184 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
|||||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
var dnatMappings map[netip.Addr]netip.Addr
|
||||||
|
|
||||||
|
// Handle DNAT mappings for new prefixes
|
||||||
|
if _, hasDNAT := d.internalDnatFw(); hasDNAT {
|
||||||
|
dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||||
|
for _, prefix := range toAdd {
|
||||||
|
realIP := prefix.Addr()
|
||||||
|
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
|
||||||
|
dnatMappings[fakeIP] = realIP
|
||||||
|
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
|
||||||
|
} else {
|
||||||
|
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Add new prefixes
|
// Add new prefixes
|
||||||
for _, prefix := range toAdd {
|
for _, prefix := range toAdd {
|
||||||
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
|
if err := d.addRouteAndAllowedIP(prefix, resolvedDomain); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
|
merr = multierror.Append(merr, err)
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if d.currentPeerKey == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
|
||||||
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
|
||||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
|
||||||
prefix.Addr(),
|
|
||||||
resolvedDomain.SafeString(),
|
|
||||||
ref.Out,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.addDNATMappings(dnatMappings)
|
||||||
|
|
||||||
if !d.route.KeepRoute {
|
if !d.route.KeepRoute {
|
||||||
// Remove old prefixes
|
// Remove old prefixes
|
||||||
for _, prefix := range toRemove {
|
for _, prefix := range toRemove {
|
||||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
// Routes use fake IPs
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
|
routePrefix := d.transformRealToFakePrefix(prefix)
|
||||||
|
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", routePrefix, err))
|
||||||
}
|
}
|
||||||
if d.currentPeerKey != "" {
|
// AllowedIPs use real IPs
|
||||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
if err := d.removeAllowedIP(prefix); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
merr = multierror.Append(merr, err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.removeDNATMappingsForRealIPs(toRemove)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update domain prefixes using resolved domain as key
|
// Update domain prefixes using resolved domain as key - store real IPs
|
||||||
if len(toAdd) > 0 || len(toRemove) > 0 {
|
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||||
if d.route.KeepRoute {
|
if d.route.KeepRoute {
|
||||||
// replace stored prefixes with old + added
|
|
||||||
// nolint:gocritic
|
// nolint:gocritic
|
||||||
newPrefixes = append(oldPrefixes, toAdd...)
|
newPrefixes = append(oldPrefixes, toAdd...)
|
||||||
}
|
}
|
||||||
d.interceptedDomains[resolvedDomain] = newPrefixes
|
d.interceptedDomains[resolvedDomain] = newPrefixes
|
||||||
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
||||||
|
|
||||||
|
// Store real IPs for status (user-facing), not fake IPs
|
||||||
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
||||||
|
|
||||||
if len(toAdd) > 0 {
|
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
|
||||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
|
||||||
resolvedDomain.SafeString(),
|
|
||||||
originalDomain.SafeString(),
|
|
||||||
toAdd)
|
|
||||||
}
|
|
||||||
if len(toRemove) > 0 && !d.route.KeepRoute {
|
|
||||||
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
|
||||||
resolvedDomain.SafeString(),
|
|
||||||
originalDomain.SafeString(),
|
|
||||||
toRemove)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeDNATMappingsForRealIPs removes DNAT mappings from the firewall for real IP prefixes
|
||||||
|
func (d *DnsInterceptor) removeDNATMappingsForRealIPs(realPrefixes []netip.Prefix) {
|
||||||
|
if len(realPrefixes) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dnatFirewall, ok := d.internalDnatFw()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range realPrefixes {
|
||||||
|
realIP := prefix.Addr()
|
||||||
|
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||||
|
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
|
||||||
|
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// internalDnatFw checks if the firewall supports internal DNAT
|
||||||
|
func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
|
||||||
|
if d.firewall == nil || runtime.GOOS != "android" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
fw, ok := d.firewall.(internalDNATer)
|
||||||
|
return fw, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// addDNATMappings adds DNAT mappings to the firewall
|
||||||
|
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
|
||||||
|
if len(mappings) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dnatFirewall, ok := d.internalDnatFw()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for fakeIP, realIP := range mappings {
|
||||||
|
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
|
||||||
|
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeDNATMappings removes DNAT mappings from the firewall for removed prefixes
|
||||||
|
func (d *DnsInterceptor) removeDNATMappings(prefixes []netip.Prefix) {
|
||||||
|
if len(prefixes) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dnatFirewall, ok := d.internalDnatFw()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
fakeIP := prefix.Addr()
|
||||||
|
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
|
||||||
|
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Removed DNAT mapping for: %s", fakeIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupDNATMappings removes all DNAT mappings for this interceptor
|
||||||
|
func (d *DnsInterceptor) cleanupDNATMappings() {
|
||||||
|
if _, ok := d.internalDnatFw(); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefixes := range d.interceptedDomains {
|
||||||
|
d.removeDNATMappingsForRealIPs(prefixes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
|
||||||
|
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
|
||||||
|
if _, ok := d.internalDnatFw(); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace A and AAAA records with fake IPs
|
||||||
|
for _, answer := range reply.Answer {
|
||||||
|
switch rr := answer.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
realIP, ok := netip.AddrFromSlice(rr.A)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||||
|
rr.A = fakeIP.AsSlice()
|
||||||
|
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
case *dns.AAAA:
|
||||||
|
realIP, ok := netip.AddrFromSlice(rr.AAAA)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||||
|
rr.AAAA = fakeIP.AsSlice()
|
||||||
|
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||||
prefixSet := make(map[netip.Prefix]bool)
|
prefixSet := make(map[netip.Prefix]bool)
|
||||||
for _, prefix := range oldPrefixes {
|
for _, prefix := range oldPrefixes {
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
@ -52,24 +53,16 @@ type Route struct {
|
|||||||
resolverAddr string
|
resolverAddr string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRoute(
|
func NewRoute(params common.HandlerParams, resolverAddr string) *Route {
|
||||||
rt *route.Route,
|
|
||||||
routeRefCounter *refcounter.RouteRefCounter,
|
|
||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
|
||||||
interval time.Duration,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
wgInterface iface.WGIface,
|
|
||||||
resolverAddr string,
|
|
||||||
) *Route {
|
|
||||||
return &Route{
|
return &Route{
|
||||||
route: rt,
|
route: params.Route,
|
||||||
routeRefCounter: routeRefCounter,
|
routeRefCounter: params.RouteRefCounter,
|
||||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||||
interval: interval,
|
interval: params.DnsRouterInteval,
|
||||||
dynamicDomains: domainMap{},
|
statusRecorder: params.StatusRecorder,
|
||||||
statusRecorder: statusRecorder,
|
wgInterface: params.WgInterface,
|
||||||
wgInterface: wgInterface,
|
|
||||||
resolverAddr: resolverAddr,
|
resolverAddr: resolverAddr,
|
||||||
|
dynamicDomains: domainMap{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
93
client/internal/routemanager/fakeip/fakeip.go
Normal file
93
client/internal/routemanager/fakeip/fakeip.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
package fakeip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FakeIPManager manages allocation of fake IPs from the 240.0.0.0/8 block
|
||||||
|
type FakeIPManager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
nextIP netip.Addr // Next IP to allocate
|
||||||
|
allocated map[netip.Addr]netip.Addr // real IP -> fake IP
|
||||||
|
fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP
|
||||||
|
baseIP netip.Addr // First usable IP: 240.0.0.1
|
||||||
|
maxIP netip.Addr // Last usable IP: 240.255.255.254
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new fake IP manager using 240.0.0.0/8 block
|
||||||
|
func NewManager() *FakeIPManager {
|
||||||
|
baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1})
|
||||||
|
maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254})
|
||||||
|
|
||||||
|
return &FakeIPManager{
|
||||||
|
nextIP: baseIP,
|
||||||
|
allocated: make(map[netip.Addr]netip.Addr),
|
||||||
|
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||||
|
baseIP: baseIP,
|
||||||
|
maxIP: maxIP,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllocateFakeIP allocates a fake IP for the given real IP
|
||||||
|
// Returns the fake IP, or existing fake IP if already allocated
|
||||||
|
func (f *FakeIPManager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
|
||||||
|
if !realIP.Is4() {
|
||||||
|
return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
|
||||||
|
if fakeIP, exists := f.allocated[realIP]; exists {
|
||||||
|
return fakeIP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
startIP := f.nextIP
|
||||||
|
for {
|
||||||
|
currentIP := f.nextIP
|
||||||
|
|
||||||
|
// Advance to next IP, wrapping at boundary
|
||||||
|
if f.nextIP.Compare(f.maxIP) >= 0 {
|
||||||
|
f.nextIP = f.baseIP
|
||||||
|
} else {
|
||||||
|
f.nextIP = f.nextIP.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if current IP is available
|
||||||
|
if _, inUse := f.fakeToReal[currentIP]; !inUse {
|
||||||
|
f.allocated[realIP] = currentIP
|
||||||
|
f.fakeToReal[currentIP] = realIP
|
||||||
|
return currentIP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prevent infinite loop if all IPs exhausted
|
||||||
|
if f.nextIP.Compare(startIP) == 0 {
|
||||||
|
return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFakeIP returns the fake IP for a real IP if it exists
|
||||||
|
func (f *FakeIPManager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
|
||||||
|
fakeIP, exists := f.allocated[realIP]
|
||||||
|
return fakeIP, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRealIP returns the real IP for a fake IP if it exists, otherwise false
|
||||||
|
func (f *FakeIPManager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
|
||||||
|
realIP, exists := f.fakeToReal[fakeIP]
|
||||||
|
return realIP, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFakeIPBlock returns the fake IP block used by this manager
|
||||||
|
func (f *FakeIPManager) GetFakeIPBlock() netip.Prefix {
|
||||||
|
return netip.MustParsePrefix("240.0.0.0/8")
|
||||||
|
}
|
242
client/internal/routemanager/fakeip/fakeip_test.go
Normal file
242
client/internal/routemanager/fakeip/fakeip_test.go
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
package fakeip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewManager(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
|
||||||
|
if manager.baseIP.String() != "240.0.0.1" {
|
||||||
|
t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.maxIP.String() != "240.255.255.254" {
|
||||||
|
t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.nextIP.Compare(manager.baseIP) != 0 {
|
||||||
|
t.Errorf("Expected nextIP to start at baseIP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllocateFakeIP(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
realIP := netip.MustParseAddr("8.8.8.8")
|
||||||
|
|
||||||
|
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate fake IP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fakeIP.Is4() {
|
||||||
|
t.Error("Fake IP should be IPv4")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check it's in the correct range
|
||||||
|
if fakeIP.As4()[0] != 240 {
|
||||||
|
t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return same fake IP for same real IP
|
||||||
|
fakeIP2, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get existing fake IP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP.Compare(fakeIP2) != 0 {
|
||||||
|
t.Errorf("Expected same fake IP for same real IP, got %s and %s", fakeIP.String(), fakeIP2.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllocateFakeIPIPv6Rejection(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
realIPv6 := netip.MustParseAddr("2001:db8::1")
|
||||||
|
|
||||||
|
_, err := manager.AllocateFakeIP(realIPv6)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for IPv6 address")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFakeIP(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
realIP := netip.MustParseAddr("1.1.1.1")
|
||||||
|
|
||||||
|
// Should not exist initially
|
||||||
|
_, exists := manager.GetFakeIP(realIP)
|
||||||
|
if exists {
|
||||||
|
t.Error("Fake IP should not exist before allocation")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate and check
|
||||||
|
expectedFakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fakeIP, exists := manager.GetFakeIP(realIP)
|
||||||
|
if !exists {
|
||||||
|
t.Error("Fake IP should exist after allocation")
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP.Compare(expectedFakeIP) != 0 {
|
||||||
|
t.Errorf("Expected %s, got %s", expectedFakeIP.String(), fakeIP.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
func TestMultipleAllocations(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
|
||||||
|
allocations := make(map[netip.Addr]netip.Addr)
|
||||||
|
|
||||||
|
// Allocate multiple IPs
|
||||||
|
for i := 1; i <= 100; i++ {
|
||||||
|
realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
|
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for duplicates
|
||||||
|
for _, existingFake := range allocations {
|
||||||
|
if fakeIP.Compare(existingFake) == 0 {
|
||||||
|
t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
allocations[realIP] = fakeIP
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all allocations can be retrieved
|
||||||
|
for realIP, expectedFake := range allocations {
|
||||||
|
actualFake, exists := manager.GetFakeIP(realIP)
|
||||||
|
if !exists {
|
||||||
|
t.Errorf("Missing allocation for %s", realIP.String())
|
||||||
|
}
|
||||||
|
if actualFake.Compare(expectedFake) != 0 {
|
||||||
|
t.Errorf("Mismatch for %s: expected %s, got %s", realIP.String(), expectedFake.String(), actualFake.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFakeIPBlock(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
block := manager.GetFakeIPBlock()
|
||||||
|
|
||||||
|
expected := "240.0.0.0/8"
|
||||||
|
if block.String() != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, block.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentAccess(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
|
||||||
|
const numGoroutines = 50
|
||||||
|
const allocationsPerGoroutine = 10
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine)
|
||||||
|
|
||||||
|
// Concurrent allocations
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(goroutineID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < allocationsPerGoroutine; j++ {
|
||||||
|
realIP := netip.AddrFrom4([4]byte{192, 168, byte(goroutineID), byte(j)})
|
||||||
|
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to allocate in goroutine %d: %v", goroutineID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
results <- fakeIP
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(results)
|
||||||
|
|
||||||
|
// Check for duplicates
|
||||||
|
seen := make(map[netip.Addr]bool)
|
||||||
|
count := 0
|
||||||
|
for fakeIP := range results {
|
||||||
|
if seen[fakeIP] {
|
||||||
|
t.Errorf("Duplicate fake IP in concurrent test: %s", fakeIP.String())
|
||||||
|
}
|
||||||
|
seen[fakeIP] = true
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
|
||||||
|
if count != numGoroutines*allocationsPerGoroutine {
|
||||||
|
t.Errorf("Expected %d allocations, got %d", numGoroutines*allocationsPerGoroutine, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPExhaustion(t *testing.T) {
|
||||||
|
// Create a manager with limited range for testing
|
||||||
|
manager := &FakeIPManager{
|
||||||
|
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||||
|
allocated: make(map[netip.Addr]netip.Addr),
|
||||||
|
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||||
|
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||||
|
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate all available IPs
|
||||||
|
realIPs := []netip.Addr{
|
||||||
|
netip.MustParseAddr("1.0.0.1"),
|
||||||
|
netip.MustParseAddr("1.0.0.2"),
|
||||||
|
netip.MustParseAddr("1.0.0.3"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, realIP := range realIPs {
|
||||||
|
_, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate fake IP: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to allocate one more - should fail
|
||||||
|
_, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4"))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected exhaustion error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapAround(t *testing.T) {
|
||||||
|
// Create manager starting near the end of range
|
||||||
|
manager := &FakeIPManager{
|
||||||
|
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||||
|
allocated: make(map[netip.Addr]netip.Addr),
|
||||||
|
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||||
|
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||||
|
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate the last IP
|
||||||
|
fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate first IP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP1.String() != "240.0.0.254" {
|
||||||
|
t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next allocation should wrap around to the beginning
|
||||||
|
fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate second IP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP2.String() != "240.0.0.1" {
|
||||||
|
t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String())
|
||||||
|
}
|
||||||
|
}
|
@ -11,6 +11,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
@ -24,6 +25,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
@ -38,6 +41,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type internalDNATer interface {
|
||||||
|
AddInternalDNATMapping(netip.Addr, netip.Addr) error
|
||||||
|
}
|
||||||
|
|
||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||||
@ -49,7 +56,7 @@ type Manager interface {
|
|||||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
EnableServerRouter(firewall firewall.Manager) error
|
SetFirewall(firewall.Manager) error
|
||||||
Stop(stateManager *statemanager.Manager)
|
Stop(stateManager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,11 +96,13 @@ type DefaultManager struct {
|
|||||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||||
clientRoutes route.HAMap
|
clientRoutes route.HAMap
|
||||||
dnsServer dns.Server
|
dnsServer dns.Server
|
||||||
|
firewall firewall.Manager
|
||||||
peerStore *peerstore.Store
|
peerStore *peerstore.Store
|
||||||
useNewDNSRoute bool
|
useNewDNSRoute bool
|
||||||
disableClientRoutes bool
|
disableClientRoutes bool
|
||||||
disableServerRoutes bool
|
disableServerRoutes bool
|
||||||
activeRoutes map[route.HAUniqueID]client.RouteHandler
|
activeRoutes map[route.HAUniqueID]client.RouteHandler
|
||||||
|
fakeIPManager *fakeip.FakeIPManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(config ManagerConfig) *DefaultManager {
|
func NewManager(config ManagerConfig) *DefaultManager {
|
||||||
@ -129,6 +138,8 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if runtime.GOOS == "android" {
|
if runtime.GOOS == "android" {
|
||||||
|
dm.fakeIPManager = fakeip.NewManager()
|
||||||
|
|
||||||
cr := dm.initialClientRoutes(config.InitialRoutes)
|
cr := dm.initialClientRoutes(config.InitialRoutes)
|
||||||
dm.notifier.SetInitialClientRoutes(cr)
|
dm.notifier.SetInitialClientRoutes(cr)
|
||||||
}
|
}
|
||||||
@ -222,16 +233,16 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
|||||||
return routeselector.NewRouteSelector()
|
return routeselector.NewRouteSelector()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
// SetFirewall sets the firewall manager for the DefaultManager
|
||||||
if m.disableServerRoutes {
|
// Not thread-safe, should be called before starting the manager
|
||||||
|
func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
|
||||||
|
m.firewall = firewall
|
||||||
|
|
||||||
|
if m.disableServerRoutes || firewall == nil {
|
||||||
log.Info("server routes are disabled")
|
log.Info("server routes are disabled")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if firewall == nil {
|
|
||||||
return errors.New("firewall manager is not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -299,17 +310,20 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for id, route := range toAdd {
|
for id, route := range toAdd {
|
||||||
handler := client.HandlerFromRoute(
|
params := common.HandlerParams{
|
||||||
route,
|
Route: route,
|
||||||
m.routeRefCounter,
|
RouteRefCounter: m.routeRefCounter,
|
||||||
m.allowedIPsRefCounter,
|
AllowedIPsRefCounter: m.allowedIPsRefCounter,
|
||||||
m.dnsRouteInterval,
|
DnsRouterInteval: m.dnsRouteInterval,
|
||||||
m.statusRecorder,
|
StatusRecorder: m.statusRecorder,
|
||||||
m.wgInterface,
|
WgInterface: m.wgInterface,
|
||||||
m.dnsServer,
|
DnsServer: m.dnsServer,
|
||||||
m.peerStore,
|
PeerStore: m.peerStore,
|
||||||
m.useNewDNSRoute,
|
UseNewDNSRoute: m.useNewDNSRoute,
|
||||||
)
|
Firewall: m.firewall,
|
||||||
|
FakeIPManager: m.fakeIPManager,
|
||||||
|
}
|
||||||
|
handler := client.HandlerFromRoute(params)
|
||||||
if err := handler.AddRoute(m.ctx); err != nil {
|
if err := handler.AddRoute(m.ctx); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
|
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
|
||||||
continue
|
continue
|
||||||
@ -517,9 +531,27 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro
|
|||||||
for _, routes := range crMap {
|
for _, routes := range crMap {
|
||||||
rs = append(rs, routes...)
|
rs = append(rs, routes...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fakeIPBlock := m.fakeIPManager.GetFakeIPBlock()
|
||||||
|
id := uuid.NewString()
|
||||||
|
fakeIPRoute := &route.Route{
|
||||||
|
ID: route.ID(id),
|
||||||
|
Network: fakeIPBlock,
|
||||||
|
NetID: route.NetID(id),
|
||||||
|
Peer: m.pubKey,
|
||||||
|
NetworkType: route.IPv4Network,
|
||||||
|
}
|
||||||
|
rs = append(rs, fakeIPRoute)
|
||||||
|
|
||||||
return rs
|
return rs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// supportsInternalDNAT checks if the firewall supports internal DNAT
|
||||||
|
func (m *DefaultManager) supportsInternalDNAT(fw firewall.Manager) bool {
|
||||||
|
_, ok := fw.(internalDNATer)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
func isRouteSupported(route *route.Route) bool {
|
func isRouteSupported(route *route.Route) bool {
|
||||||
if netstack.IsEnabled() || !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
|
if netstack.IsEnabled() || !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
|
||||||
return true
|
return true
|
||||||
|
@ -15,7 +15,7 @@ import (
|
|||||||
// MockManager is the mock instance of a route manager
|
// MockManager is the mock instance of a route manager
|
||||||
type MockManager struct {
|
type MockManager struct {
|
||||||
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
|
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
|
||||||
UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
||||||
TriggerSelectionFunc func(haMap route.HAMap)
|
TriggerSelectionFunc func(haMap route.HAMap)
|
||||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||||
GetClientRoutesFunc func() route.HAMap
|
GetClientRoutesFunc func() route.HAMap
|
||||||
@ -87,7 +87,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
|
func (m *MockManager) SetFirewall(firewall.Manager) error {
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,10 +32,6 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
|||||||
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
||||||
nets := make([]string, 0)
|
nets := make([]string, 0)
|
||||||
for _, r := range clientRoutes {
|
for _, r := range clientRoutes {
|
||||||
// filter out domain routes
|
|
||||||
if r.IsDynamic() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
nets = append(nets, r.Network.String())
|
nets = append(nets, r.Network.String())
|
||||||
}
|
}
|
||||||
sort.Strings(nets)
|
sort.Strings(nets)
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@ -16,11 +17,11 @@ type Route struct {
|
|||||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
|
func NewRoute(params common.HandlerParams) *Route {
|
||||||
return &Route{
|
return &Route{
|
||||||
route: rt,
|
route: params.Route,
|
||||||
routeRefCounter: routeRefCounter,
|
routeRefCounter: params.RouteRefCounter,
|
||||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user