diff --git a/client/android/client.go b/client/android/client.go index 3b8a5bd0f..a17439696 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -203,8 +203,10 @@ func (c *Client) Networks() *NetworkArray { continue } - if routes[0].IsDynamic() { - continue + r := routes[0] + netStr := r.Network.String() + if r.IsDynamic() { + netStr = r.Domains.SafeString() } peer, err := c.recorder.GetPeer(routes[0].Peer) @@ -214,7 +216,7 @@ func (c *Client) Networks() *NetworkArray { } network := Network{ Name: string(id), - Network: routes[0].Network.String(), + Network: netStr, Peer: peer.FQDN, Status: peer.ConnStatus.String(), } diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/filter.go similarity index 96% rename from client/firewall/uspfilter/uspfilter.go rename to client/firewall/uspfilter/filter.go index dcff92c61..7120d7d64 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/filter.go @@ -104,6 +104,12 @@ type Manager struct { flowLogger nftypes.FlowLogger blockRule firewall.Rule + + // Internal 1:1 DNAT + dnatEnabled atomic.Bool + dnatMappings map[netip.Addr]netip.Addr + dnatMutex sync.RWMutex + dnatBiMap *biDNATMap } // decoder for packages @@ -189,6 +195,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe flowLogger: flowLogger, netstack: netstack.IsEnabled(), localForwarding: enableLocalForwarding, + dnatMappings: make(map[netip.Addr]netip.Addr), } m.routingEnabled.Store(false) @@ -519,22 +526,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } -// AddDNATRule adds a DNAT rule -func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { - if m.nativeFirewall == nil { - return nil, errNatNotSupported - } - return m.nativeFirewall.AddDNATRule(rule) -} - -// DeleteDNATRule deletes a DNAT rule -func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { - if m.nativeFirewall == nil { - return errNatNotSupported - } - return m.nativeFirewall.DeleteDNATRule(rule) -} - // UpdateSet updates the rule destinations associated with the given set // by merging the existing prefixes with the new ones, then deduplicating. func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { @@ -581,14 +572,14 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nil } -// DropOutgoing filter outgoing packets -func (m *Manager) DropOutgoing(packetData []byte, size int) bool { - return m.processOutgoingHooks(packetData, size) +// FilterOutBound filters outgoing packets +func (m *Manager) FilterOutbound(packetData []byte, size int) bool { + return m.filterOutbound(packetData, size) } -// DropIncoming filter incoming packets -func (m *Manager) DropIncoming(packetData []byte, size int) bool { - return m.dropFilter(packetData, size) +// FilterInbound filters incoming packets +func (m *Manager) FilterInbound(packetData []byte, size int) bool { + return m.filterInbound(packetData, size) } // UpdateLocalIPs updates the list of local IPs @@ -596,7 +587,7 @@ func (m *Manager) UpdateLocalIPs() error { return m.localipmanager.UpdateLocalIPs(m.wgIface) } -func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool { +func (m *Manager) filterOutbound(packetData []byte, size int) bool { d := m.decoders.Get().(*decoder) defer m.decoders.Put(d) @@ -618,8 +609,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool { return true } - // for netflow we keep track even if the firewall is stateless m.trackOutbound(d, srcIP, dstIP, size) + m.translateOutboundDNAT(packetData, d) return false } @@ -723,9 +714,9 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte return false } -// dropFilter implements filtering logic for incoming packets. +// filterInbound implements filtering logic for incoming packets. // If it returns true, the packet should be dropped. -func (m *Manager) dropFilter(packetData []byte, size int) bool { +func (m *Manager) filterInbound(packetData []byte, size int) bool { d := m.decoders.Get().(*decoder) defer m.decoders.Put(d) @@ -747,8 +738,15 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool { return false } - // For all inbound traffic, first check if it matches a tracked connection. - // This must happen before any other filtering because the packets are statefully tracked. + if translated := m.translateInboundReverse(packetData, d); translated { + // Re-decode after translation to get original addresses + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err) + return true + } + srcIP, dstIP = m.extractIPs(d) + } + if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) { return false } diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/filter_bench_test.go similarity index 94% rename from client/firewall/uspfilter/uspfilter_bench_test.go rename to client/firewall/uspfilter/filter_bench_test.go index c03e60640..0cffcc1a7 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/filter_bench_test.go @@ -188,13 +188,13 @@ func BenchmarkCoreFiltering(b *testing.B) { // For stateful scenarios, establish the connection if sc.stateful { - manager.processOutgoingHooks(outbound, 0) + manager.filterOutbound(outbound, 0) } // Measure inbound packet processing b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(inbound, 0) + manager.filterInbound(inbound, 0) } }) } @@ -220,7 +220,7 @@ func BenchmarkStateScaling(b *testing.B) { for i := 0; i < count; i++ { outbound := generatePacket(b, srcIPs[i], dstIPs[i], uint16(1024+i), 80, layers.IPProtocolTCP) - manager.processOutgoingHooks(outbound, 0) + manager.filterOutbound(outbound, 0) } // Test packet @@ -228,11 +228,11 @@ func BenchmarkStateScaling(b *testing.B) { testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) // First establish our test connection - manager.processOutgoingHooks(testOut, 0) + manager.filterOutbound(testOut, 0) b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(testIn, 0) + manager.filterInbound(testIn, 0) } }) } @@ -263,12 +263,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) if sc.established { - manager.processOutgoingHooks(outbound, 0) + manager.filterOutbound(outbound, 0) } b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(inbound, 0) + manager.filterInbound(inbound, 0) } }) } @@ -426,25 +426,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { // For stateful cases and established connections if !strings.Contains(sc.name, "allow_non_wg") || (strings.Contains(sc.state, "established") || sc.state == "post_handshake") { - manager.processOutgoingHooks(outbound, 0) + manager.filterOutbound(outbound, 0) // For TCP post-handshake, simulate full handshake if sc.state == "post_handshake" { // SYN syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) - manager.processOutgoingHooks(syn, 0) + manager.filterOutbound(syn, 0) // SYN-ACK synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) - manager.dropFilter(synack, 0) + manager.filterInbound(synack, 0) // ACK ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) - manager.processOutgoingHooks(ack, 0) + manager.filterOutbound(ack, 0) } } b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(inbound, 0) + manager.filterInbound(inbound, 0) } }) } @@ -568,17 +568,17 @@ func BenchmarkLongLivedConnections(b *testing.B) { // Initial SYN syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], uint16(1024+i), 80, uint16(conntrack.TCPSyn)) - manager.processOutgoingHooks(syn, 0) + manager.filterOutbound(syn, 0) // SYN-ACK synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) - manager.dropFilter(synack, 0) + manager.filterInbound(synack, 0) // ACK ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], uint16(1024+i), 80, uint16(conntrack.TCPAck)) - manager.processOutgoingHooks(ack, 0) + manager.filterOutbound(ack, 0) } // Prepare test packets simulating bidirectional traffic @@ -599,9 +599,9 @@ func BenchmarkLongLivedConnections(b *testing.B) { // Simulate bidirectional traffic // First outbound data - manager.processOutgoingHooks(outPackets[connIdx], 0) + manager.filterOutbound(outPackets[connIdx], 0) // Then inbound response - this is what we're actually measuring - manager.dropFilter(inPackets[connIdx], 0) + manager.filterInbound(inPackets[connIdx], 0) } }) } @@ -700,19 +700,19 @@ func BenchmarkShortLivedConnections(b *testing.B) { p := patterns[connIdx] // Connection establishment - manager.processOutgoingHooks(p.syn, 0) - manager.dropFilter(p.synAck, 0) - manager.processOutgoingHooks(p.ack, 0) + manager.filterOutbound(p.syn, 0) + manager.filterInbound(p.synAck, 0) + manager.filterOutbound(p.ack, 0) // Data transfer - manager.processOutgoingHooks(p.request, 0) - manager.dropFilter(p.response, 0) + manager.filterOutbound(p.request, 0) + manager.filterInbound(p.response, 0) // Connection teardown - manager.processOutgoingHooks(p.finClient, 0) - manager.dropFilter(p.ackServer, 0) - manager.dropFilter(p.finServer, 0) - manager.processOutgoingHooks(p.ackClient, 0) + manager.filterOutbound(p.finClient, 0) + manager.filterInbound(p.ackServer, 0) + manager.filterInbound(p.finServer, 0) + manager.filterOutbound(p.ackClient, 0) } }) } @@ -760,15 +760,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { for i := 0; i < sc.connCount; i++ { syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], uint16(1024+i), 80, uint16(conntrack.TCPSyn)) - manager.processOutgoingHooks(syn, 0) + manager.filterOutbound(syn, 0) synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) - manager.dropFilter(synack, 0) + manager.filterInbound(synack, 0) ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], uint16(1024+i), 80, uint16(conntrack.TCPAck)) - manager.processOutgoingHooks(ack, 0) + manager.filterOutbound(ack, 0) } // Pre-generate test packets @@ -790,8 +790,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { counter++ // Simulate bidirectional traffic - manager.processOutgoingHooks(outPackets[connIdx], 0) - manager.dropFilter(inPackets[connIdx], 0) + manager.filterOutbound(outPackets[connIdx], 0) + manager.filterInbound(inPackets[connIdx], 0) } }) }) @@ -879,17 +879,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { p := patterns[connIdx] // Full connection lifecycle - manager.processOutgoingHooks(p.syn, 0) - manager.dropFilter(p.synAck, 0) - manager.processOutgoingHooks(p.ack, 0) + manager.filterOutbound(p.syn, 0) + manager.filterInbound(p.synAck, 0) + manager.filterOutbound(p.ack, 0) - manager.processOutgoingHooks(p.request, 0) - manager.dropFilter(p.response, 0) + manager.filterOutbound(p.request, 0) + manager.filterInbound(p.response, 0) - manager.processOutgoingHooks(p.finClient, 0) - manager.dropFilter(p.ackServer, 0) - manager.dropFilter(p.finServer, 0) - manager.processOutgoingHooks(p.ackClient, 0) + manager.filterOutbound(p.finClient, 0) + manager.filterInbound(p.ackServer, 0) + manager.filterInbound(p.finServer, 0) + manager.filterOutbound(p.ackClient, 0) } }) }) diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/filter_filter_test.go similarity index 99% rename from client/firewall/uspfilter/uspfilter_filter_test.go rename to client/firewall/uspfilter/filter_filter_test.go index 318f86a87..b630c9e66 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/filter_filter_test.go @@ -462,7 +462,7 @@ func TestPeerACLFiltering(t *testing.T) { t.Run("Implicit DROP (no rules)", func(t *testing.T) { packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443) - isDropped := manager.DropIncoming(packet, 0) + isDropped := manager.FilterInbound(packet, 0) require.True(t, isDropped, "Packet should be dropped when no rules exist") }) @@ -509,7 +509,7 @@ func TestPeerACLFiltering(t *testing.T) { }) packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) - isDropped := manager.DropIncoming(packet, 0) + isDropped := manager.FilterInbound(packet, 0) require.Equal(t, tc.shouldBeBlocked, isDropped) }) } @@ -1233,7 +1233,7 @@ func TestRouteACLFiltering(t *testing.T) { srcIP := netip.MustParseAddr(tc.srcIP) dstIP := netip.MustParseAddr(tc.dstIP) - // testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed + // testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed // to the forwarder _, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort) require.Equal(t, tc.shouldPass, isAllowed) diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/filter_test.go similarity index 98% rename from client/firewall/uspfilter/uspfilter_test.go rename to client/firewall/uspfilter/filter_test.go index 88de1ddcd..5b5cd5a53 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -321,7 +321,7 @@ func TestNotMatchByIP(t *testing.T) { return } - if m.dropFilter(buf.Bytes(), 0) { + if m.filterInbound(buf.Bytes(), 0) { t.Errorf("expected packet to be accepted") return } @@ -447,7 +447,7 @@ func TestProcessOutgoingHooks(t *testing.T) { require.NoError(t, err) // Test hook gets called - result := manager.processOutgoingHooks(buf.Bytes(), 0) + result := manager.filterOutbound(buf.Bytes(), 0) require.True(t, result) require.True(t, hookCalled) @@ -457,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) { err = gopacket.SerializeLayers(buf, opts, ipv4) require.NoError(t, err) - result = manager.processOutgoingHooks(buf.Bytes(), 0) + result = manager.filterOutbound(buf.Bytes(), 0) require.False(t, result) } @@ -553,7 +553,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { require.NoError(t, err) // Process outbound packet and verify connection tracking - drop := manager.DropOutgoing(outboundBuf.Bytes(), 0) + drop := manager.FilterOutbound(outboundBuf.Bytes(), 0) require.False(t, drop, "Initial outbound packet should not be dropped") // Verify connection was tracked @@ -620,7 +620,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { for _, cp := range checkPoints { time.Sleep(cp.sleep) - drop = manager.dropFilter(inboundBuf.Bytes(), 0) + drop = manager.filterInbound(inboundBuf.Bytes(), 0) require.Equal(t, cp.shouldAllow, !drop, cp.description) // If the connection should still be valid, verify it exists @@ -669,7 +669,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { } // Create a new outbound connection for invalid tests - drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0) + drop = manager.filterOutbound(outboundBuf.Bytes(), 0) require.False(t, drop, "Second outbound packet should not be dropped") for _, tc := range invalidCases { @@ -691,7 +691,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { require.NoError(t, err) // Verify the invalid packet is dropped - drop = manager.dropFilter(testBuf.Bytes(), 0) + drop = manager.filterInbound(testBuf.Bytes(), 0) require.True(t, drop, tc.description) }) } diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go new file mode 100644 index 000000000..4539f7da5 --- /dev/null +++ b/client/firewall/uspfilter/nat.go @@ -0,0 +1,408 @@ +package uspfilter + +import ( + "encoding/binary" + "errors" + "fmt" + "net/netip" + + "github.com/google/gopacket/layers" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) + +var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") + +func ipv4Checksum(header []byte) uint16 { + if len(header) < 20 { + return 0 + } + + var sum1, sum2 uint32 + + // Parallel processing - unroll and compute two sums simultaneously + sum1 += uint32(binary.BigEndian.Uint16(header[0:2])) + sum2 += uint32(binary.BigEndian.Uint16(header[2:4])) + sum1 += uint32(binary.BigEndian.Uint16(header[4:6])) + sum2 += uint32(binary.BigEndian.Uint16(header[6:8])) + sum1 += uint32(binary.BigEndian.Uint16(header[8:10])) + // Skip checksum field at [10:12] + sum2 += uint32(binary.BigEndian.Uint16(header[12:14])) + sum1 += uint32(binary.BigEndian.Uint16(header[14:16])) + sum2 += uint32(binary.BigEndian.Uint16(header[16:18])) + sum1 += uint32(binary.BigEndian.Uint16(header[18:20])) + + sum := sum1 + sum2 + + // Handle remaining bytes for headers > 20 bytes + for i := 20; i < len(header)-1; i += 2 { + sum += uint32(binary.BigEndian.Uint16(header[i : i+2])) + } + + if len(header)%2 == 1 { + sum += uint32(header[len(header)-1]) << 8 + } + + // Optimized carry fold - single iteration handles most cases + sum = (sum & 0xFFFF) + (sum >> 16) + if sum > 0xFFFF { + sum++ + } + + return ^uint16(sum) +} + +func icmpChecksum(data []byte) uint16 { + var sum1, sum2, sum3, sum4 uint32 + i := 0 + + // Process 16 bytes at once with 4 parallel accumulators + for i <= len(data)-16 { + sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2])) + sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4])) + sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6])) + sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8])) + sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10])) + sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12])) + sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14])) + sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16])) + i += 16 + } + + sum := sum1 + sum2 + sum3 + sum4 + + // Handle remaining bytes + for i < len(data)-1 { + sum += uint32(binary.BigEndian.Uint16(data[i : i+2])) + i += 2 + } + + if len(data)%2 == 1 { + sum += uint32(data[len(data)-1]) << 8 + } + + sum = (sum & 0xFFFF) + (sum >> 16) + if sum > 0xFFFF { + sum++ + } + + return ^uint16(sum) +} + +type biDNATMap struct { + forward map[netip.Addr]netip.Addr + reverse map[netip.Addr]netip.Addr +} + +func newBiDNATMap() *biDNATMap { + return &biDNATMap{ + forward: make(map[netip.Addr]netip.Addr), + reverse: make(map[netip.Addr]netip.Addr), + } +} + +func (b *biDNATMap) set(original, translated netip.Addr) { + b.forward[original] = translated + b.reverse[translated] = original +} + +func (b *biDNATMap) delete(original netip.Addr) { + if translated, exists := b.forward[original]; exists { + delete(b.forward, original) + delete(b.reverse, translated) + } +} + +func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) { + translated, exists := b.forward[original] + return translated, exists +} + +func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) { + original, exists := b.reverse[translated] + return original, exists +} + +func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error { + if !originalAddr.IsValid() || !translatedAddr.IsValid() { + return fmt.Errorf("invalid IP addresses") + } + + if m.localipmanager.IsLocalIP(translatedAddr) { + return fmt.Errorf("cannot map to local IP: %s", translatedAddr) + } + + m.dnatMutex.Lock() + defer m.dnatMutex.Unlock() + + // Initialize both maps together if either is nil + if m.dnatMappings == nil || m.dnatBiMap == nil { + m.dnatMappings = make(map[netip.Addr]netip.Addr) + m.dnatBiMap = newBiDNATMap() + } + + m.dnatMappings[originalAddr] = translatedAddr + m.dnatBiMap.set(originalAddr, translatedAddr) + + if len(m.dnatMappings) == 1 { + m.dnatEnabled.Store(true) + } + + return nil +} + +// RemoveInternalDNATMapping removes a 1:1 IP address mapping +func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { + m.dnatMutex.Lock() + defer m.dnatMutex.Unlock() + + if _, exists := m.dnatMappings[originalAddr]; !exists { + return fmt.Errorf("mapping not found for: %s", originalAddr) + } + + delete(m.dnatMappings, originalAddr) + m.dnatBiMap.delete(originalAddr) + if len(m.dnatMappings) == 0 { + m.dnatEnabled.Store(false) + } + + return nil +} + +// getDNATTranslation returns the translated address if a mapping exists +func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { + if !m.dnatEnabled.Load() { + return addr, false + } + + m.dnatMutex.RLock() + translated, exists := m.dnatBiMap.getTranslated(addr) + m.dnatMutex.RUnlock() + return translated, exists +} + +// findReverseDNATMapping finds original address for return traffic +func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) { + if !m.dnatEnabled.Load() { + return translatedAddr, false + } + + m.dnatMutex.RLock() + original, exists := m.dnatBiMap.getOriginal(translatedAddr) + m.dnatMutex.RUnlock() + return original, exists +} + +// translateOutboundDNAT applies DNAT translation to outbound packets +func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + return false + } + + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + translatedIP, exists := m.getDNATTranslation(dstIP) + if !exists { + return false + } + + if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { + m.logger.Error("Failed to rewrite packet destination: %v", err) + return false + } + + m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP) + return true +} + +// translateInboundReverse applies reverse DNAT to inbound return traffic +func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + + originalIP, exists := m.findReverseDNATMapping(srcIP) + if !exists { + return false + } + + if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { + m.logger.Error("Failed to rewrite packet source: %v", err) + return false + } + + m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP) + return true +} + +// rewritePacketDestination replaces destination IP in the packet +func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error { + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { + return ErrIPv4Only + } + + var oldDst [4]byte + copy(oldDst[:], packetData[16:20]) + newDst := newIP.As4() + + copy(packetData[16:20], newDst[:]) + + ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return fmt.Errorf("invalid IP header length") + } + + binary.BigEndian.PutUint16(packetData[10:12], 0) + ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) + binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) + + if len(d.decoded) > 1 { + switch d.decoded[1] { + case layers.LayerTypeTCP: + m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) + case layers.LayerTypeUDP: + m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) + case layers.LayerTypeICMPv4: + m.updateICMPChecksum(packetData, ipHeaderLen) + } + } + + return nil +} + +// rewritePacketSource replaces the source IP address in the packet +func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error { + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { + return ErrIPv4Only + } + + var oldSrc [4]byte + copy(oldSrc[:], packetData[12:16]) + newSrc := newIP.As4() + + copy(packetData[12:16], newSrc[:]) + + ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return fmt.Errorf("invalid IP header length") + } + + binary.BigEndian.PutUint16(packetData[10:12], 0) + ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) + binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) + + if len(d.decoded) > 1 { + switch d.decoded[1] { + case layers.LayerTypeTCP: + m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) + case layers.LayerTypeUDP: + m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) + case layers.LayerTypeICMPv4: + m.updateICMPChecksum(packetData, ipHeaderLen) + } + } + + return nil +} + +func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { + tcpStart := ipHeaderLen + if len(packetData) < tcpStart+18 { + return + } + + checksumOffset := tcpStart + 16 + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) +} + +func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { + udpStart := ipHeaderLen + if len(packetData) < udpStart+8 { + return + } + + checksumOffset := udpStart + 6 + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + + if oldChecksum == 0 { + return + } + + newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) +} + +func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { + icmpStart := ipHeaderLen + if len(packetData) < icmpStart+8 { + return + } + + icmpData := packetData[icmpStart:] + binary.BigEndian.PutUint16(icmpData[2:4], 0) + checksum := icmpChecksum(icmpData) + binary.BigEndian.PutUint16(icmpData[2:4], checksum) +} + +// incrementalUpdate performs incremental checksum update per RFC 1624 +func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { + sum := uint32(^oldChecksum) + + // Fast path for IPv4 addresses (4 bytes) - most common case + if len(oldBytes) == 4 && len(newBytes) == 4 { + sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2])) + sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) + sum += uint32(binary.BigEndian.Uint16(newBytes[0:2])) + sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) + } else { + // Fallback for other lengths + for i := 0; i < len(oldBytes)-1; i += 2 { + sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2])) + } + if len(oldBytes)%2 == 1 { + sum += uint32(^oldBytes[len(oldBytes)-1]) << 8 + } + + for i := 0; i < len(newBytes)-1; i += 2 { + sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2])) + } + if len(newBytes)%2 == 1 { + sum += uint32(newBytes[len(newBytes)-1]) << 8 + } + } + + sum = (sum & 0xFFFF) + (sum >> 16) + if sum > 0xFFFF { + sum++ + } + + return ^uint16(sum) +} + +// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding) +func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { + if m.nativeFirewall == nil { + return nil, errNatNotSupported + } + return m.nativeFirewall.AddDNATRule(rule) +} + +// DeleteDNATRule deletes a DNAT rule (delegates to native firewall) +func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { + if m.nativeFirewall == nil { + return errNatNotSupported + } + return m.nativeFirewall.DeleteDNATRule(rule) +} diff --git a/client/firewall/uspfilter/nat_bench_test.go b/client/firewall/uspfilter/nat_bench_test.go new file mode 100644 index 000000000..16dba682e --- /dev/null +++ b/client/firewall/uspfilter/nat_bench_test.go @@ -0,0 +1,416 @@ +package uspfilter + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/device" +) + +// BenchmarkDNATTranslation measures the performance of DNAT operations +func BenchmarkDNATTranslation(b *testing.B) { + scenarios := []struct { + name string + proto layers.IPProtocol + setupDNAT bool + description string + }{ + { + name: "tcp_with_dnat", + proto: layers.IPProtocolTCP, + setupDNAT: true, + description: "TCP packet with DNAT translation enabled", + }, + { + name: "tcp_without_dnat", + proto: layers.IPProtocolTCP, + setupDNAT: false, + description: "TCP packet without DNAT (baseline)", + }, + { + name: "udp_with_dnat", + proto: layers.IPProtocolUDP, + setupDNAT: true, + description: "UDP packet with DNAT translation enabled", + }, + { + name: "udp_without_dnat", + proto: layers.IPProtocolUDP, + setupDNAT: false, + description: "UDP packet without DNAT (baseline)", + }, + { + name: "icmp_with_dnat", + proto: layers.IPProtocolICMPv4, + setupDNAT: true, + description: "ICMP packet with DNAT translation enabled", + }, + { + name: "icmp_without_dnat", + proto: layers.IPProtocolICMPv4, + setupDNAT: false, + description: "ICMP packet without DNAT (baseline)", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + // Set logger to error level to reduce noise during benchmarking + manager.SetLogLevel(log.ErrorLevel) + defer func() { + // Restore to info level after benchmark + manager.SetLogLevel(log.InfoLevel) + }() + + // Setup DNAT mapping if needed + originalIP := netip.MustParseAddr("192.168.1.100") + translatedIP := netip.MustParseAddr("10.0.0.100") + + if sc.setupDNAT { + err := manager.AddInternalDNATMapping(originalIP, translatedIP) + require.NoError(b, err) + } + + // Create test packets + srcIP := netip.MustParseAddr("172.16.0.1") + outboundPacket := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80) + + // Pre-establish connection for reverse DNAT test + if sc.setupDNAT { + manager.filterOutbound(outboundPacket, 0) + } + + b.ResetTimer() + + // Benchmark outbound DNAT translation + b.Run("outbound", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh packet each time since translation modifies it + packet := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80) + manager.filterOutbound(packet, 0) + } + }) + + // Benchmark inbound reverse DNAT translation + if sc.setupDNAT { + b.Run("inbound_reverse", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh packet each time since translation modifies it + packet := generateDNATTestPacket(b, translatedIP, srcIP, sc.proto, 80, 12345) + manager.filterInbound(packet, 0) + } + }) + } + }) + } +} + +// BenchmarkDNATConcurrency tests DNAT performance under concurrent load +func BenchmarkDNATConcurrency(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + // Set logger to error level to reduce noise during benchmarking + manager.SetLogLevel(log.ErrorLevel) + defer func() { + // Restore to info level after benchmark + manager.SetLogLevel(log.InfoLevel) + }() + + // Setup multiple DNAT mappings + numMappings := 100 + originalIPs := make([]netip.Addr, numMappings) + translatedIPs := make([]netip.Addr, numMappings) + + for i := 0; i < numMappings; i++ { + originalIPs[i] = netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1)) + translatedIPs[i] = netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1)) + err := manager.AddInternalDNATMapping(originalIPs[i], translatedIPs[i]) + require.NoError(b, err) + } + + srcIP := netip.MustParseAddr("172.16.0.1") + + // Pre-generate packets + outboundPackets := make([][]byte, numMappings) + inboundPackets := make([][]byte, numMappings) + for i := 0; i < numMappings; i++ { + outboundPackets[i] = generateDNATTestPacket(b, srcIP, originalIPs[i], layers.IPProtocolTCP, 12345, 80) + inboundPackets[i] = generateDNATTestPacket(b, translatedIPs[i], srcIP, layers.IPProtocolTCP, 80, 12345) + // Establish connections + manager.filterOutbound(outboundPackets[i], 0) + } + + b.ResetTimer() + + b.Run("concurrent_outbound", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + idx := i % numMappings + packet := generateDNATTestPacket(b, srcIP, originalIPs[idx], layers.IPProtocolTCP, 12345, 80) + manager.filterOutbound(packet, 0) + i++ + } + }) + }) + + b.Run("concurrent_inbound", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + idx := i % numMappings + packet := generateDNATTestPacket(b, translatedIPs[idx], srcIP, layers.IPProtocolTCP, 80, 12345) + manager.filterInbound(packet, 0) + i++ + } + }) + }) +} + +// BenchmarkDNATScaling tests how DNAT performance scales with number of mappings +func BenchmarkDNATScaling(b *testing.B) { + mappingCounts := []int{1, 10, 100, 1000} + + for _, count := range mappingCounts { + b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + // Set logger to error level to reduce noise during benchmarking + manager.SetLogLevel(log.ErrorLevel) + defer func() { + // Restore to info level after benchmark + manager.SetLogLevel(log.InfoLevel) + }() + + // Setup DNAT mappings + for i := 0; i < count; i++ { + originalIP := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1)) + translatedIP := netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1)) + err := manager.AddInternalDNATMapping(originalIP, translatedIP) + require.NoError(b, err) + } + + // Test with the last mapping added (worst case for lookup) + srcIP := netip.MustParseAddr("172.16.0.1") + lastOriginal := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", ((count-1)/254)+1, ((count-1)%254)+1)) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + packet := generateDNATTestPacket(b, srcIP, lastOriginal, layers.IPProtocolTCP, 12345, 80) + manager.filterOutbound(packet, 0) + } + }) + } +} + +// generateDNATTestPacket creates a test packet for DNAT benchmarking +func generateDNATTestPacket(tb testing.TB, srcIP, dstIP netip.Addr, proto layers.IPProtocol, srcPort, dstPort uint16) []byte { + tb.Helper() + + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP.AsSlice(), + DstIP: dstIP.AsSlice(), + Protocol: proto, + } + + var transportLayer gopacket.SerializableLayer + switch proto { + case layers.IPProtocolTCP: + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + } + require.NoError(tb, tcp.SetNetworkLayerForChecksum(ipv4)) + transportLayer = tcp + case layers.IPProtocolUDP: + udp := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + require.NoError(tb, udp.SetNetworkLayerForChecksum(ipv4)) + transportLayer = udp + case layers.IPProtocolICMPv4: + icmp := &layers.ICMPv4{ + TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0), + } + transportLayer = icmp + } + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test")) + require.NoError(tb, err) + return buf.Bytes() +} + +// BenchmarkChecksumUpdate specifically benchmarks checksum calculation performance +func BenchmarkChecksumUpdate(b *testing.B) { + // Create test data for checksum calculations + testData := make([]byte, 64) // Typical packet size for checksum testing + for i := range testData { + testData[i] = byte(i) + } + + b.Run("ipv4_checksum", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = ipv4Checksum(testData[:20]) // IPv4 header is typically 20 bytes + } + }) + + b.Run("icmp_checksum", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = icmpChecksum(testData) + } + }) + + b.Run("incremental_update", func(b *testing.B) { + oldBytes := []byte{192, 168, 1, 100} + newBytes := []byte{10, 0, 0, 100} + oldChecksum := uint16(0x1234) + + for i := 0; i < b.N; i++ { + _ = incrementalUpdate(oldChecksum, oldBytes, newBytes) + } + }) +} + +// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations +func BenchmarkDNATMemoryAllocations(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + // Set logger to error level to reduce noise during benchmarking + manager.SetLogLevel(log.ErrorLevel) + defer func() { + // Restore to info level after benchmark + manager.SetLogLevel(log.InfoLevel) + }() + + originalIP := netip.MustParseAddr("192.168.1.100") + translatedIP := netip.MustParseAddr("10.0.0.100") + srcIP := netip.MustParseAddr("172.16.0.1") + + err = manager.AddInternalDNATMapping(originalIP, translatedIP) + require.NoError(b, err) + + packet := generateDNATTestPacket(b, srcIP, originalIP, layers.IPProtocolTCP, 12345, 80) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Create fresh packet each time to isolate allocation testing + testPacket := make([]byte, len(packet)) + copy(testPacket, packet) + + // Parse the packet fresh each time to get a clean decoder + d := &decoder{decoded: []gopacket.LayerType{}} + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + err = d.parser.DecodeLayers(testPacket, &d.decoded) + assert.NoError(b, err) + + manager.translateOutboundDNAT(testPacket, d) + } +} + +// BenchmarkDirectIPExtraction tests the performance improvement of direct IP extraction +func BenchmarkDirectIPExtraction(b *testing.B) { + // Create a test packet + srcIP := netip.MustParseAddr("172.16.0.1") + dstIP := netip.MustParseAddr("192.168.1.100") + packet := generateDNATTestPacket(b, srcIP, dstIP, layers.IPProtocolTCP, 12345, 80) + + b.Run("direct_byte_access", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Direct extraction from packet bytes + _ = netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]}) + } + }) + + b.Run("decoder_extraction", func(b *testing.B) { + // Create decoder once for comparison + d := &decoder{decoded: []gopacket.LayerType{}} + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + err := d.parser.DecodeLayers(packet, &d.decoded) + assert.NoError(b, err) + + for i := 0; i < b.N; i++ { + // Extract using decoder (traditional method) + dst, _ := netip.AddrFromSlice(d.ip4.DstIP) + _ = dst + } + }) +} + +// BenchmarkChecksumOptimizations compares optimized vs standard checksum implementations +func BenchmarkChecksumOptimizations(b *testing.B) { + // Create test IPv4 header (20 bytes) + header := make([]byte, 20) + for i := range header { + header[i] = byte(i) + } + // Clear checksum field + header[10] = 0 + header[11] = 0 + + b.Run("optimized_ipv4_checksum", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = ipv4Checksum(header) + } + }) + + // Test incremental checksum updates + oldIP := []byte{192, 168, 1, 100} + newIP := []byte{10, 0, 0, 100} + oldChecksum := uint16(0x1234) + + b.Run("optimized_incremental_update", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = incrementalUpdate(oldChecksum, oldIP, newIP) + } + }) +} diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go new file mode 100644 index 000000000..710abd445 --- /dev/null +++ b/client/firewall/uspfilter/nat_test.go @@ -0,0 +1,145 @@ +package uspfilter + +import ( + "net/netip" + "testing" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/iface/device" +) + +// TestDNATTranslationCorrectness verifies DNAT translation works correctly +func TestDNATTranslationCorrectness(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + originalIP := netip.MustParseAddr("192.168.1.100") + translatedIP := netip.MustParseAddr("10.0.0.100") + srcIP := netip.MustParseAddr("172.16.0.1") + + // Add DNAT mapping + err = manager.AddInternalDNATMapping(originalIP, translatedIP) + require.NoError(t, err) + + testCases := []struct { + name string + protocol layers.IPProtocol + srcPort uint16 + dstPort uint16 + }{ + {"TCP", layers.IPProtocolTCP, 12345, 80}, + {"UDP", layers.IPProtocolUDP, 12345, 53}, + {"ICMP", layers.IPProtocolICMPv4, 0, 0}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Test outbound DNAT translation + outboundPacket := generateDNATTestPacket(t, srcIP, originalIP, tc.protocol, tc.srcPort, tc.dstPort) + originalOutbound := make([]byte, len(outboundPacket)) + copy(originalOutbound, outboundPacket) + + // Process outbound packet (should translate destination) + translated := manager.translateOutboundDNAT(outboundPacket, parsePacket(t, outboundPacket)) + require.True(t, translated, "Outbound packet should be translated") + + // Verify destination IP was changed + dstIPAfter := netip.AddrFrom4([4]byte{outboundPacket[16], outboundPacket[17], outboundPacket[18], outboundPacket[19]}) + require.Equal(t, translatedIP, dstIPAfter, "Destination IP should be translated") + + // Test inbound reverse DNAT translation + inboundPacket := generateDNATTestPacket(t, translatedIP, srcIP, tc.protocol, tc.dstPort, tc.srcPort) + originalInbound := make([]byte, len(inboundPacket)) + copy(originalInbound, inboundPacket) + + // Process inbound packet (should reverse translate source) + reversed := manager.translateInboundReverse(inboundPacket, parsePacket(t, inboundPacket)) + require.True(t, reversed, "Inbound packet should be reverse translated") + + // Verify source IP was changed back to original + srcIPAfter := netip.AddrFrom4([4]byte{inboundPacket[12], inboundPacket[13], inboundPacket[14], inboundPacket[15]}) + require.Equal(t, originalIP, srcIPAfter, "Source IP should be reverse translated") + + // Test that checksums are recalculated correctly + if tc.protocol != layers.IPProtocolICMPv4 { + // For TCP/UDP, verify the transport checksum was updated + require.NotEqual(t, originalOutbound, outboundPacket, "Outbound packet should be modified") + require.NotEqual(t, originalInbound, inboundPacket, "Inbound packet should be modified") + } + }) + } +} + +// parsePacket helper to create a decoder for testing +func parsePacket(t testing.TB, packetData []byte) *decoder { + t.Helper() + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + + err := d.parser.DecodeLayers(packetData, &d.decoded) + require.NoError(t, err) + return d +} + +// TestDNATMappingManagement tests adding/removing DNAT mappings +func TestDNATMappingManagement(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + originalIP := netip.MustParseAddr("192.168.1.100") + translatedIP := netip.MustParseAddr("10.0.0.100") + + // Test adding mapping + err = manager.AddInternalDNATMapping(originalIP, translatedIP) + require.NoError(t, err) + + // Verify mapping exists + result, exists := manager.getDNATTranslation(originalIP) + require.True(t, exists) + require.Equal(t, translatedIP, result) + + // Test reverse lookup + reverseResult, exists := manager.findReverseDNATMapping(translatedIP) + require.True(t, exists) + require.Equal(t, originalIP, reverseResult) + + // Test removing mapping + err = manager.RemoveInternalDNATMapping(originalIP) + require.NoError(t, err) + + // Verify mapping no longer exists + _, exists = manager.getDNATTranslation(originalIP) + require.False(t, exists) + + _, exists = manager.findReverseDNATMapping(translatedIP) + require.False(t, exists) + + // Test error cases + err = manager.AddInternalDNATMapping(netip.Addr{}, translatedIP) + require.Error(t, err, "Should reject invalid original IP") + + err = manager.AddInternalDNATMapping(originalIP, netip.Addr{}) + require.Error(t, err, "Should reject invalid translated IP") + + err = manager.RemoveInternalDNATMapping(originalIP) + require.Error(t, err, "Should error when removing non-existent mapping") +} diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index 53350797c..ef04f2700 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -401,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { // will create or update the connection state - dropped := m.processOutgoingHooks(packetData, 0) + dropped := m.filterOutbound(packetData, 0) if dropped { trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) } else { diff --git a/client/iface/device/device_filter.go b/client/iface/device/device_filter.go index 5a1a0e96a..015f71ff4 100644 --- a/client/iface/device/device_filter.go +++ b/client/iface/device/device_filter.go @@ -9,11 +9,11 @@ import ( // PacketFilter interface for firewall abilities type PacketFilter interface { - // DropOutgoing filter outgoing packets from host to external destinations - DropOutgoing(packetData []byte, size int) bool + // FilterOutbound filter outgoing packets from host to external destinations + FilterOutbound(packetData []byte, size int) bool - // DropIncoming filter incoming packets from external sources to host - DropIncoming(packetData []byte, size int) bool + // FilterInbound filter incoming packets from external sources to host + FilterInbound(packetData []byte, size int) bool // AddUDPPacketHook calls hook when UDP packet from given direction matched // @@ -54,7 +54,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er } for i := 0; i < n; i++ { - if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) { + if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) { bufs = append(bufs[:i], bufs[i+1:]...) sizes = append(sizes[:i], sizes[i+1:]...) n-- @@ -78,7 +78,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { filteredBufs := make([][]byte, 0, len(bufs)) dropped := 0 for _, buf := range bufs { - if !filter.DropIncoming(buf[offset:], len(buf)) { + if !filter.FilterInbound(buf[offset:], len(buf)) { filteredBufs = append(filteredBufs, buf) dropped++ } diff --git a/client/iface/device/device_filter_test.go b/client/iface/device/device_filter_test.go index c90269e82..eef783542 100644 --- a/client/iface/device/device_filter_test.go +++ b/client/iface/device/device_filter_test.go @@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) { tun.EXPECT().Write(mockBufs, 0).Return(0, nil) filter := mocks.NewMockPacketFilter(ctrl) - filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true) + filter.EXPECT().FilterInbound(gomock.Any(), gomock.Any()).Return(true) wrapped := newDeviceFilter(tun) wrapped.filter = filter @@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) { return 1, nil }) filter := mocks.NewMockPacketFilter(ctrl) - filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true) + filter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).Return(true) wrapped := newDeviceFilter(tun) wrapped.filter = filter diff --git a/client/iface/mocks/filter.go b/client/iface/mocks/filter.go index 8cd2a1231..566068aa5 100644 --- a/client/iface/mocks/filter.go +++ b/client/iface/mocks/filter.go @@ -48,32 +48,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3) } -// DropIncoming mocks base method. -func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool { +// FilterInbound mocks base method. +func (m *MockPacketFilter) FilterInbound(arg0 []byte, arg1 int) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1) + ret := m.ctrl.Call(m, "FilterInbound", arg0, arg1) ret0, _ := ret[0].(bool) return ret0 } -// DropIncoming indicates an expected call of DropIncoming. -func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call { +// FilterInbound indicates an expected call of FilterInbound. +func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0, arg1) } -// DropOutgoing mocks base method. -func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool { +// FilterOutbound mocks base method. +func (m *MockPacketFilter) FilterOutbound(arg0 []byte, arg1 int) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1) + ret := m.ctrl.Call(m, "FilterOutbound", arg0, arg1) ret0, _ := ret[0].(bool) return ret0 } -// DropOutgoing indicates an expected call of DropOutgoing. -func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call { +// FilterOutbound indicates an expected call of FilterOutbound. +func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1) } // RemovePacketHook mocks base method. diff --git a/client/iface/mocks/iface/mocks/filter.go b/client/iface/mocks/iface/mocks/filter.go index 17e123abb..291ab9ab5 100644 --- a/client/iface/mocks/iface/mocks/filter.go +++ b/client/iface/mocks/iface/mocks/filter.go @@ -46,32 +46,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3) } -// DropIncoming mocks base method. -func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool { +// FilterInbound mocks base method. +func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DropIncoming", arg0) + ret := m.ctrl.Call(m, "FilterInbound", arg0) ret0, _ := ret[0].(bool) return ret0 } -// DropIncoming indicates an expected call of DropIncoming. -func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call { +// FilterInbound indicates an expected call of FilterInbound. +func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0) } -// DropOutgoing mocks base method. -func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool { +// FilterOutbound mocks base method. +func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DropOutgoing", arg0) + ret := m.ctrl.Call(m, "FilterOutbound", arg0) ret0, _ := ret[0].(bool) return ret0 } -// DropOutgoing indicates an expected call of DropOutgoing. -func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call { +// FilterOutbound indicates an expected call of FilterOutbound. +func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0) } // SetNetwork mocks base method. diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 1cf59fb5b..21a9e2f2d 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -464,7 +464,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { defer ctrl.Finish() packetfilter := pfmock.NewMockPacketFilter(ctrl) - packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes() + packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes() packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) packetfilter.EXPECT().RemovePacketHook(gomock.Any()) diff --git a/client/internal/engine.go b/client/internal/engine.go index 4ea6fbd94..74d84569a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -383,7 +383,13 @@ func (e *Engine) Start() error { } e.stateManager.Start() - initialRoutes, dnsServer, err := e.newDnsServer() + initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings() + if err != nil { + e.close() + return fmt.Errorf("read initial settings: %w", err) + } + + dnsServer, err := e.newDnsServer(dnsConfig) if err != nil { e.close() return fmt.Errorf("create dns server: %w", err) @@ -400,6 +406,7 @@ func (e *Engine) Start() error { InitialRoutes: initialRoutes, StateManager: e.stateManager, DNSServer: dnsServer, + DNSFeatureFlag: dnsFeatureFlag, PeerStore: e.peerStore, DisableClientRoutes: e.config.DisableClientRoutes, DisableServerRoutes: e.config.DisableServerRoutes, @@ -488,9 +495,9 @@ func (e *Engine) createFirewall() error { } func (e *Engine) initFirewall() error { - if err := e.routeManager.EnableServerRouter(e.firewall); err != nil { + if err := e.routeManager.SetFirewall(e.firewall); err != nil { e.close() - return fmt.Errorf("enable server router: %w", err) + return fmt.Errorf("set firewall: %w", err) } if e.config.BlockLANAccess { @@ -1009,8 +1016,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { log.Errorf("failed to update dns server, err: %v", err) } - dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) - // apply routes first, route related actions might depend on routing being enabled routes := toRoutes(networkMap.GetRoutes()) serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes) @@ -1021,6 +1026,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes)) } + dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil { log.Errorf("failed to update routes: %v", err) } @@ -1489,7 +1495,12 @@ func (e *Engine) close() { } } -func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { +func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) { + if runtime.GOOS != "android" { + // nolint:nilnil + return nil, nil, false, nil + } + info := system.GetInfo(e.ctx) info.SetFlags( e.config.RosenpassEnabled, @@ -1506,11 +1517,12 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { netMap, err := e.mgmClient.GetNetworkMap(info) if err != nil { - return nil, nil, err + return nil, nil, false, err } routes := toRoutes(netMap.GetRoutes()) dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network) - return routes, &dnsCfg, nil + dnsFeatureFlag := toDNSFeatureFlag(netMap) + return routes, &dnsCfg, dnsFeatureFlag, nil } func (e *Engine) newWgIface() (*iface.WGIface, error) { @@ -1558,18 +1570,14 @@ func (e *Engine) wgInterfaceCreate() (err error) { return err } -func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { +func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { // due to tests where we are using a mocked version of the DNS server if e.dnsServer != nil { - return nil, e.dnsServer, nil + return e.dnsServer, nil } switch runtime.GOOS { case "android": - routes, dnsConfig, err := e.readInitialSettings() - if err != nil { - return nil, nil, err - } dnsServer := dns.NewDefaultServerPermanentUpstream( e.ctx, e.wgInterface, @@ -1580,19 +1588,19 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { e.config.DisableDNS, ) go e.mobileDep.DnsReadyListener.OnReady() - return routes, dnsServer, nil + return dnsServer, nil case "ios": dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS) - return nil, dnsServer, nil + return dnsServer, nil default: dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS) if err != nil { - return nil, nil, err + return nil, err } - return nil, dnsServer, nil + return dnsServer, nil } } diff --git a/client/internal/routemanager/client/client.go b/client/internal/routemanager/client/client.go index 46bff96db..0b8e161d2 100644 --- a/client/internal/routemanager/client/client.go +++ b/client/internal/routemanager/client/client.go @@ -10,11 +10,10 @@ import ( nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/iface" - "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/static" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/route" @@ -553,41 +552,16 @@ func (w *Watcher) Stop() { w.currentChosenStatus = nil } -func HandlerFromRoute( - rt *route.Route, - routeRefCounter *refcounter.RouteRefCounter, - allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, - dnsRouterInteval time.Duration, - statusRecorder *peer.Status, - wgInterface iface.WGIface, - dnsServer nbdns.Server, - peerStore *peerstore.Store, - useNewDNSRoute bool, -) RouteHandler { - switch handlerType(rt, useNewDNSRoute) { +func HandlerFromRoute(params common.HandlerParams) RouteHandler { + switch handlerType(params.Route, params.UseNewDNSRoute) { case handlerTypeDnsInterceptor: - return dnsinterceptor.New( - rt, - routeRefCounter, - allowedIPsRefCounter, - statusRecorder, - dnsServer, - wgInterface, - peerStore, - ) + return dnsinterceptor.New(params) case handlerTypeDynamic: - dns := nbdns.NewServiceViaMemory(wgInterface) - return dynamic.NewRoute( - rt, - routeRefCounter, - allowedIPsRefCounter, - dnsRouterInteval, - statusRecorder, - wgInterface, - fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()), - ) + dns := nbdns.NewServiceViaMemory(params.WgInterface) + dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()) + return dynamic.NewRoute(params, dnsAddr) default: - return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) + return static.NewRoute(params) } } diff --git a/client/internal/routemanager/client/client_test.go b/client/internal/routemanager/client/client_test.go index e7aff28b6..ec8e0e944 100644 --- a/client/internal/routemanager/client/client_test.go +++ b/client/internal/routemanager/client/client_test.go @@ -7,12 +7,12 @@ import ( "time" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/static" "github.com/netbirdio/netbird/route" ) func TestGetBestrouteFromStatuses(t *testing.T) { - testCases := []struct { name string statuses map[route.ID]routerPeerStatus @@ -811,9 +811,12 @@ func TestGetBestrouteFromStatuses(t *testing.T) { currentRoute = tc.existingRoutes[tc.currentRoute] } + params := common.HandlerParams{ + Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, + } // create new clientNetwork client := &Watcher{ - handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil), + handler: static.NewRoute(params), routes: tc.existingRoutes, currentChosen: currentRoute, } diff --git a/client/internal/routemanager/common/params.go b/client/internal/routemanager/common/params.go new file mode 100644 index 000000000..def18411f --- /dev/null +++ b/client/internal/routemanager/common/params.go @@ -0,0 +1,28 @@ +package common + +import ( + "time" + + "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" + "github.com/netbirdio/netbird/client/internal/routemanager/iface" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/route" +) + +type HandlerParams struct { + Route *route.Route + RouteRefCounter *refcounter.RouteRefCounter + AllowedIPsRefCounter *refcounter.AllowedIPsRefCounter + DnsRouterInterval time.Duration + StatusRecorder *peer.Status + WgInterface iface.WGIface + DnsServer dns.Server + PeerStore *peerstore.Store + UseNewDNSRoute bool + Firewall manager.Manager + FakeIPManager *fakeip.Manager +} diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 66557e888..c7c3aeb0b 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/netip" + "runtime" "strings" "sync" @@ -12,11 +13,14 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/common" + "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" @@ -24,6 +28,11 @@ import ( type domainMap map[domain.Domain][]netip.Prefix +type internalDNATer interface { + RemoveInternalDNATMapping(netip.Addr) error + AddInternalDNATMapping(netip.Addr, netip.Addr) error +} + type wgInterface interface { Name() string Address() wgaddr.Address @@ -40,26 +49,22 @@ type DnsInterceptor struct { interceptedDomains domainMap wgInterface wgInterface peerStore *peerstore.Store + firewall firewall.Manager + fakeIPManager *fakeip.Manager } -func New( - rt *route.Route, - routeRefCounter *refcounter.RouteRefCounter, - allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, - statusRecorder *peer.Status, - dnsServer nbdns.Server, - wgInterface wgInterface, - peerStore *peerstore.Store, -) *DnsInterceptor { +func New(params common.HandlerParams) *DnsInterceptor { return &DnsInterceptor{ - route: rt, - routeRefCounter: routeRefCounter, - allowedIPsRefcounter: allowedIPsRefCounter, - statusRecorder: statusRecorder, - dnsServer: dnsServer, - wgInterface: wgInterface, + route: params.Route, + routeRefCounter: params.RouteRefCounter, + allowedIPsRefcounter: params.AllowedIPsRefCounter, + statusRecorder: params.StatusRecorder, + dnsServer: params.DnsServer, + wgInterface: params.WgInterface, + peerStore: params.PeerStore, + firewall: params.Firewall, + fakeIPManager: params.FakeIPManager, interceptedDomains: make(domainMap), - peerStore: peerStore, } } @@ -78,9 +83,13 @@ func (d *DnsInterceptor) RemoveRoute() error { var merr *multierror.Error for domain, prefixes := range d.interceptedDomains { for _, prefix := range prefixes { - if _, err := d.routeRefCounter.Decrement(prefix); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err)) + // Routes should use fake IPs + routePrefix := d.transformRealToFakePrefix(prefix) + if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", routePrefix, err)) } + + // AllowedIPs should use real IPs if d.currentPeerKey != "" { if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) @@ -88,8 +97,10 @@ func (d *DnsInterceptor) RemoveRoute() error { } } log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", ")) - } + + d.cleanupDNATMappings() + for _, domain := range d.route.Domains { d.statusRecorder.DeleteResolvedDomainsStates(domain) } @@ -102,6 +113,68 @@ func (d *DnsInterceptor) RemoveRoute() error { return nberrors.FormatErrorOrNil(merr) } +// transformRealToFakePrefix returns fake IP prefix for routes (if DNAT enabled) +func (d *DnsInterceptor) transformRealToFakePrefix(realPrefix netip.Prefix) netip.Prefix { + if _, hasDNAT := d.internalDnatFw(); !hasDNAT { + return realPrefix + } + + if fakeIP, ok := d.fakeIPManager.GetFakeIP(realPrefix.Addr()); ok { + return netip.PrefixFrom(fakeIP, realPrefix.Bits()) + } + + return realPrefix +} + +// addAllowedIPForPrefix handles the AllowedIPs logic for a single prefix (uses real IPs) +func (d *DnsInterceptor) addAllowedIPForPrefix(realPrefix netip.Prefix, peerKey string, domain domain.Domain) error { + // AllowedIPs always use real IPs + ref, err := d.allowedIPsRefcounter.Increment(realPrefix, peerKey) + if err != nil { + return fmt.Errorf("add allowed IP %s: %v", realPrefix, err) + } + + if ref.Count > 1 && ref.Out != peerKey { + log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", + realPrefix.Addr(), + domain.SafeString(), + ref.Out, + ) + } + + return nil +} + +// addRouteAndAllowedIP handles both route and AllowedIPs addition for a prefix +func (d *DnsInterceptor) addRouteAndAllowedIP(realPrefix netip.Prefix, domain domain.Domain) error { + // Routes use fake IPs (so traffic to fake IPs gets routed to interface) + routePrefix := d.transformRealToFakePrefix(realPrefix) + if _, err := d.routeRefCounter.Increment(routePrefix, struct{}{}); err != nil { + return fmt.Errorf("add route for IP %s: %v", routePrefix, err) + } + + // Add to AllowedIPs if we have a current peer (uses real IPs) + if d.currentPeerKey == "" { + return nil + } + + return d.addAllowedIPForPrefix(realPrefix, d.currentPeerKey, domain) +} + +// removeAllowedIP handles AllowedIPs removal for a prefix (uses real IPs) +func (d *DnsInterceptor) removeAllowedIP(realPrefix netip.Prefix) error { + if d.currentPeerKey == "" { + return nil + } + + // AllowedIPs use real IPs + if _, err := d.allowedIPsRefcounter.Decrement(realPrefix); err != nil { + return fmt.Errorf("remove allowed IP %s: %v", realPrefix, err) + } + + return nil +} + func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error { d.mu.Lock() defer d.mu.Unlock() @@ -109,14 +182,9 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error { var merr *multierror.Error for domain, prefixes := range d.interceptedDomains { for _, prefix := range prefixes { - if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil { - merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) - } else if ref.Count > 1 && ref.Out != peerKey { - log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", - prefix.Addr(), - domain.SafeString(), - ref.Out, - ) + // AllowedIPs use real IPs + if err := d.addAllowedIPForPrefix(prefix, peerKey, domain); err != nil { + merr = multierror.Append(merr, err) } } } @@ -132,6 +200,7 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error { var merr *multierror.Error for _, prefixes := range d.interceptedDomains { for _, prefix := range prefixes { + // AllowedIPs use real IPs if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) } @@ -287,6 +356,8 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil { log.Errorf("failed to update domain prefixes: %v", err) } + + d.replaceIPsInDNSResponse(r, newPrefixes) } } @@ -297,6 +368,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { return nil } +// logPrefixChanges handles the logging for prefix changes +func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) { + if len(toAdd) > 0 { + log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s", + resolvedDomain.SafeString(), + originalDomain.SafeString(), + toAdd) + } + if len(toRemove) > 0 && !d.route.KeepRoute { + log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", + resolvedDomain.SafeString(), + originalDomain.SafeString(), + toRemove) + } +} + func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error { d.mu.Lock() defer d.mu.Unlock() @@ -305,70 +392,163 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes) var merr *multierror.Error + var dnatMappings map[netip.Addr]netip.Addr + + // Handle DNAT mappings for new prefixes + if _, hasDNAT := d.internalDnatFw(); hasDNAT { + dnatMappings = make(map[netip.Addr]netip.Addr) + for _, prefix := range toAdd { + realIP := prefix.Addr() + if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil { + dnatMappings[fakeIP] = realIP + log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP) + } else { + log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err) + } + } + } // Add new prefixes for _, prefix := range toAdd { - if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil { - merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err)) - continue - } - - if d.currentPeerKey == "" { - continue - } - if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil { - merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) - } else if ref.Count > 1 && ref.Out != d.currentPeerKey { - log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", - prefix.Addr(), - resolvedDomain.SafeString(), - ref.Out, - ) + if err := d.addRouteAndAllowedIP(prefix, resolvedDomain); err != nil { + merr = multierror.Append(merr, err) } } + d.addDNATMappings(dnatMappings) + if !d.route.KeepRoute { // Remove old prefixes for _, prefix := range toRemove { - if _, err := d.routeRefCounter.Decrement(prefix); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err)) + // Routes use fake IPs + routePrefix := d.transformRealToFakePrefix(prefix) + if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", routePrefix, err)) } - if d.currentPeerKey != "" { - if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) - } + // AllowedIPs use real IPs + if err := d.removeAllowedIP(prefix); err != nil { + merr = multierror.Append(merr, err) } } + + d.removeDNATMappings(toRemove) } - // Update domain prefixes using resolved domain as key + // Update domain prefixes using resolved domain as key - store real IPs if len(toAdd) > 0 || len(toRemove) > 0 { if d.route.KeepRoute { - // replace stored prefixes with old + added // nolint:gocritic newPrefixes = append(oldPrefixes, toAdd...) } d.interceptedDomains[resolvedDomain] = newPrefixes originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), ".")) + + // Store real IPs for status (user-facing), not fake IPs d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID()) - if len(toAdd) > 0 { - log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s", - resolvedDomain.SafeString(), - originalDomain.SafeString(), - toAdd) - } - if len(toRemove) > 0 && !d.route.KeepRoute { - log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", - resolvedDomain.SafeString(), - originalDomain.SafeString(), - toRemove) - } + d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove) } return nberrors.FormatErrorOrNil(merr) } +// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes +func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) { + if len(realPrefixes) == 0 { + return + } + + dnatFirewall, ok := d.internalDnatFw() + if !ok { + return + } + + for _, prefix := range realPrefixes { + realIP := prefix.Addr() + if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { + if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil { + log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err) + } else { + log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP) + } + } + } +} + +// internalDnatFw checks if the firewall supports internal DNAT +func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) { + if d.firewall == nil || runtime.GOOS != "android" { + return nil, false + } + fw, ok := d.firewall.(internalDNATer) + return fw, ok +} + +// addDNATMappings adds DNAT mappings to the firewall +func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) { + if len(mappings) == 0 { + return + } + + dnatFirewall, ok := d.internalDnatFw() + if !ok { + return + } + + for fakeIP, realIP := range mappings { + if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil { + log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err) + } else { + log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP) + } + } +} + +// cleanupDNATMappings removes all DNAT mappings for this interceptor +func (d *DnsInterceptor) cleanupDNATMappings() { + if _, ok := d.internalDnatFw(); !ok { + return + } + + for _, prefixes := range d.interceptedDomains { + d.removeDNATMappings(prefixes) + } +} + +// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response +func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) { + if _, ok := d.internalDnatFw(); !ok { + return + } + + // Replace A and AAAA records with fake IPs + for _, answer := range reply.Answer { + switch rr := answer.(type) { + case *dns.A: + realIP, ok := netip.AddrFromSlice(rr.A) + if !ok { + continue + } + + if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { + rr.A = fakeIP.AsSlice() + log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP) + } + + case *dns.AAAA: + realIP, ok := netip.AddrFromSlice(rr.AAAA) + if !ok { + continue + } + + if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { + rr.AAAA = fakeIP.AsSlice() + log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP) + } + } + } +} + func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) { prefixSet := make(map[netip.Prefix]bool) for _, prefix := range oldPrefixes { diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 47511d4af..5d561f0cf 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -14,6 +14,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" @@ -52,24 +53,16 @@ type Route struct { resolverAddr string } -func NewRoute( - rt *route.Route, - routeRefCounter *refcounter.RouteRefCounter, - allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, - interval time.Duration, - statusRecorder *peer.Status, - wgInterface iface.WGIface, - resolverAddr string, -) *Route { +func NewRoute(params common.HandlerParams, resolverAddr string) *Route { return &Route{ - route: rt, - routeRefCounter: routeRefCounter, - allowedIPsRefcounter: allowedIPsRefCounter, - interval: interval, - dynamicDomains: domainMap{}, - statusRecorder: statusRecorder, - wgInterface: wgInterface, + route: params.Route, + routeRefCounter: params.RouteRefCounter, + allowedIPsRefcounter: params.AllowedIPsRefCounter, + interval: params.DnsRouterInterval, + statusRecorder: params.StatusRecorder, + wgInterface: params.WgInterface, resolverAddr: resolverAddr, + dynamicDomains: domainMap{}, } } diff --git a/client/internal/routemanager/fakeip/fakeip.go b/client/internal/routemanager/fakeip/fakeip.go new file mode 100644 index 000000000..1592045d2 --- /dev/null +++ b/client/internal/routemanager/fakeip/fakeip.go @@ -0,0 +1,93 @@ +package fakeip + +import ( + "fmt" + "net/netip" + "sync" +) + +// Manager manages allocation of fake IPs from the 240.0.0.0/8 block +type Manager struct { + mu sync.Mutex + nextIP netip.Addr // Next IP to allocate + allocated map[netip.Addr]netip.Addr // real IP -> fake IP + fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP + baseIP netip.Addr // First usable IP: 240.0.0.1 + maxIP netip.Addr // Last usable IP: 240.255.255.254 +} + +// NewManager creates a new fake IP manager using 240.0.0.0/8 block +func NewManager() *Manager { + baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1}) + maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254}) + + return &Manager{ + nextIP: baseIP, + allocated: make(map[netip.Addr]netip.Addr), + fakeToReal: make(map[netip.Addr]netip.Addr), + baseIP: baseIP, + maxIP: maxIP, + } +} + +// AllocateFakeIP allocates a fake IP for the given real IP +// Returns the fake IP, or existing fake IP if already allocated +func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) { + if !realIP.Is4() { + return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported") + } + + m.mu.Lock() + defer m.mu.Unlock() + + if fakeIP, exists := m.allocated[realIP]; exists { + return fakeIP, nil + } + + startIP := m.nextIP + for { + currentIP := m.nextIP + + // Advance to next IP, wrapping at boundary + if m.nextIP.Compare(m.maxIP) >= 0 { + m.nextIP = m.baseIP + } else { + m.nextIP = m.nextIP.Next() + } + + // Check if current IP is available + if _, inUse := m.fakeToReal[currentIP]; !inUse { + m.allocated[realIP] = currentIP + m.fakeToReal[currentIP] = realIP + return currentIP, nil + } + + // Prevent infinite loop if all IPs exhausted + if m.nextIP.Compare(startIP) == 0 { + return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block") + } + } +} + +// GetFakeIP returns the fake IP for a real IP if it exists +func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) { + m.mu.Lock() + defer m.mu.Unlock() + + fakeIP, exists := m.allocated[realIP] + return fakeIP, exists +} + +// GetRealIP returns the real IP for a fake IP if it exists, otherwise false +func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) { + m.mu.Lock() + defer m.mu.Unlock() + + realIP, exists := m.fakeToReal[fakeIP] + return realIP, exists +} + +// GetFakeIPBlock returns the fake IP block used by this manager +func (m *Manager) GetFakeIPBlock() netip.Prefix { + return netip.MustParsePrefix("240.0.0.0/8") +} diff --git a/client/internal/routemanager/fakeip/fakeip_test.go b/client/internal/routemanager/fakeip/fakeip_test.go new file mode 100644 index 000000000..ad3e4bd4e --- /dev/null +++ b/client/internal/routemanager/fakeip/fakeip_test.go @@ -0,0 +1,240 @@ +package fakeip + +import ( + "net/netip" + "sync" + "testing" +) + +func TestNewManager(t *testing.T) { + manager := NewManager() + + if manager.baseIP.String() != "240.0.0.1" { + t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String()) + } + + if manager.maxIP.String() != "240.255.255.254" { + t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String()) + } + + if manager.nextIP.Compare(manager.baseIP) != 0 { + t.Errorf("Expected nextIP to start at baseIP") + } +} + +func TestAllocateFakeIP(t *testing.T) { + manager := NewManager() + realIP := netip.MustParseAddr("8.8.8.8") + + fakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate fake IP: %v", err) + } + + if !fakeIP.Is4() { + t.Error("Fake IP should be IPv4") + } + + // Check it's in the correct range + if fakeIP.As4()[0] != 240 { + t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String()) + } + + // Should return same fake IP for same real IP + fakeIP2, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to get existing fake IP: %v", err) + } + + if fakeIP.Compare(fakeIP2) != 0 { + t.Errorf("Expected same fake IP for same real IP, got %s and %s", fakeIP.String(), fakeIP2.String()) + } +} + +func TestAllocateFakeIPIPv6Rejection(t *testing.T) { + manager := NewManager() + realIPv6 := netip.MustParseAddr("2001:db8::1") + + _, err := manager.AllocateFakeIP(realIPv6) + if err == nil { + t.Error("Expected error for IPv6 address") + } +} + +func TestGetFakeIP(t *testing.T) { + manager := NewManager() + realIP := netip.MustParseAddr("1.1.1.1") + + // Should not exist initially + _, exists := manager.GetFakeIP(realIP) + if exists { + t.Error("Fake IP should not exist before allocation") + } + + // Allocate and check + expectedFakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate: %v", err) + } + + fakeIP, exists := manager.GetFakeIP(realIP) + if !exists { + t.Error("Fake IP should exist after allocation") + } + + if fakeIP.Compare(expectedFakeIP) != 0 { + t.Errorf("Expected %s, got %s", expectedFakeIP.String(), fakeIP.String()) + } +} + +func TestMultipleAllocations(t *testing.T) { + manager := NewManager() + + allocations := make(map[netip.Addr]netip.Addr) + + // Allocate multiple IPs + for i := 1; i <= 100; i++ { + realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)}) + fakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err) + } + + // Check for duplicates + for _, existingFake := range allocations { + if fakeIP.Compare(existingFake) == 0 { + t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String()) + } + } + + allocations[realIP] = fakeIP + } + + // Verify all allocations can be retrieved + for realIP, expectedFake := range allocations { + actualFake, exists := manager.GetFakeIP(realIP) + if !exists { + t.Errorf("Missing allocation for %s", realIP.String()) + } + if actualFake.Compare(expectedFake) != 0 { + t.Errorf("Mismatch for %s: expected %s, got %s", realIP.String(), expectedFake.String(), actualFake.String()) + } + } +} + +func TestGetFakeIPBlock(t *testing.T) { + manager := NewManager() + block := manager.GetFakeIPBlock() + + expected := "240.0.0.0/8" + if block.String() != expected { + t.Errorf("Expected %s, got %s", expected, block.String()) + } +} + +func TestConcurrentAccess(t *testing.T) { + manager := NewManager() + + const numGoroutines = 50 + const allocationsPerGoroutine = 10 + + var wg sync.WaitGroup + results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine) + + // Concurrent allocations + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(goroutineID int) { + defer wg.Done() + for j := 0; j < allocationsPerGoroutine; j++ { + realIP := netip.AddrFrom4([4]byte{192, 168, byte(goroutineID), byte(j)}) + fakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Errorf("Failed to allocate in goroutine %d: %v", goroutineID, err) + return + } + results <- fakeIP + } + }(i) + } + + wg.Wait() + close(results) + + // Check for duplicates + seen := make(map[netip.Addr]bool) + count := 0 + for fakeIP := range results { + if seen[fakeIP] { + t.Errorf("Duplicate fake IP in concurrent test: %s", fakeIP.String()) + } + seen[fakeIP] = true + count++ + } + + if count != numGoroutines*allocationsPerGoroutine { + t.Errorf("Expected %d allocations, got %d", numGoroutines*allocationsPerGoroutine, count) + } +} + +func TestIPExhaustion(t *testing.T) { + // Create a manager with limited range for testing + manager := &Manager{ + nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), + allocated: make(map[netip.Addr]netip.Addr), + fakeToReal: make(map[netip.Addr]netip.Addr), + baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), + maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available + } + + // Allocate all available IPs + realIPs := []netip.Addr{ + netip.MustParseAddr("1.0.0.1"), + netip.MustParseAddr("1.0.0.2"), + netip.MustParseAddr("1.0.0.3"), + } + + for _, realIP := range realIPs { + _, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate fake IP: %v", err) + } + } + + // Try to allocate one more - should fail + _, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4")) + if err == nil { + t.Error("Expected exhaustion error") + } +} + +func TestWrapAround(t *testing.T) { + // Create manager starting near the end of range + manager := &Manager{ + nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}), + allocated: make(map[netip.Addr]netip.Addr), + fakeToReal: make(map[netip.Addr]netip.Addr), + baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), + maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}), + } + + // Allocate the last IP + fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1")) + if err != nil { + t.Fatalf("Failed to allocate first IP: %v", err) + } + + if fakeIP1.String() != "240.0.0.254" { + t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String()) + } + + // Next allocation should wrap around to the beginning + fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2")) + if err != nil { + t.Fatalf("Failed to allocate second IP: %v", err) + } + + if fakeIP2.String() != "240.0.0.1" { + t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String()) + } +} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 919bf25e3..e0974ab2a 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -8,9 +8,11 @@ import ( "net/netip" "net/url" "runtime" + "slices" "sync" "time" + "github.com/google/uuid" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" @@ -24,6 +26,8 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/client" + "github.com/netbirdio/netbird/client/internal/routemanager/common" + "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" @@ -49,7 +53,7 @@ type Manager interface { GetClientRoutesWithNetID() map[route.NetID][]*route.Route SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string - EnableServerRouter(firewall firewall.Manager) error + SetFirewall(firewall.Manager) error Stop(stateManager *statemanager.Manager) } @@ -63,6 +67,7 @@ type ManagerConfig struct { InitialRoutes []*route.Route StateManager *statemanager.Manager DNSServer dns.Server + DNSFeatureFlag bool PeerStore *peerstore.Store DisableClientRoutes bool DisableServerRoutes bool @@ -89,11 +94,13 @@ type DefaultManager struct { // clientRoutes is the most recent list of clientRoutes received from the Management Service clientRoutes route.HAMap dnsServer dns.Server + firewall firewall.Manager peerStore *peerstore.Store useNewDNSRoute bool disableClientRoutes bool disableServerRoutes bool activeRoutes map[route.HAUniqueID]client.RouteHandler + fakeIPManager *fakeip.Manager } func NewManager(config ManagerConfig) *DefaultManager { @@ -129,11 +136,31 @@ func NewManager(config ManagerConfig) *DefaultManager { } if runtime.GOOS == "android" { - cr := dm.initialClientRoutes(config.InitialRoutes) - dm.notifier.SetInitialClientRoutes(cr) + dm.setupAndroidRoutes(config) } return dm } +func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) { + cr := m.initialClientRoutes(config.InitialRoutes) + + routesForComparison := slices.Clone(cr) + + if config.DNSFeatureFlag { + m.fakeIPManager = fakeip.NewManager() + + id := uuid.NewString() + fakeIPRoute := &route.Route{ + ID: route.ID(id), + Network: m.fakeIPManager.GetFakeIPBlock(), + NetID: route.NetID(id), + Peer: m.pubKey, + NetworkType: route.IPv4Network, + } + cr = append(cr, fakeIPRoute) + } + + m.notifier.SetInitialClientRoutes(cr, routesForComparison) +} func (m *DefaultManager) setupRefCounters(useNoop bool) { m.routeRefCounter = refcounter.New( @@ -222,16 +249,16 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector { return routeselector.NewRouteSelector() } -func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { - if m.disableServerRoutes { +// SetFirewall sets the firewall manager for the DefaultManager +// Not thread-safe, should be called before starting the manager +func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error { + m.firewall = firewall + + if m.disableServerRoutes || firewall == nil { log.Info("server routes are disabled") return nil } - if firewall == nil { - return errors.New("firewall manager is not set") - } - var err error m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) if err != nil { @@ -299,17 +326,20 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error { } for id, route := range toAdd { - handler := client.HandlerFromRoute( - route, - m.routeRefCounter, - m.allowedIPsRefCounter, - m.dnsRouteInterval, - m.statusRecorder, - m.wgInterface, - m.dnsServer, - m.peerStore, - m.useNewDNSRoute, - ) + params := common.HandlerParams{ + Route: route, + RouteRefCounter: m.routeRefCounter, + AllowedIPsRefCounter: m.allowedIPsRefCounter, + DnsRouterInterval: m.dnsRouteInterval, + StatusRecorder: m.statusRecorder, + WgInterface: m.wgInterface, + DnsServer: m.dnsServer, + PeerStore: m.peerStore, + UseNewDNSRoute: m.useNewDNSRoute, + Firewall: m.firewall, + FakeIPManager: m.fakeIPManager, + } + handler := client.HandlerFromRoute(params) if err := handler.AddRoute(m.ctx); err != nil { merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err)) continue @@ -517,6 +547,7 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro for _, routes := range crMap { rs = append(rs, routes...) } + return rs } diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 742294cdf..4e182f82c 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -87,7 +87,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList } -func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error { +func (m *MockManager) SetFirewall(firewall.Manager) error { panic("implement me") } diff --git a/client/internal/routemanager/notifier/notifier.go b/client/internal/routemanager/notifier/notifier.go deleted file mode 100644 index 3cc7c3308..000000000 --- a/client/internal/routemanager/notifier/notifier.go +++ /dev/null @@ -1,124 +0,0 @@ -package notifier - -import ( - "net/netip" - "runtime" - "sort" - "strings" - "sync" - - "github.com/netbirdio/netbird/client/internal/listener" - "github.com/netbirdio/netbird/route" -) - -type Notifier struct { - initialRouteRanges []string - routeRanges []string - - listener listener.NetworkChangeListener - listenerMux sync.Mutex -} - -func NewNotifier() *Notifier { - return &Notifier{} -} - -func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { - n.listenerMux.Lock() - defer n.listenerMux.Unlock() - n.listener = listener -} - -func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) { - nets := make([]string, 0) - for _, r := range clientRoutes { - if r.IsDynamic() { - continue - } - nets = append(nets, r.Network.String()) - } - sort.Strings(nets) - n.initialRouteRanges = nets -} - -func (n *Notifier) OnNewRoutes(idMap route.HAMap) { - if runtime.GOOS != "android" { - return - } - - var newNets []string - for _, routes := range idMap { - for _, r := range routes { - if r.IsDynamic() { - continue - } - newNets = append(newNets, r.Network.String()) - } - } - - sort.Strings(newNets) - if !n.hasDiff(n.initialRouteRanges, newNets) { - return - } - - n.routeRanges = newNets - n.notify() -} - -// OnNewPrefixes is called from iOS only -func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { - newNets := make([]string, 0) - for _, prefix := range prefixes { - newNets = append(newNets, prefix.String()) - } - - sort.Strings(newNets) - if !n.hasDiff(n.routeRanges, newNets) { - return - } - - n.routeRanges = newNets - n.notify() -} - -func (n *Notifier) notify() { - n.listenerMux.Lock() - defer n.listenerMux.Unlock() - if n.listener == nil { - return - } - - go func(l listener.NetworkChangeListener) { - l.OnNetworkChanged(strings.Join(addIPv6RangeIfNeeded(n.routeRanges), ",")) - }(n.listener) -} - -func (n *Notifier) hasDiff(a []string, b []string) bool { - if len(a) != len(b) { - return true - } - for i, v := range a { - if v != b[i] { - return true - } - } - return false -} - -func (n *Notifier) GetInitialRouteRanges() []string { - return addIPv6RangeIfNeeded(n.initialRouteRanges) -} - -// addIPv6RangeIfNeeded returns the input ranges with the default IPv6 range when there is an IPv4 default route. -func addIPv6RangeIfNeeded(inputRanges []string) []string { - ranges := inputRanges - for _, r := range inputRanges { - // we are intentionally adding the ipv6 default range in case of ipv4 default range - // to ensure that all traffic is managed by the tunnel interface on android - if r == "0.0.0.0/0" { - ranges = append(ranges, "::/0") - break - } - } - return ranges -} diff --git a/client/internal/routemanager/notifier/notifier_android.go b/client/internal/routemanager/notifier/notifier_android.go new file mode 100644 index 000000000..dec0af87c --- /dev/null +++ b/client/internal/routemanager/notifier/notifier_android.go @@ -0,0 +1,127 @@ +//go:build android + +package notifier + +import ( + "net/netip" + "slices" + "sort" + "strings" + "sync" + + "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/route" +) + +type Notifier struct { + initialRoutes []*route.Route + currentRoutes []*route.Route + + listener listener.NetworkChangeListener + listenerMux sync.Mutex +} + +func NewNotifier() *Notifier { + return &Notifier{} +} + +func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { + n.listenerMux.Lock() + defer n.listenerMux.Unlock() + n.listener = listener +} + +func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) { + // initialRoutes contains fake IP block for interface configuration + filteredInitial := make([]*route.Route, 0) + for _, r := range initialRoutes { + if r.IsDynamic() { + continue + } + filteredInitial = append(filteredInitial, r) + } + n.initialRoutes = filteredInitial + + // routesForComparison excludes fake IP block for comparison with new routes + filteredComparison := make([]*route.Route, 0) + for _, r := range routesForComparison { + if r.IsDynamic() { + continue + } + filteredComparison = append(filteredComparison, r) + } + n.currentRoutes = filteredComparison +} + +func (n *Notifier) OnNewRoutes(idMap route.HAMap) { + var newRoutes []*route.Route + for _, routes := range idMap { + for _, r := range routes { + if r.IsDynamic() { + continue + } + newRoutes = append(newRoutes, r) + } + } + + if !n.hasRouteDiff(n.currentRoutes, newRoutes) { + return + } + + n.currentRoutes = newRoutes + n.notify() +} + +func (n *Notifier) OnNewPrefixes([]netip.Prefix) { + // Not used on Android +} + +func (n *Notifier) notify() { + n.listenerMux.Lock() + defer n.listenerMux.Unlock() + if n.listener == nil { + return + } + + routeStrings := n.routesToStrings(n.currentRoutes) + sort.Strings(routeStrings) + go func(l listener.NetworkChangeListener) { + l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ",")) + }(n.listener) +} + +func (n *Notifier) routesToStrings(routes []*route.Route) []string { + nets := make([]string, 0, len(routes)) + for _, r := range routes { + nets = append(nets, r.NetString()) + } + return nets +} + +func (n *Notifier) hasRouteDiff(a []*route.Route, b []*route.Route) bool { + slices.SortFunc(a, func(x, y *route.Route) int { + return strings.Compare(x.NetString(), y.NetString()) + }) + slices.SortFunc(b, func(x, y *route.Route) int { + return strings.Compare(x.NetString(), y.NetString()) + }) + + return !slices.EqualFunc(a, b, func(x, y *route.Route) bool { + return x.NetString() == y.NetString() + }) +} + +func (n *Notifier) GetInitialRouteRanges() []string { + initialStrings := n.routesToStrings(n.initialRoutes) + sort.Strings(initialStrings) + return n.addIPv6RangeIfNeeded(initialStrings, n.initialRoutes) +} + +func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string, routes []*route.Route) []string { + for _, r := range routes { + if r.Network.Addr().Is4() && r.Network.Bits() == 0 { + return append(slices.Clone(inputRanges), "::/0") + } + } + return inputRanges +} diff --git a/client/internal/routemanager/notifier/notifier_ios.go b/client/internal/routemanager/notifier/notifier_ios.go new file mode 100644 index 000000000..bb125cfa4 --- /dev/null +++ b/client/internal/routemanager/notifier/notifier_ios.go @@ -0,0 +1,80 @@ +//go:build ios + +package notifier + +import ( + "net/netip" + "slices" + "sort" + "strings" + "sync" + + "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/route" +) + +type Notifier struct { + currentPrefixes []string + + listener listener.NetworkChangeListener + listenerMux sync.Mutex +} + +func NewNotifier() *Notifier { + return &Notifier{} +} + +func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { + n.listenerMux.Lock() + defer n.listenerMux.Unlock() + n.listener = listener +} + +func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) { + // iOS doesn't care about initial routes +} + +func (n *Notifier) OnNewRoutes(route.HAMap) { + // Not used on iOS +} + +func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { + newNets := make([]string, 0) + for _, prefix := range prefixes { + newNets = append(newNets, prefix.String()) + } + + sort.Strings(newNets) + + if slices.Equal(n.currentPrefixes, newNets) { + return + } + + n.currentPrefixes = newNets + n.notify() +} + +func (n *Notifier) notify() { + n.listenerMux.Lock() + defer n.listenerMux.Unlock() + if n.listener == nil { + return + } + + go func(l listener.NetworkChangeListener) { + l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(n.currentPrefixes), ",")) + }(n.listener) +} + +func (n *Notifier) GetInitialRouteRanges() []string { + return nil +} + +func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string) []string { + for _, r := range inputRanges { + if r == "0.0.0.0/0" { + return append(slices.Clone(inputRanges), "::/0") + } + } + return inputRanges +} diff --git a/client/internal/routemanager/notifier/notifier_other.go b/client/internal/routemanager/notifier/notifier_other.go new file mode 100644 index 000000000..77045b839 --- /dev/null +++ b/client/internal/routemanager/notifier/notifier_other.go @@ -0,0 +1,36 @@ +//go:build !android && !ios + +package notifier + +import ( + "net/netip" + + "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/route" +) + +type Notifier struct{} + +func NewNotifier() *Notifier { + return &Notifier{} +} + +func (n *Notifier) SetListener(listener listener.NetworkChangeListener) { + // Not used on non-mobile platforms +} + +func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) { + // Not used on non-mobile platforms +} + +func (n *Notifier) OnNewRoutes(idMap route.HAMap) { + // Not used on non-mobile platforms +} + +func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { + // Not used on non-mobile platforms +} + +func (n *Notifier) GetInitialRouteRanges() []string { + return []string{} +} \ No newline at end of file diff --git a/client/internal/routemanager/static/route.go b/client/internal/routemanager/static/route.go index c8b9338e0..d480fdf00 100644 --- a/client/internal/routemanager/static/route.go +++ b/client/internal/routemanager/static/route.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/route" ) @@ -16,11 +17,11 @@ type Route struct { allowedIPsRefcounter *refcounter.AllowedIPsRefCounter } -func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route { +func NewRoute(params common.HandlerParams) *Route { return &Route{ - route: rt, - routeRefCounter: routeRefCounter, - allowedIPsRefcounter: allowedIPsRefCounter, + route: params.Route, + routeRefCounter: params.RouteRefCounter, + allowedIPsRefcounter: params.AllowedIPsRefCounter, } }