mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-09 07:15:15 +02:00
[client] Implement dns routes for Android (#3989)
This commit is contained in:
@ -203,8 +203,10 @@ func (c *Client) Networks() *NetworkArray {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if routes[0].IsDynamic() {
|
r := routes[0]
|
||||||
continue
|
netStr := r.Network.String()
|
||||||
|
if r.IsDynamic() {
|
||||||
|
netStr = r.Domains.SafeString()
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||||
@ -214,7 +216,7 @@ func (c *Client) Networks() *NetworkArray {
|
|||||||
}
|
}
|
||||||
network := Network{
|
network := Network{
|
||||||
Name: string(id),
|
Name: string(id),
|
||||||
Network: routes[0].Network.String(),
|
Network: netStr,
|
||||||
Peer: peer.FQDN,
|
Peer: peer.FQDN,
|
||||||
Status: peer.ConnStatus.String(),
|
Status: peer.ConnStatus.String(),
|
||||||
}
|
}
|
||||||
|
@ -104,6 +104,12 @@ type Manager struct {
|
|||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
|
|
||||||
blockRule firewall.Rule
|
blockRule firewall.Rule
|
||||||
|
|
||||||
|
// Internal 1:1 DNAT
|
||||||
|
dnatEnabled atomic.Bool
|
||||||
|
dnatMappings map[netip.Addr]netip.Addr
|
||||||
|
dnatMutex sync.RWMutex
|
||||||
|
dnatBiMap *biDNATMap
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@ -189,6 +195,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
|
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||||
}
|
}
|
||||||
m.routingEnabled.Store(false)
|
m.routingEnabled.Store(false)
|
||||||
|
|
||||||
@ -519,22 +526,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
// AddDNATRule adds a DNAT rule
|
|
||||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return nil, errNatNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.AddDNATRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteDNATRule deletes a DNAT rule
|
|
||||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return errNatNotSupported
|
|
||||||
}
|
|
||||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSet updates the rule destinations associated with the given set
|
// UpdateSet updates the rule destinations associated with the given set
|
||||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
@ -581,14 +572,14 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
// FilterOutBound filters outgoing packets
|
||||||
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
|
||||||
return m.processOutgoingHooks(packetData, size)
|
return m.filterOutbound(packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming filter incoming packets
|
// FilterInbound filters incoming packets
|
||||||
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
|
func (m *Manager) FilterInbound(packetData []byte, size int) bool {
|
||||||
return m.dropFilter(packetData, size)
|
return m.filterInbound(packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLocalIPs updates the list of local IPs
|
// UpdateLocalIPs updates the list of local IPs
|
||||||
@ -596,7 +587,7 @@ func (m *Manager) UpdateLocalIPs() error {
|
|||||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@ -618,8 +609,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// for netflow we keep track even if the firewall is stateless
|
|
||||||
m.trackOutbound(d, srcIP, dstIP, size)
|
m.trackOutbound(d, srcIP, dstIP, size)
|
||||||
|
m.translateOutboundDNAT(packetData, d)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -723,9 +714,9 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter implements filtering logic for incoming packets.
|
// filterInbound implements filtering logic for incoming packets.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@ -747,8 +738,15 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// For all inbound traffic, first check if it matches a tracked connection.
|
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||||
// This must happen before any other filtering because the packets are statefully tracked.
|
// Re-decode after translation to get original addresses
|
||||||
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
|
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
srcIP, dstIP = m.extractIPs(d)
|
||||||
|
}
|
||||||
|
|
||||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
@ -188,13 +188,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
|
|
||||||
// For stateful scenarios, establish the connection
|
// For stateful scenarios, establish the connection
|
||||||
if sc.stateful {
|
if sc.stateful {
|
||||||
manager.processOutgoingHooks(outbound, 0)
|
manager.filterOutbound(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Measure inbound packet processing
|
// Measure inbound packet processing
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, 0)
|
manager.filterInbound(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -220,7 +220,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||||
manager.processOutgoingHooks(outbound, 0)
|
manager.filterOutbound(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test packet
|
// Test packet
|
||||||
@ -228,11 +228,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
// First establish our test connection
|
// First establish our test connection
|
||||||
manager.processOutgoingHooks(testOut, 0)
|
manager.filterOutbound(testOut, 0)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(testIn, 0)
|
manager.filterInbound(testIn, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -263,12 +263,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
if sc.established {
|
if sc.established {
|
||||||
manager.processOutgoingHooks(outbound, 0)
|
manager.filterOutbound(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, 0)
|
manager.filterInbound(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -426,25 +426,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
// For stateful cases and established connections
|
// For stateful cases and established connections
|
||||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||||
manager.processOutgoingHooks(outbound, 0)
|
manager.filterOutbound(outbound, 0)
|
||||||
|
|
||||||
// For TCP post-handshake, simulate full handshake
|
// For TCP post-handshake, simulate full handshake
|
||||||
if sc.state == "post_handshake" {
|
if sc.state == "post_handshake" {
|
||||||
// SYN
|
// SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn, 0)
|
manager.filterOutbound(syn, 0)
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, 0)
|
manager.filterInbound(synack, 0)
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack, 0)
|
manager.filterOutbound(ack, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, 0)
|
manager.filterInbound(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -568,17 +568,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// Initial SYN
|
// Initial SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn, 0)
|
manager.filterOutbound(syn, 0)
|
||||||
|
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, 0)
|
manager.filterInbound(synack, 0)
|
||||||
|
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack, 0)
|
manager.filterOutbound(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare test packets simulating bidirectional traffic
|
// Prepare test packets simulating bidirectional traffic
|
||||||
@ -599,9 +599,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
// First outbound data
|
// First outbound data
|
||||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
manager.filterOutbound(outPackets[connIdx], 0)
|
||||||
// Then inbound response - this is what we're actually measuring
|
// Then inbound response - this is what we're actually measuring
|
||||||
manager.dropFilter(inPackets[connIdx], 0)
|
manager.filterInbound(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -700,19 +700,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Connection establishment
|
// Connection establishment
|
||||||
manager.processOutgoingHooks(p.syn, 0)
|
manager.filterOutbound(p.syn, 0)
|
||||||
manager.dropFilter(p.synAck, 0)
|
manager.filterInbound(p.synAck, 0)
|
||||||
manager.processOutgoingHooks(p.ack, 0)
|
manager.filterOutbound(p.ack, 0)
|
||||||
|
|
||||||
// Data transfer
|
// Data transfer
|
||||||
manager.processOutgoingHooks(p.request, 0)
|
manager.filterOutbound(p.request, 0)
|
||||||
manager.dropFilter(p.response, 0)
|
manager.filterInbound(p.response, 0)
|
||||||
|
|
||||||
// Connection teardown
|
// Connection teardown
|
||||||
manager.processOutgoingHooks(p.finClient, 0)
|
manager.filterOutbound(p.finClient, 0)
|
||||||
manager.dropFilter(p.ackServer, 0)
|
manager.filterInbound(p.ackServer, 0)
|
||||||
manager.dropFilter(p.finServer, 0)
|
manager.filterInbound(p.finServer, 0)
|
||||||
manager.processOutgoingHooks(p.ackClient, 0)
|
manager.filterOutbound(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -760,15 +760,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
for i := 0; i < sc.connCount; i++ {
|
for i := 0; i < sc.connCount; i++ {
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn, 0)
|
manager.filterOutbound(syn, 0)
|
||||||
|
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, 0)
|
manager.filterInbound(synack, 0)
|
||||||
|
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack, 0)
|
manager.filterOutbound(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pre-generate test packets
|
// Pre-generate test packets
|
||||||
@ -790,8 +790,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
counter++
|
counter++
|
||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
manager.filterOutbound(outPackets[connIdx], 0)
|
||||||
manager.dropFilter(inPackets[connIdx], 0)
|
manager.filterInbound(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@ -879,17 +879,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Full connection lifecycle
|
// Full connection lifecycle
|
||||||
manager.processOutgoingHooks(p.syn, 0)
|
manager.filterOutbound(p.syn, 0)
|
||||||
manager.dropFilter(p.synAck, 0)
|
manager.filterInbound(p.synAck, 0)
|
||||||
manager.processOutgoingHooks(p.ack, 0)
|
manager.filterOutbound(p.ack, 0)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.request, 0)
|
manager.filterOutbound(p.request, 0)
|
||||||
manager.dropFilter(p.response, 0)
|
manager.filterInbound(p.response, 0)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.finClient, 0)
|
manager.filterOutbound(p.finClient, 0)
|
||||||
manager.dropFilter(p.ackServer, 0)
|
manager.filterInbound(p.ackServer, 0)
|
||||||
manager.dropFilter(p.finServer, 0)
|
manager.filterInbound(p.finServer, 0)
|
||||||
manager.processOutgoingHooks(p.ackClient, 0)
|
manager.filterOutbound(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
@ -462,7 +462,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||||
isDropped := manager.DropIncoming(packet, 0)
|
isDropped := manager.FilterInbound(packet, 0)
|
||||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -509,7 +509,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
isDropped := manager.DropIncoming(packet, 0)
|
isDropped := manager.FilterInbound(packet, 0)
|
||||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -1233,7 +1233,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||||
|
|
||||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
|
||||||
// to the forwarder
|
// to the forwarder
|
||||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
require.Equal(t, tc.shouldPass, isAllowed)
|
require.Equal(t, tc.shouldPass, isAllowed)
|
@ -321,7 +321,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes(), 0) {
|
if m.filterInbound(buf.Bytes(), 0) {
|
||||||
t.Errorf("expected packet to be accepted")
|
t.Errorf("expected packet to be accepted")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -447,7 +447,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test hook gets called
|
// Test hook gets called
|
||||||
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
result := manager.filterOutbound(buf.Bytes(), 0)
|
||||||
require.True(t, result)
|
require.True(t, result)
|
||||||
require.True(t, hookCalled)
|
require.True(t, hookCalled)
|
||||||
|
|
||||||
@ -457,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
result = manager.filterOutbound(buf.Bytes(), 0)
|
||||||
require.False(t, result)
|
require.False(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -553,7 +553,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Process outbound packet and verify connection tracking
|
// Process outbound packet and verify connection tracking
|
||||||
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
drop := manager.FilterOutbound(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
@ -620,7 +620,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
for _, cp := range checkPoints {
|
for _, cp := range checkPoints {
|
||||||
time.Sleep(cp.sleep)
|
time.Sleep(cp.sleep)
|
||||||
|
|
||||||
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
drop = manager.filterInbound(inboundBuf.Bytes(), 0)
|
||||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||||
|
|
||||||
// If the connection should still be valid, verify it exists
|
// If the connection should still be valid, verify it exists
|
||||||
@ -669,7 +669,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new outbound connection for invalid tests
|
// Create a new outbound connection for invalid tests
|
||||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
drop = manager.filterOutbound(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||||
|
|
||||||
for _, tc := range invalidCases {
|
for _, tc := range invalidCases {
|
||||||
@ -691,7 +691,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify the invalid packet is dropped
|
// Verify the invalid packet is dropped
|
||||||
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
drop = manager.filterInbound(testBuf.Bytes(), 0)
|
||||||
require.True(t, drop, tc.description)
|
require.True(t, drop, tc.description)
|
||||||
})
|
})
|
||||||
}
|
}
|
408
client/firewall/uspfilter/nat.go
Normal file
408
client/firewall/uspfilter/nat.go
Normal file
@ -0,0 +1,408 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||||
|
|
||||||
|
func ipv4Checksum(header []byte) uint16 {
|
||||||
|
if len(header) < 20 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var sum1, sum2 uint32
|
||||||
|
|
||||||
|
// Parallel processing - unroll and compute two sums simultaneously
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
|
||||||
|
// Skip checksum field at [10:12]
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[18:20]))
|
||||||
|
|
||||||
|
sum := sum1 + sum2
|
||||||
|
|
||||||
|
// Handle remaining bytes for headers > 20 bytes
|
||||||
|
for i := 20; i < len(header)-1; i += 2 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(header)%2 == 1 {
|
||||||
|
sum += uint32(header[len(header)-1]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optimized carry fold - single iteration handles most cases
|
||||||
|
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||||
|
if sum > 0xFFFF {
|
||||||
|
sum++
|
||||||
|
}
|
||||||
|
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func icmpChecksum(data []byte) uint16 {
|
||||||
|
var sum1, sum2, sum3, sum4 uint32
|
||||||
|
i := 0
|
||||||
|
|
||||||
|
// Process 16 bytes at once with 4 parallel accumulators
|
||||||
|
for i <= len(data)-16 {
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
|
||||||
|
sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6]))
|
||||||
|
sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12]))
|
||||||
|
sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14]))
|
||||||
|
sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16]))
|
||||||
|
i += 16
|
||||||
|
}
|
||||||
|
|
||||||
|
sum := sum1 + sum2 + sum3 + sum4
|
||||||
|
|
||||||
|
// Handle remaining bytes
|
||||||
|
for i < len(data)-1 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||||
|
i += 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data)%2 == 1 {
|
||||||
|
sum += uint32(data[len(data)-1]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||||
|
if sum > 0xFFFF {
|
||||||
|
sum++
|
||||||
|
}
|
||||||
|
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
type biDNATMap struct {
|
||||||
|
forward map[netip.Addr]netip.Addr
|
||||||
|
reverse map[netip.Addr]netip.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBiDNATMap() *biDNATMap {
|
||||||
|
return &biDNATMap{
|
||||||
|
forward: make(map[netip.Addr]netip.Addr),
|
||||||
|
reverse: make(map[netip.Addr]netip.Addr),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *biDNATMap) set(original, translated netip.Addr) {
|
||||||
|
b.forward[original] = translated
|
||||||
|
b.reverse[translated] = original
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *biDNATMap) delete(original netip.Addr) {
|
||||||
|
if translated, exists := b.forward[original]; exists {
|
||||||
|
delete(b.forward, original)
|
||||||
|
delete(b.reverse, translated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
|
||||||
|
translated, exists := b.forward[original]
|
||||||
|
return translated, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
|
||||||
|
original, exists := b.reverse[translated]
|
||||||
|
return original, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
|
||||||
|
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
|
||||||
|
return fmt.Errorf("invalid IP addresses")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.localipmanager.IsLocalIP(translatedAddr) {
|
||||||
|
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMutex.Lock()
|
||||||
|
defer m.dnatMutex.Unlock()
|
||||||
|
|
||||||
|
// Initialize both maps together if either is nil
|
||||||
|
if m.dnatMappings == nil || m.dnatBiMap == nil {
|
||||||
|
m.dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||||
|
m.dnatBiMap = newBiDNATMap()
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMappings[originalAddr] = translatedAddr
|
||||||
|
m.dnatBiMap.set(originalAddr, translatedAddr)
|
||||||
|
|
||||||
|
if len(m.dnatMappings) == 1 {
|
||||||
|
m.dnatEnabled.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
|
||||||
|
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||||
|
m.dnatMutex.Lock()
|
||||||
|
defer m.dnatMutex.Unlock()
|
||||||
|
|
||||||
|
if _, exists := m.dnatMappings[originalAddr]; !exists {
|
||||||
|
return fmt.Errorf("mapping not found for: %s", originalAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(m.dnatMappings, originalAddr)
|
||||||
|
m.dnatBiMap.delete(originalAddr)
|
||||||
|
if len(m.dnatMappings) == 0 {
|
||||||
|
m.dnatEnabled.Store(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDNATTranslation returns the translated address if a mapping exists
|
||||||
|
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return addr, false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMutex.RLock()
|
||||||
|
translated, exists := m.dnatBiMap.getTranslated(addr)
|
||||||
|
m.dnatMutex.RUnlock()
|
||||||
|
return translated, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// findReverseDNATMapping finds original address for return traffic
|
||||||
|
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return translatedAddr, false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMutex.RLock()
|
||||||
|
original, exists := m.dnatBiMap.getOriginal(translatedAddr)
|
||||||
|
m.dnatMutex.RUnlock()
|
||||||
|
return original, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// translateOutboundDNAT applies DNAT translation to outbound packets
|
||||||
|
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||||
|
|
||||||
|
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
|
||||||
|
m.logger.Error("Failed to rewrite packet destination: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// translateInboundReverse applies reverse DNAT to inbound return traffic
|
||||||
|
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||||
|
|
||||||
|
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
|
||||||
|
m.logger.Error("Failed to rewrite packet source: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewritePacketDestination replaces destination IP in the packet
|
||||||
|
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||||
|
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||||
|
return ErrIPv4Only
|
||||||
|
}
|
||||||
|
|
||||||
|
var oldDst [4]byte
|
||||||
|
copy(oldDst[:], packetData[16:20])
|
||||||
|
newDst := newIP.As4()
|
||||||
|
|
||||||
|
copy(packetData[16:20], newDst[:])
|
||||||
|
|
||||||
|
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||||
|
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||||
|
return fmt.Errorf("invalid IP header length")
|
||||||
|
}
|
||||||
|
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||||
|
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||||
|
|
||||||
|
if len(d.decoded) > 1 {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewritePacketSource replaces the source IP address in the packet
|
||||||
|
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||||
|
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||||
|
return ErrIPv4Only
|
||||||
|
}
|
||||||
|
|
||||||
|
var oldSrc [4]byte
|
||||||
|
copy(oldSrc[:], packetData[12:16])
|
||||||
|
newSrc := newIP.As4()
|
||||||
|
|
||||||
|
copy(packetData[12:16], newSrc[:])
|
||||||
|
|
||||||
|
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||||
|
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||||
|
return fmt.Errorf("invalid IP header length")
|
||||||
|
}
|
||||||
|
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||||
|
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||||
|
|
||||||
|
if len(d.decoded) > 1 {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||||
|
tcpStart := ipHeaderLen
|
||||||
|
if len(packetData) < tcpStart+18 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
checksumOffset := tcpStart + 16
|
||||||
|
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||||
|
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||||
|
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||||
|
udpStart := ipHeaderLen
|
||||||
|
if len(packetData) < udpStart+8 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
checksumOffset := udpStart + 6
|
||||||
|
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||||
|
|
||||||
|
if oldChecksum == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||||
|
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||||
|
icmpStart := ipHeaderLen
|
||||||
|
if len(packetData) < icmpStart+8 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
icmpData := packetData[icmpStart:]
|
||||||
|
binary.BigEndian.PutUint16(icmpData[2:4], 0)
|
||||||
|
checksum := icmpChecksum(icmpData)
|
||||||
|
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// incrementalUpdate performs incremental checksum update per RFC 1624
|
||||||
|
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||||
|
sum := uint32(^oldChecksum)
|
||||||
|
|
||||||
|
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||||
|
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||||
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||||
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
||||||
|
} else {
|
||||||
|
// Fallback for other lengths
|
||||||
|
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||||
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
|
||||||
|
}
|
||||||
|
if len(oldBytes)%2 == 1 {
|
||||||
|
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(newBytes)-1; i += 2 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
|
||||||
|
}
|
||||||
|
if len(newBytes)%2 == 1 {
|
||||||
|
sum += uint32(newBytes[len(newBytes)-1]) << 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||||
|
if sum > 0xFFFF {
|
||||||
|
sum++
|
||||||
|
}
|
||||||
|
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
|
||||||
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return nil, errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.AddDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
|
||||||
|
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||||
|
}
|
416
client/firewall/uspfilter/nat_bench_test.go
Normal file
416
client/firewall/uspfilter/nat_bench_test.go
Normal file
@ -0,0 +1,416 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BenchmarkDNATTranslation measures the performance of DNAT operations
|
||||||
|
func BenchmarkDNATTranslation(b *testing.B) {
|
||||||
|
scenarios := []struct {
|
||||||
|
name string
|
||||||
|
proto layers.IPProtocol
|
||||||
|
setupDNAT bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "tcp_with_dnat",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
setupDNAT: true,
|
||||||
|
description: "TCP packet with DNAT translation enabled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp_without_dnat",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
setupDNAT: false,
|
||||||
|
description: "TCP packet without DNAT (baseline)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "udp_with_dnat",
|
||||||
|
proto: layers.IPProtocolUDP,
|
||||||
|
setupDNAT: true,
|
||||||
|
description: "UDP packet with DNAT translation enabled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "udp_without_dnat",
|
||||||
|
proto: layers.IPProtocolUDP,
|
||||||
|
setupDNAT: false,
|
||||||
|
description: "UDP packet without DNAT (baseline)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "icmp_with_dnat",
|
||||||
|
proto: layers.IPProtocolICMPv4,
|
||||||
|
setupDNAT: true,
|
||||||
|
description: "ICMP packet with DNAT translation enabled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "icmp_without_dnat",
|
||||||
|
proto: layers.IPProtocolICMPv4,
|
||||||
|
setupDNAT: false,
|
||||||
|
description: "ICMP packet without DNAT (baseline)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set logger to error level to reduce noise during benchmarking
|
||||||
|
manager.SetLogLevel(log.ErrorLevel)
|
||||||
|
defer func() {
|
||||||
|
// Restore to info level after benchmark
|
||||||
|
manager.SetLogLevel(log.InfoLevel)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Setup DNAT mapping if needed
|
||||||
|
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||||
|
|
||||||
|
if sc.setupDNAT {
|
||||||
|
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test packets
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
outboundPacket := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
|
||||||
|
|
||||||
|
// Pre-establish connection for reverse DNAT test
|
||||||
|
if sc.setupDNAT {
|
||||||
|
manager.filterOutbound(outboundPacket, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
// Benchmark outbound DNAT translation
|
||||||
|
b.Run("outbound", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Create fresh packet each time since translation modifies it
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
|
||||||
|
manager.filterOutbound(packet, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Benchmark inbound reverse DNAT translation
|
||||||
|
if sc.setupDNAT {
|
||||||
|
b.Run("inbound_reverse", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Create fresh packet each time since translation modifies it
|
||||||
|
packet := generateDNATTestPacket(b, translatedIP, srcIP, sc.proto, 80, 12345)
|
||||||
|
manager.filterInbound(packet, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
|
||||||
|
func BenchmarkDNATConcurrency(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set logger to error level to reduce noise during benchmarking
|
||||||
|
manager.SetLogLevel(log.ErrorLevel)
|
||||||
|
defer func() {
|
||||||
|
// Restore to info level after benchmark
|
||||||
|
manager.SetLogLevel(log.InfoLevel)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Setup multiple DNAT mappings
|
||||||
|
numMappings := 100
|
||||||
|
originalIPs := make([]netip.Addr, numMappings)
|
||||||
|
translatedIPs := make([]netip.Addr, numMappings)
|
||||||
|
|
||||||
|
for i := 0; i < numMappings; i++ {
|
||||||
|
originalIPs[i] = netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
|
||||||
|
translatedIPs[i] = netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
|
||||||
|
err := manager.AddInternalDNATMapping(originalIPs[i], translatedIPs[i])
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
|
||||||
|
// Pre-generate packets
|
||||||
|
outboundPackets := make([][]byte, numMappings)
|
||||||
|
inboundPackets := make([][]byte, numMappings)
|
||||||
|
for i := 0; i < numMappings; i++ {
|
||||||
|
outboundPackets[i] = generateDNATTestPacket(b, srcIP, originalIPs[i], layers.IPProtocolTCP, 12345, 80)
|
||||||
|
inboundPackets[i] = generateDNATTestPacket(b, translatedIPs[i], srcIP, layers.IPProtocolTCP, 80, 12345)
|
||||||
|
// Establish connections
|
||||||
|
manager.filterOutbound(outboundPackets[i], 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
b.Run("concurrent_outbound", func(b *testing.B) {
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
i := 0
|
||||||
|
for pb.Next() {
|
||||||
|
idx := i % numMappings
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, originalIPs[idx], layers.IPProtocolTCP, 12345, 80)
|
||||||
|
manager.filterOutbound(packet, 0)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("concurrent_inbound", func(b *testing.B) {
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
i := 0
|
||||||
|
for pb.Next() {
|
||||||
|
idx := i % numMappings
|
||||||
|
packet := generateDNATTestPacket(b, translatedIPs[idx], srcIP, layers.IPProtocolTCP, 80, 12345)
|
||||||
|
manager.filterInbound(packet, 0)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkDNATScaling tests how DNAT performance scales with number of mappings
|
||||||
|
func BenchmarkDNATScaling(b *testing.B) {
|
||||||
|
mappingCounts := []int{1, 10, 100, 1000}
|
||||||
|
|
||||||
|
for _, count := range mappingCounts {
|
||||||
|
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set logger to error level to reduce noise during benchmarking
|
||||||
|
manager.SetLogLevel(log.ErrorLevel)
|
||||||
|
defer func() {
|
||||||
|
// Restore to info level after benchmark
|
||||||
|
manager.SetLogLevel(log.InfoLevel)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Setup DNAT mappings
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
originalIP := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
|
||||||
|
translatedIP := netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
|
||||||
|
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with the last mapping added (worst case for lookup)
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
lastOriginal := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", ((count-1)/254)+1, ((count-1)%254)+1))
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, lastOriginal, layers.IPProtocolTCP, 12345, 80)
|
||||||
|
manager.filterOutbound(packet, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateDNATTestPacket creates a test packet for DNAT benchmarking
|
||||||
|
func generateDNATTestPacket(tb testing.TB, srcIP, dstIP netip.Addr, proto layers.IPProtocol, srcPort, dstPort uint16) []byte {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: srcIP.AsSlice(),
|
||||||
|
DstIP: dstIP.AsSlice(),
|
||||||
|
Protocol: proto,
|
||||||
|
}
|
||||||
|
|
||||||
|
var transportLayer gopacket.SerializableLayer
|
||||||
|
switch proto {
|
||||||
|
case layers.IPProtocolTCP:
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(srcPort),
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
SYN: true,
|
||||||
|
}
|
||||||
|
require.NoError(tb, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||||
|
transportLayer = tcp
|
||||||
|
case layers.IPProtocolUDP:
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(srcPort),
|
||||||
|
DstPort: layers.UDPPort(dstPort),
|
||||||
|
}
|
||||||
|
require.NoError(tb, udp.SetNetworkLayerForChecksum(ipv4))
|
||||||
|
transportLayer = udp
|
||||||
|
case layers.IPProtocolICMPv4:
|
||||||
|
icmp := &layers.ICMPv4{
|
||||||
|
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
|
||||||
|
}
|
||||||
|
transportLayer = icmp
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||||
|
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
|
||||||
|
require.NoError(tb, err)
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkChecksumUpdate specifically benchmarks checksum calculation performance
|
||||||
|
func BenchmarkChecksumUpdate(b *testing.B) {
|
||||||
|
// Create test data for checksum calculations
|
||||||
|
testData := make([]byte, 64) // Typical packet size for checksum testing
|
||||||
|
for i := range testData {
|
||||||
|
testData[i] = byte(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Run("ipv4_checksum", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = ipv4Checksum(testData[:20]) // IPv4 header is typically 20 bytes
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("icmp_checksum", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = icmpChecksum(testData)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("incremental_update", func(b *testing.B) {
|
||||||
|
oldBytes := []byte{192, 168, 1, 100}
|
||||||
|
newBytes := []byte{10, 0, 0, 100}
|
||||||
|
oldChecksum := uint16(0x1234)
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = incrementalUpdate(oldChecksum, oldBytes, newBytes)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
|
||||||
|
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set logger to error level to reduce noise during benchmarking
|
||||||
|
manager.SetLogLevel(log.ErrorLevel)
|
||||||
|
defer func() {
|
||||||
|
// Restore to info level after benchmark
|
||||||
|
manager.SetLogLevel(log.InfoLevel)
|
||||||
|
}()
|
||||||
|
|
||||||
|
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
|
||||||
|
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(b, err)
|
||||||
|
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, originalIP, layers.IPProtocolTCP, 12345, 80)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Create fresh packet each time to isolate allocation testing
|
||||||
|
testPacket := make([]byte, len(packet))
|
||||||
|
copy(testPacket, packet)
|
||||||
|
|
||||||
|
// Parse the packet fresh each time to get a clean decoder
|
||||||
|
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
err = d.parser.DecodeLayers(testPacket, &d.decoded)
|
||||||
|
assert.NoError(b, err)
|
||||||
|
|
||||||
|
manager.translateOutboundDNAT(testPacket, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkDirectIPExtraction tests the performance improvement of direct IP extraction
|
||||||
|
func BenchmarkDirectIPExtraction(b *testing.B) {
|
||||||
|
// Create a test packet
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, dstIP, layers.IPProtocolTCP, 12345, 80)
|
||||||
|
|
||||||
|
b.Run("direct_byte_access", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Direct extraction from packet bytes
|
||||||
|
_ = netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("decoder_extraction", func(b *testing.B) {
|
||||||
|
// Create decoder once for comparison
|
||||||
|
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
err := d.parser.DecodeLayers(packet, &d.decoded)
|
||||||
|
assert.NoError(b, err)
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Extract using decoder (traditional method)
|
||||||
|
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||||
|
_ = dst
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkChecksumOptimizations compares optimized vs standard checksum implementations
|
||||||
|
func BenchmarkChecksumOptimizations(b *testing.B) {
|
||||||
|
// Create test IPv4 header (20 bytes)
|
||||||
|
header := make([]byte, 20)
|
||||||
|
for i := range header {
|
||||||
|
header[i] = byte(i)
|
||||||
|
}
|
||||||
|
// Clear checksum field
|
||||||
|
header[10] = 0
|
||||||
|
header[11] = 0
|
||||||
|
|
||||||
|
b.Run("optimized_ipv4_checksum", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = ipv4Checksum(header)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test incremental checksum updates
|
||||||
|
oldIP := []byte{192, 168, 1, 100}
|
||||||
|
newIP := []byte{10, 0, 0, 100}
|
||||||
|
oldChecksum := uint16(0x1234)
|
||||||
|
|
||||||
|
b.Run("optimized_incremental_update", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
145
client/firewall/uspfilter/nat_test.go
Normal file
145
client/firewall/uspfilter/nat_test.go
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
|
||||||
|
func TestDNATTranslationCorrectness(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
|
||||||
|
// Add DNAT mapping
|
||||||
|
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
protocol layers.IPProtocol
|
||||||
|
srcPort uint16
|
||||||
|
dstPort uint16
|
||||||
|
}{
|
||||||
|
{"TCP", layers.IPProtocolTCP, 12345, 80},
|
||||||
|
{"UDP", layers.IPProtocolUDP, 12345, 53},
|
||||||
|
{"ICMP", layers.IPProtocolICMPv4, 0, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Test outbound DNAT translation
|
||||||
|
outboundPacket := generateDNATTestPacket(t, srcIP, originalIP, tc.protocol, tc.srcPort, tc.dstPort)
|
||||||
|
originalOutbound := make([]byte, len(outboundPacket))
|
||||||
|
copy(originalOutbound, outboundPacket)
|
||||||
|
|
||||||
|
// Process outbound packet (should translate destination)
|
||||||
|
translated := manager.translateOutboundDNAT(outboundPacket, parsePacket(t, outboundPacket))
|
||||||
|
require.True(t, translated, "Outbound packet should be translated")
|
||||||
|
|
||||||
|
// Verify destination IP was changed
|
||||||
|
dstIPAfter := netip.AddrFrom4([4]byte{outboundPacket[16], outboundPacket[17], outboundPacket[18], outboundPacket[19]})
|
||||||
|
require.Equal(t, translatedIP, dstIPAfter, "Destination IP should be translated")
|
||||||
|
|
||||||
|
// Test inbound reverse DNAT translation
|
||||||
|
inboundPacket := generateDNATTestPacket(t, translatedIP, srcIP, tc.protocol, tc.dstPort, tc.srcPort)
|
||||||
|
originalInbound := make([]byte, len(inboundPacket))
|
||||||
|
copy(originalInbound, inboundPacket)
|
||||||
|
|
||||||
|
// Process inbound packet (should reverse translate source)
|
||||||
|
reversed := manager.translateInboundReverse(inboundPacket, parsePacket(t, inboundPacket))
|
||||||
|
require.True(t, reversed, "Inbound packet should be reverse translated")
|
||||||
|
|
||||||
|
// Verify source IP was changed back to original
|
||||||
|
srcIPAfter := netip.AddrFrom4([4]byte{inboundPacket[12], inboundPacket[13], inboundPacket[14], inboundPacket[15]})
|
||||||
|
require.Equal(t, originalIP, srcIPAfter, "Source IP should be reverse translated")
|
||||||
|
|
||||||
|
// Test that checksums are recalculated correctly
|
||||||
|
if tc.protocol != layers.IPProtocolICMPv4 {
|
||||||
|
// For TCP/UDP, verify the transport checksum was updated
|
||||||
|
require.NotEqual(t, originalOutbound, outboundPacket, "Outbound packet should be modified")
|
||||||
|
require.NotEqual(t, originalInbound, inboundPacket, "Inbound packet should be modified")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePacket helper to create a decoder for testing
|
||||||
|
func parsePacket(t testing.TB, packetData []byte) *decoder {
|
||||||
|
t.Helper()
|
||||||
|
d := &decoder{
|
||||||
|
decoded: []gopacket.LayerType{},
|
||||||
|
}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
|
||||||
|
err := d.parser.DecodeLayers(packetData, &d.decoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNATMappingManagement tests adding/removing DNAT mappings
|
||||||
|
func TestDNATMappingManagement(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||||
|
|
||||||
|
// Test adding mapping
|
||||||
|
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify mapping exists
|
||||||
|
result, exists := manager.getDNATTranslation(originalIP)
|
||||||
|
require.True(t, exists)
|
||||||
|
require.Equal(t, translatedIP, result)
|
||||||
|
|
||||||
|
// Test reverse lookup
|
||||||
|
reverseResult, exists := manager.findReverseDNATMapping(translatedIP)
|
||||||
|
require.True(t, exists)
|
||||||
|
require.Equal(t, originalIP, reverseResult)
|
||||||
|
|
||||||
|
// Test removing mapping
|
||||||
|
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify mapping no longer exists
|
||||||
|
_, exists = manager.getDNATTranslation(originalIP)
|
||||||
|
require.False(t, exists)
|
||||||
|
|
||||||
|
_, exists = manager.findReverseDNATMapping(translatedIP)
|
||||||
|
require.False(t, exists)
|
||||||
|
|
||||||
|
// Test error cases
|
||||||
|
err = manager.AddInternalDNATMapping(netip.Addr{}, translatedIP)
|
||||||
|
require.Error(t, err, "Should reject invalid original IP")
|
||||||
|
|
||||||
|
err = manager.AddInternalDNATMapping(originalIP, netip.Addr{})
|
||||||
|
require.Error(t, err, "Should reject invalid translated IP")
|
||||||
|
|
||||||
|
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||||
|
require.Error(t, err, "Should error when removing non-existent mapping")
|
||||||
|
}
|
@ -401,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
|||||||
|
|
||||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||||
// will create or update the connection state
|
// will create or update the connection state
|
||||||
dropped := m.processOutgoingHooks(packetData, 0)
|
dropped := m.filterOutbound(packetData, 0)
|
||||||
if dropped {
|
if dropped {
|
||||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||||
} else {
|
} else {
|
||||||
|
@ -9,11 +9,11 @@ import (
|
|||||||
|
|
||||||
// PacketFilter interface for firewall abilities
|
// PacketFilter interface for firewall abilities
|
||||||
type PacketFilter interface {
|
type PacketFilter interface {
|
||||||
// DropOutgoing filter outgoing packets from host to external destinations
|
// FilterOutbound filter outgoing packets from host to external destinations
|
||||||
DropOutgoing(packetData []byte, size int) bool
|
FilterOutbound(packetData []byte, size int) bool
|
||||||
|
|
||||||
// DropIncoming filter incoming packets from external sources to host
|
// FilterInbound filter incoming packets from external sources to host
|
||||||
DropIncoming(packetData []byte, size int) bool
|
FilterInbound(packetData []byte, size int) bool
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
@ -54,7 +54,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||||
n--
|
n--
|
||||||
@ -78,7 +78,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
|||||||
filteredBufs := make([][]byte, 0, len(bufs))
|
filteredBufs := make([][]byte, 0, len(bufs))
|
||||||
dropped := 0
|
dropped := 0
|
||||||
for _, buf := range bufs {
|
for _, buf := range bufs {
|
||||||
if !filter.DropIncoming(buf[offset:], len(buf)) {
|
if !filter.FilterInbound(buf[offset:], len(buf)) {
|
||||||
filteredBufs = append(filteredBufs, buf)
|
filteredBufs = append(filteredBufs, buf)
|
||||||
dropped++
|
dropped++
|
||||||
}
|
}
|
||||||
|
@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||||
|
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
|
filter.EXPECT().FilterInbound(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
return 1, nil
|
return 1, nil
|
||||||
})
|
})
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
|
filter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
|
@ -48,32 +48,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming mocks base method.
|
// FilterInbound mocks base method.
|
||||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
|
func (m *MockPacketFilter) FilterInbound(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
|
ret := m.ctrl.Call(m, "FilterInbound", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming indicates an expected call of DropIncoming.
|
// FilterInbound indicates an expected call of FilterInbound.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing mocks base method.
|
// FilterOutbound mocks base method.
|
||||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
|
func (m *MockPacketFilter) FilterOutbound(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
|
ret := m.ctrl.Call(m, "FilterOutbound", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePacketHook mocks base method.
|
// RemovePacketHook mocks base method.
|
||||||
|
@ -46,32 +46,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming mocks base method.
|
// FilterInbound mocks base method.
|
||||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
ret := m.ctrl.Call(m, "FilterInbound", arg0)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming indicates an expected call of DropIncoming.
|
// FilterInbound indicates an expected call of FilterInbound.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing mocks base method.
|
// FilterOutbound mocks base method.
|
||||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork mocks base method.
|
// SetNetwork mocks base method.
|
||||||
|
@ -464,7 +464,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||||
|
|
||||||
|
@ -383,7 +383,13 @@ func (e *Engine) Start() error {
|
|||||||
}
|
}
|
||||||
e.stateManager.Start()
|
e.stateManager.Start()
|
||||||
|
|
||||||
initialRoutes, dnsServer, err := e.newDnsServer()
|
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
|
||||||
|
if err != nil {
|
||||||
|
e.close()
|
||||||
|
return fmt.Errorf("read initial settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsServer, err := e.newDnsServer(dnsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
return fmt.Errorf("create dns server: %w", err)
|
return fmt.Errorf("create dns server: %w", err)
|
||||||
@ -400,6 +406,7 @@ func (e *Engine) Start() error {
|
|||||||
InitialRoutes: initialRoutes,
|
InitialRoutes: initialRoutes,
|
||||||
StateManager: e.stateManager,
|
StateManager: e.stateManager,
|
||||||
DNSServer: dnsServer,
|
DNSServer: dnsServer,
|
||||||
|
DNSFeatureFlag: dnsFeatureFlag,
|
||||||
PeerStore: e.peerStore,
|
PeerStore: e.peerStore,
|
||||||
DisableClientRoutes: e.config.DisableClientRoutes,
|
DisableClientRoutes: e.config.DisableClientRoutes,
|
||||||
DisableServerRoutes: e.config.DisableServerRoutes,
|
DisableServerRoutes: e.config.DisableServerRoutes,
|
||||||
@ -488,9 +495,9 @@ func (e *Engine) createFirewall() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) initFirewall() error {
|
func (e *Engine) initFirewall() error {
|
||||||
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
return fmt.Errorf("enable server router: %w", err)
|
return fmt.Errorf("set firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.config.BlockLANAccess {
|
if e.config.BlockLANAccess {
|
||||||
@ -1009,8 +1016,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
log.Errorf("failed to update dns server, err: %v", err)
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
|
||||||
|
|
||||||
// apply routes first, route related actions might depend on routing being enabled
|
// apply routes first, route related actions might depend on routing being enabled
|
||||||
routes := toRoutes(networkMap.GetRoutes())
|
routes := toRoutes(networkMap.GetRoutes())
|
||||||
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
||||||
@ -1021,6 +1026,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
|
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||||
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
|
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
|
||||||
log.Errorf("failed to update routes: %v", err)
|
log.Errorf("failed to update routes: %v", err)
|
||||||
}
|
}
|
||||||
@ -1489,7 +1495,12 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
|
||||||
|
if runtime.GOOS != "android" {
|
||||||
|
// nolint:nilnil
|
||||||
|
return nil, nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
info := system.GetInfo(e.ctx)
|
info := system.GetInfo(e.ctx)
|
||||||
info.SetFlags(
|
info.SetFlags(
|
||||||
e.config.RosenpassEnabled,
|
e.config.RosenpassEnabled,
|
||||||
@ -1506,11 +1517,12 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
|||||||
|
|
||||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, false, err
|
||||||
}
|
}
|
||||||
routes := toRoutes(netMap.GetRoutes())
|
routes := toRoutes(netMap.GetRoutes())
|
||||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
|
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
|
||||||
return routes, &dnsCfg, nil
|
dnsFeatureFlag := toDNSFeatureFlag(netMap)
|
||||||
|
return routes, &dnsCfg, dnsFeatureFlag, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||||
@ -1558,18 +1570,14 @@ func (e *Engine) wgInterfaceCreate() (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||||
// due to tests where we are using a mocked version of the DNS server
|
// due to tests where we are using a mocked version of the DNS server
|
||||||
if e.dnsServer != nil {
|
if e.dnsServer != nil {
|
||||||
return nil, e.dnsServer, nil
|
return e.dnsServer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "android":
|
case "android":
|
||||||
routes, dnsConfig, err := e.readInitialSettings()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
dnsServer := dns.NewDefaultServerPermanentUpstream(
|
dnsServer := dns.NewDefaultServerPermanentUpstream(
|
||||||
e.ctx,
|
e.ctx,
|
||||||
e.wgInterface,
|
e.wgInterface,
|
||||||
@ -1580,19 +1588,19 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
e.config.DisableDNS,
|
e.config.DisableDNS,
|
||||||
)
|
)
|
||||||
go e.mobileDep.DnsReadyListener.OnReady()
|
go e.mobileDep.DnsReadyListener.OnReady()
|
||||||
return routes, dnsServer, nil
|
return dnsServer, nil
|
||||||
|
|
||||||
case "ios":
|
case "ios":
|
||||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||||
return nil, dnsServer, nil
|
return dnsServer, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
|
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, dnsServer, nil
|
return dnsServer, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,11 +10,10 @@ import (
|
|||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@ -553,41 +552,16 @@ func (w *Watcher) Stop() {
|
|||||||
w.currentChosenStatus = nil
|
w.currentChosenStatus = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandlerFromRoute(
|
func HandlerFromRoute(params common.HandlerParams) RouteHandler {
|
||||||
rt *route.Route,
|
switch handlerType(params.Route, params.UseNewDNSRoute) {
|
||||||
routeRefCounter *refcounter.RouteRefCounter,
|
|
||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
|
||||||
dnsRouterInteval time.Duration,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
wgInterface iface.WGIface,
|
|
||||||
dnsServer nbdns.Server,
|
|
||||||
peerStore *peerstore.Store,
|
|
||||||
useNewDNSRoute bool,
|
|
||||||
) RouteHandler {
|
|
||||||
switch handlerType(rt, useNewDNSRoute) {
|
|
||||||
case handlerTypeDnsInterceptor:
|
case handlerTypeDnsInterceptor:
|
||||||
return dnsinterceptor.New(
|
return dnsinterceptor.New(params)
|
||||||
rt,
|
|
||||||
routeRefCounter,
|
|
||||||
allowedIPsRefCounter,
|
|
||||||
statusRecorder,
|
|
||||||
dnsServer,
|
|
||||||
wgInterface,
|
|
||||||
peerStore,
|
|
||||||
)
|
|
||||||
case handlerTypeDynamic:
|
case handlerTypeDynamic:
|
||||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
dns := nbdns.NewServiceViaMemory(params.WgInterface)
|
||||||
return dynamic.NewRoute(
|
dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())
|
||||||
rt,
|
return dynamic.NewRoute(params, dnsAddr)
|
||||||
routeRefCounter,
|
|
||||||
allowedIPsRefCounter,
|
|
||||||
dnsRouterInteval,
|
|
||||||
statusRecorder,
|
|
||||||
wgInterface,
|
|
||||||
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
|
|
||||||
)
|
|
||||||
default:
|
default:
|
||||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
return static.NewRoute(params)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,12 +7,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetBestrouteFromStatuses(t *testing.T) {
|
func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
statuses map[route.ID]routerPeerStatus
|
statuses map[route.ID]routerPeerStatus
|
||||||
@ -811,9 +811,12 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
currentRoute = tc.existingRoutes[tc.currentRoute]
|
currentRoute = tc.existingRoutes[tc.currentRoute]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
params := common.HandlerParams{
|
||||||
|
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
|
||||||
|
}
|
||||||
// create new clientNetwork
|
// create new clientNetwork
|
||||||
client := &Watcher{
|
client := &Watcher{
|
||||||
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
|
handler: static.NewRoute(params),
|
||||||
routes: tc.existingRoutes,
|
routes: tc.existingRoutes,
|
||||||
currentChosen: currentRoute,
|
currentChosen: currentRoute,
|
||||||
}
|
}
|
||||||
|
28
client/internal/routemanager/common/params.go
Normal file
28
client/internal/routemanager/common/params.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HandlerParams struct {
|
||||||
|
Route *route.Route
|
||||||
|
RouteRefCounter *refcounter.RouteRefCounter
|
||||||
|
AllowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||||
|
DnsRouterInterval time.Duration
|
||||||
|
StatusRecorder *peer.Status
|
||||||
|
WgInterface iface.WGIface
|
||||||
|
DnsServer dns.Server
|
||||||
|
PeerStore *peerstore.Store
|
||||||
|
UseNewDNSRoute bool
|
||||||
|
Firewall manager.Manager
|
||||||
|
FakeIPManager *fakeip.Manager
|
||||||
|
}
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@ -12,11 +13,14 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@ -24,6 +28,11 @@ import (
|
|||||||
|
|
||||||
type domainMap map[domain.Domain][]netip.Prefix
|
type domainMap map[domain.Domain][]netip.Prefix
|
||||||
|
|
||||||
|
type internalDNATer interface {
|
||||||
|
RemoveInternalDNATMapping(netip.Addr) error
|
||||||
|
AddInternalDNATMapping(netip.Addr, netip.Addr) error
|
||||||
|
}
|
||||||
|
|
||||||
type wgInterface interface {
|
type wgInterface interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
@ -40,26 +49,22 @@ type DnsInterceptor struct {
|
|||||||
interceptedDomains domainMap
|
interceptedDomains domainMap
|
||||||
wgInterface wgInterface
|
wgInterface wgInterface
|
||||||
peerStore *peerstore.Store
|
peerStore *peerstore.Store
|
||||||
|
firewall firewall.Manager
|
||||||
|
fakeIPManager *fakeip.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(
|
func New(params common.HandlerParams) *DnsInterceptor {
|
||||||
rt *route.Route,
|
|
||||||
routeRefCounter *refcounter.RouteRefCounter,
|
|
||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
dnsServer nbdns.Server,
|
|
||||||
wgInterface wgInterface,
|
|
||||||
peerStore *peerstore.Store,
|
|
||||||
) *DnsInterceptor {
|
|
||||||
return &DnsInterceptor{
|
return &DnsInterceptor{
|
||||||
route: rt,
|
route: params.Route,
|
||||||
routeRefCounter: routeRefCounter,
|
routeRefCounter: params.RouteRefCounter,
|
||||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: params.StatusRecorder,
|
||||||
dnsServer: dnsServer,
|
dnsServer: params.DnsServer,
|
||||||
wgInterface: wgInterface,
|
wgInterface: params.WgInterface,
|
||||||
|
peerStore: params.PeerStore,
|
||||||
|
firewall: params.Firewall,
|
||||||
|
fakeIPManager: params.FakeIPManager,
|
||||||
interceptedDomains: make(domainMap),
|
interceptedDomains: make(domainMap),
|
||||||
peerStore: peerStore,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,9 +83,13 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for domain, prefixes := range d.interceptedDomains {
|
for domain, prefixes := range d.interceptedDomains {
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
// Routes should use fake IPs
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
|
routePrefix := d.transformRealToFakePrefix(prefix)
|
||||||
|
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", routePrefix, err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowedIPs should use real IPs
|
||||||
if d.currentPeerKey != "" {
|
if d.currentPeerKey != "" {
|
||||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
@ -88,8 +97,10 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.cleanupDNATMappings()
|
||||||
|
|
||||||
for _, domain := range d.route.Domains {
|
for _, domain := range d.route.Domains {
|
||||||
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||||
}
|
}
|
||||||
@ -102,6 +113,68 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// transformRealToFakePrefix returns fake IP prefix for routes (if DNAT enabled)
|
||||||
|
func (d *DnsInterceptor) transformRealToFakePrefix(realPrefix netip.Prefix) netip.Prefix {
|
||||||
|
if _, hasDNAT := d.internalDnatFw(); !hasDNAT {
|
||||||
|
return realPrefix
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP, ok := d.fakeIPManager.GetFakeIP(realPrefix.Addr()); ok {
|
||||||
|
return netip.PrefixFrom(fakeIP, realPrefix.Bits())
|
||||||
|
}
|
||||||
|
|
||||||
|
return realPrefix
|
||||||
|
}
|
||||||
|
|
||||||
|
// addAllowedIPForPrefix handles the AllowedIPs logic for a single prefix (uses real IPs)
|
||||||
|
func (d *DnsInterceptor) addAllowedIPForPrefix(realPrefix netip.Prefix, peerKey string, domain domain.Domain) error {
|
||||||
|
// AllowedIPs always use real IPs
|
||||||
|
ref, err := d.allowedIPsRefcounter.Increment(realPrefix, peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("add allowed IP %s: %v", realPrefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ref.Count > 1 && ref.Out != peerKey {
|
||||||
|
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||||
|
realPrefix.Addr(),
|
||||||
|
domain.SafeString(),
|
||||||
|
ref.Out,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRouteAndAllowedIP handles both route and AllowedIPs addition for a prefix
|
||||||
|
func (d *DnsInterceptor) addRouteAndAllowedIP(realPrefix netip.Prefix, domain domain.Domain) error {
|
||||||
|
// Routes use fake IPs (so traffic to fake IPs gets routed to interface)
|
||||||
|
routePrefix := d.transformRealToFakePrefix(realPrefix)
|
||||||
|
if _, err := d.routeRefCounter.Increment(routePrefix, struct{}{}); err != nil {
|
||||||
|
return fmt.Errorf("add route for IP %s: %v", routePrefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to AllowedIPs if we have a current peer (uses real IPs)
|
||||||
|
if d.currentPeerKey == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.addAllowedIPForPrefix(realPrefix, d.currentPeerKey, domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeAllowedIP handles AllowedIPs removal for a prefix (uses real IPs)
|
||||||
|
func (d *DnsInterceptor) removeAllowedIP(realPrefix netip.Prefix) error {
|
||||||
|
if d.currentPeerKey == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowedIPs use real IPs
|
||||||
|
if _, err := d.allowedIPsRefcounter.Decrement(realPrefix); err != nil {
|
||||||
|
return fmt.Errorf("remove allowed IP %s: %v", realPrefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
defer d.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
@ -109,14 +182,9 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for domain, prefixes := range d.interceptedDomains {
|
for domain, prefixes := range d.interceptedDomains {
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
// AllowedIPs use real IPs
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
if err := d.addAllowedIPForPrefix(prefix, peerKey, domain); err != nil {
|
||||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
merr = multierror.Append(merr, err)
|
||||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
|
||||||
prefix.Addr(),
|
|
||||||
domain.SafeString(),
|
|
||||||
ref.Out,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -132,6 +200,7 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for _, prefixes := range d.interceptedDomains {
|
for _, prefixes := range d.interceptedDomains {
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
|
// AllowedIPs use real IPs
|
||||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
}
|
}
|
||||||
@ -287,6 +356,8 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
||||||
log.Errorf("failed to update domain prefixes: %v", err)
|
log.Errorf("failed to update domain prefixes: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.replaceIPsInDNSResponse(r, newPrefixes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -297,6 +368,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// logPrefixChanges handles the logging for prefix changes
|
||||||
|
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
|
||||||
|
if len(toAdd) > 0 {
|
||||||
|
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||||
|
resolvedDomain.SafeString(),
|
||||||
|
originalDomain.SafeString(),
|
||||||
|
toAdd)
|
||||||
|
}
|
||||||
|
if len(toRemove) > 0 && !d.route.KeepRoute {
|
||||||
|
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||||
|
resolvedDomain.SafeString(),
|
||||||
|
originalDomain.SafeString(),
|
||||||
|
toRemove)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
defer d.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
@ -305,70 +392,163 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
|||||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
var dnatMappings map[netip.Addr]netip.Addr
|
||||||
|
|
||||||
|
// Handle DNAT mappings for new prefixes
|
||||||
|
if _, hasDNAT := d.internalDnatFw(); hasDNAT {
|
||||||
|
dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||||
|
for _, prefix := range toAdd {
|
||||||
|
realIP := prefix.Addr()
|
||||||
|
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
|
||||||
|
dnatMappings[fakeIP] = realIP
|
||||||
|
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
|
||||||
|
} else {
|
||||||
|
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Add new prefixes
|
// Add new prefixes
|
||||||
for _, prefix := range toAdd {
|
for _, prefix := range toAdd {
|
||||||
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
|
if err := d.addRouteAndAllowedIP(prefix, resolvedDomain); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
|
merr = multierror.Append(merr, err)
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if d.currentPeerKey == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
|
||||||
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
|
||||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
|
||||||
prefix.Addr(),
|
|
||||||
resolvedDomain.SafeString(),
|
|
||||||
ref.Out,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.addDNATMappings(dnatMappings)
|
||||||
|
|
||||||
if !d.route.KeepRoute {
|
if !d.route.KeepRoute {
|
||||||
// Remove old prefixes
|
// Remove old prefixes
|
||||||
for _, prefix := range toRemove {
|
for _, prefix := range toRemove {
|
||||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
// Routes use fake IPs
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
|
routePrefix := d.transformRealToFakePrefix(prefix)
|
||||||
|
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", routePrefix, err))
|
||||||
}
|
}
|
||||||
if d.currentPeerKey != "" {
|
// AllowedIPs use real IPs
|
||||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
if err := d.removeAllowedIP(prefix); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
merr = multierror.Append(merr, err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.removeDNATMappings(toRemove)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update domain prefixes using resolved domain as key
|
// Update domain prefixes using resolved domain as key - store real IPs
|
||||||
if len(toAdd) > 0 || len(toRemove) > 0 {
|
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||||
if d.route.KeepRoute {
|
if d.route.KeepRoute {
|
||||||
// replace stored prefixes with old + added
|
|
||||||
// nolint:gocritic
|
// nolint:gocritic
|
||||||
newPrefixes = append(oldPrefixes, toAdd...)
|
newPrefixes = append(oldPrefixes, toAdd...)
|
||||||
}
|
}
|
||||||
d.interceptedDomains[resolvedDomain] = newPrefixes
|
d.interceptedDomains[resolvedDomain] = newPrefixes
|
||||||
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
||||||
|
|
||||||
|
// Store real IPs for status (user-facing), not fake IPs
|
||||||
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
||||||
|
|
||||||
if len(toAdd) > 0 {
|
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
|
||||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
|
||||||
resolvedDomain.SafeString(),
|
|
||||||
originalDomain.SafeString(),
|
|
||||||
toAdd)
|
|
||||||
}
|
|
||||||
if len(toRemove) > 0 && !d.route.KeepRoute {
|
|
||||||
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
|
||||||
resolvedDomain.SafeString(),
|
|
||||||
originalDomain.SafeString(),
|
|
||||||
toRemove)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
|
||||||
|
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
|
||||||
|
if len(realPrefixes) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dnatFirewall, ok := d.internalDnatFw()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range realPrefixes {
|
||||||
|
realIP := prefix.Addr()
|
||||||
|
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||||
|
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
|
||||||
|
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// internalDnatFw checks if the firewall supports internal DNAT
|
||||||
|
func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
|
||||||
|
if d.firewall == nil || runtime.GOOS != "android" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
fw, ok := d.firewall.(internalDNATer)
|
||||||
|
return fw, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// addDNATMappings adds DNAT mappings to the firewall
|
||||||
|
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
|
||||||
|
if len(mappings) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dnatFirewall, ok := d.internalDnatFw()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for fakeIP, realIP := range mappings {
|
||||||
|
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
|
||||||
|
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupDNATMappings removes all DNAT mappings for this interceptor
|
||||||
|
func (d *DnsInterceptor) cleanupDNATMappings() {
|
||||||
|
if _, ok := d.internalDnatFw(); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefixes := range d.interceptedDomains {
|
||||||
|
d.removeDNATMappings(prefixes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
|
||||||
|
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
|
||||||
|
if _, ok := d.internalDnatFw(); !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace A and AAAA records with fake IPs
|
||||||
|
for _, answer := range reply.Answer {
|
||||||
|
switch rr := answer.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
realIP, ok := netip.AddrFromSlice(rr.A)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||||
|
rr.A = fakeIP.AsSlice()
|
||||||
|
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
case *dns.AAAA:
|
||||||
|
realIP, ok := netip.AddrFromSlice(rr.AAAA)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||||
|
rr.AAAA = fakeIP.AsSlice()
|
||||||
|
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||||
prefixSet := make(map[netip.Prefix]bool)
|
prefixSet := make(map[netip.Prefix]bool)
|
||||||
for _, prefix := range oldPrefixes {
|
for _, prefix := range oldPrefixes {
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
@ -52,24 +53,16 @@ type Route struct {
|
|||||||
resolverAddr string
|
resolverAddr string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRoute(
|
func NewRoute(params common.HandlerParams, resolverAddr string) *Route {
|
||||||
rt *route.Route,
|
|
||||||
routeRefCounter *refcounter.RouteRefCounter,
|
|
||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
|
||||||
interval time.Duration,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
wgInterface iface.WGIface,
|
|
||||||
resolverAddr string,
|
|
||||||
) *Route {
|
|
||||||
return &Route{
|
return &Route{
|
||||||
route: rt,
|
route: params.Route,
|
||||||
routeRefCounter: routeRefCounter,
|
routeRefCounter: params.RouteRefCounter,
|
||||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||||
interval: interval,
|
interval: params.DnsRouterInterval,
|
||||||
dynamicDomains: domainMap{},
|
statusRecorder: params.StatusRecorder,
|
||||||
statusRecorder: statusRecorder,
|
wgInterface: params.WgInterface,
|
||||||
wgInterface: wgInterface,
|
|
||||||
resolverAddr: resolverAddr,
|
resolverAddr: resolverAddr,
|
||||||
|
dynamicDomains: domainMap{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
93
client/internal/routemanager/fakeip/fakeip.go
Normal file
93
client/internal/routemanager/fakeip/fakeip.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
package fakeip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager manages allocation of fake IPs from the 240.0.0.0/8 block
|
||||||
|
type Manager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
nextIP netip.Addr // Next IP to allocate
|
||||||
|
allocated map[netip.Addr]netip.Addr // real IP -> fake IP
|
||||||
|
fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP
|
||||||
|
baseIP netip.Addr // First usable IP: 240.0.0.1
|
||||||
|
maxIP netip.Addr // Last usable IP: 240.255.255.254
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new fake IP manager using 240.0.0.0/8 block
|
||||||
|
func NewManager() *Manager {
|
||||||
|
baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1})
|
||||||
|
maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254})
|
||||||
|
|
||||||
|
return &Manager{
|
||||||
|
nextIP: baseIP,
|
||||||
|
allocated: make(map[netip.Addr]netip.Addr),
|
||||||
|
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||||
|
baseIP: baseIP,
|
||||||
|
maxIP: maxIP,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllocateFakeIP allocates a fake IP for the given real IP
|
||||||
|
// Returns the fake IP, or existing fake IP if already allocated
|
||||||
|
func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
|
||||||
|
if !realIP.Is4() {
|
||||||
|
return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if fakeIP, exists := m.allocated[realIP]; exists {
|
||||||
|
return fakeIP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
startIP := m.nextIP
|
||||||
|
for {
|
||||||
|
currentIP := m.nextIP
|
||||||
|
|
||||||
|
// Advance to next IP, wrapping at boundary
|
||||||
|
if m.nextIP.Compare(m.maxIP) >= 0 {
|
||||||
|
m.nextIP = m.baseIP
|
||||||
|
} else {
|
||||||
|
m.nextIP = m.nextIP.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if current IP is available
|
||||||
|
if _, inUse := m.fakeToReal[currentIP]; !inUse {
|
||||||
|
m.allocated[realIP] = currentIP
|
||||||
|
m.fakeToReal[currentIP] = realIP
|
||||||
|
return currentIP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prevent infinite loop if all IPs exhausted
|
||||||
|
if m.nextIP.Compare(startIP) == 0 {
|
||||||
|
return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFakeIP returns the fake IP for a real IP if it exists
|
||||||
|
func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
fakeIP, exists := m.allocated[realIP]
|
||||||
|
return fakeIP, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRealIP returns the real IP for a fake IP if it exists, otherwise false
|
||||||
|
func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
realIP, exists := m.fakeToReal[fakeIP]
|
||||||
|
return realIP, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFakeIPBlock returns the fake IP block used by this manager
|
||||||
|
func (m *Manager) GetFakeIPBlock() netip.Prefix {
|
||||||
|
return netip.MustParsePrefix("240.0.0.0/8")
|
||||||
|
}
|
240
client/internal/routemanager/fakeip/fakeip_test.go
Normal file
240
client/internal/routemanager/fakeip/fakeip_test.go
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
package fakeip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewManager(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
|
||||||
|
if manager.baseIP.String() != "240.0.0.1" {
|
||||||
|
t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.maxIP.String() != "240.255.255.254" {
|
||||||
|
t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.nextIP.Compare(manager.baseIP) != 0 {
|
||||||
|
t.Errorf("Expected nextIP to start at baseIP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllocateFakeIP(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
realIP := netip.MustParseAddr("8.8.8.8")
|
||||||
|
|
||||||
|
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate fake IP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fakeIP.Is4() {
|
||||||
|
t.Error("Fake IP should be IPv4")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check it's in the correct range
|
||||||
|
if fakeIP.As4()[0] != 240 {
|
||||||
|
t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return same fake IP for same real IP
|
||||||
|
fakeIP2, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get existing fake IP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP.Compare(fakeIP2) != 0 {
|
||||||
|
t.Errorf("Expected same fake IP for same real IP, got %s and %s", fakeIP.String(), fakeIP2.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllocateFakeIPIPv6Rejection(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
realIPv6 := netip.MustParseAddr("2001:db8::1")
|
||||||
|
|
||||||
|
_, err := manager.AllocateFakeIP(realIPv6)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for IPv6 address")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFakeIP(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
realIP := netip.MustParseAddr("1.1.1.1")
|
||||||
|
|
||||||
|
// Should not exist initially
|
||||||
|
_, exists := manager.GetFakeIP(realIP)
|
||||||
|
if exists {
|
||||||
|
t.Error("Fake IP should not exist before allocation")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate and check
|
||||||
|
expectedFakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fakeIP, exists := manager.GetFakeIP(realIP)
|
||||||
|
if !exists {
|
||||||
|
t.Error("Fake IP should exist after allocation")
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP.Compare(expectedFakeIP) != 0 {
|
||||||
|
t.Errorf("Expected %s, got %s", expectedFakeIP.String(), fakeIP.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleAllocations(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
|
||||||
|
allocations := make(map[netip.Addr]netip.Addr)
|
||||||
|
|
||||||
|
// Allocate multiple IPs
|
||||||
|
for i := 1; i <= 100; i++ {
|
||||||
|
realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
|
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for duplicates
|
||||||
|
for _, existingFake := range allocations {
|
||||||
|
if fakeIP.Compare(existingFake) == 0 {
|
||||||
|
t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
allocations[realIP] = fakeIP
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all allocations can be retrieved
|
||||||
|
for realIP, expectedFake := range allocations {
|
||||||
|
actualFake, exists := manager.GetFakeIP(realIP)
|
||||||
|
if !exists {
|
||||||
|
t.Errorf("Missing allocation for %s", realIP.String())
|
||||||
|
}
|
||||||
|
if actualFake.Compare(expectedFake) != 0 {
|
||||||
|
t.Errorf("Mismatch for %s: expected %s, got %s", realIP.String(), expectedFake.String(), actualFake.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFakeIPBlock(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
block := manager.GetFakeIPBlock()
|
||||||
|
|
||||||
|
expected := "240.0.0.0/8"
|
||||||
|
if block.String() != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, block.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentAccess(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
|
||||||
|
const numGoroutines = 50
|
||||||
|
const allocationsPerGoroutine = 10
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine)
|
||||||
|
|
||||||
|
// Concurrent allocations
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(goroutineID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < allocationsPerGoroutine; j++ {
|
||||||
|
realIP := netip.AddrFrom4([4]byte{192, 168, byte(goroutineID), byte(j)})
|
||||||
|
fakeIP, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to allocate in goroutine %d: %v", goroutineID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
results <- fakeIP
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(results)
|
||||||
|
|
||||||
|
// Check for duplicates
|
||||||
|
seen := make(map[netip.Addr]bool)
|
||||||
|
count := 0
|
||||||
|
for fakeIP := range results {
|
||||||
|
if seen[fakeIP] {
|
||||||
|
t.Errorf("Duplicate fake IP in concurrent test: %s", fakeIP.String())
|
||||||
|
}
|
||||||
|
seen[fakeIP] = true
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
|
||||||
|
if count != numGoroutines*allocationsPerGoroutine {
|
||||||
|
t.Errorf("Expected %d allocations, got %d", numGoroutines*allocationsPerGoroutine, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPExhaustion(t *testing.T) {
|
||||||
|
// Create a manager with limited range for testing
|
||||||
|
manager := &Manager{
|
||||||
|
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||||
|
allocated: make(map[netip.Addr]netip.Addr),
|
||||||
|
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||||
|
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||||
|
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate all available IPs
|
||||||
|
realIPs := []netip.Addr{
|
||||||
|
netip.MustParseAddr("1.0.0.1"),
|
||||||
|
netip.MustParseAddr("1.0.0.2"),
|
||||||
|
netip.MustParseAddr("1.0.0.3"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, realIP := range realIPs {
|
||||||
|
_, err := manager.AllocateFakeIP(realIP)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate fake IP: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to allocate one more - should fail
|
||||||
|
_, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4"))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected exhaustion error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapAround(t *testing.T) {
|
||||||
|
// Create manager starting near the end of range
|
||||||
|
manager := &Manager{
|
||||||
|
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||||
|
allocated: make(map[netip.Addr]netip.Addr),
|
||||||
|
fakeToReal: make(map[netip.Addr]netip.Addr),
|
||||||
|
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
|
||||||
|
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate the last IP
|
||||||
|
fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate first IP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP1.String() != "240.0.0.254" {
|
||||||
|
t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next allocation should wrap around to the beginning
|
||||||
|
fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to allocate second IP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeIP2.String() != "240.0.0.1" {
|
||||||
|
t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String())
|
||||||
|
}
|
||||||
|
}
|
@ -8,9 +8,11 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
@ -24,6 +26,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
@ -49,7 +53,7 @@ type Manager interface {
|
|||||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
EnableServerRouter(firewall firewall.Manager) error
|
SetFirewall(firewall.Manager) error
|
||||||
Stop(stateManager *statemanager.Manager)
|
Stop(stateManager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,6 +67,7 @@ type ManagerConfig struct {
|
|||||||
InitialRoutes []*route.Route
|
InitialRoutes []*route.Route
|
||||||
StateManager *statemanager.Manager
|
StateManager *statemanager.Manager
|
||||||
DNSServer dns.Server
|
DNSServer dns.Server
|
||||||
|
DNSFeatureFlag bool
|
||||||
PeerStore *peerstore.Store
|
PeerStore *peerstore.Store
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
DisableServerRoutes bool
|
DisableServerRoutes bool
|
||||||
@ -89,11 +94,13 @@ type DefaultManager struct {
|
|||||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||||
clientRoutes route.HAMap
|
clientRoutes route.HAMap
|
||||||
dnsServer dns.Server
|
dnsServer dns.Server
|
||||||
|
firewall firewall.Manager
|
||||||
peerStore *peerstore.Store
|
peerStore *peerstore.Store
|
||||||
useNewDNSRoute bool
|
useNewDNSRoute bool
|
||||||
disableClientRoutes bool
|
disableClientRoutes bool
|
||||||
disableServerRoutes bool
|
disableServerRoutes bool
|
||||||
activeRoutes map[route.HAUniqueID]client.RouteHandler
|
activeRoutes map[route.HAUniqueID]client.RouteHandler
|
||||||
|
fakeIPManager *fakeip.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(config ManagerConfig) *DefaultManager {
|
func NewManager(config ManagerConfig) *DefaultManager {
|
||||||
@ -129,11 +136,31 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if runtime.GOOS == "android" {
|
if runtime.GOOS == "android" {
|
||||||
cr := dm.initialClientRoutes(config.InitialRoutes)
|
dm.setupAndroidRoutes(config)
|
||||||
dm.notifier.SetInitialClientRoutes(cr)
|
|
||||||
}
|
}
|
||||||
return dm
|
return dm
|
||||||
}
|
}
|
||||||
|
func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
||||||
|
cr := m.initialClientRoutes(config.InitialRoutes)
|
||||||
|
|
||||||
|
routesForComparison := slices.Clone(cr)
|
||||||
|
|
||||||
|
if config.DNSFeatureFlag {
|
||||||
|
m.fakeIPManager = fakeip.NewManager()
|
||||||
|
|
||||||
|
id := uuid.NewString()
|
||||||
|
fakeIPRoute := &route.Route{
|
||||||
|
ID: route.ID(id),
|
||||||
|
Network: m.fakeIPManager.GetFakeIPBlock(),
|
||||||
|
NetID: route.NetID(id),
|
||||||
|
Peer: m.pubKey,
|
||||||
|
NetworkType: route.IPv4Network,
|
||||||
|
}
|
||||||
|
cr = append(cr, fakeIPRoute)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||||
m.routeRefCounter = refcounter.New(
|
m.routeRefCounter = refcounter.New(
|
||||||
@ -222,16 +249,16 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
|||||||
return routeselector.NewRouteSelector()
|
return routeselector.NewRouteSelector()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
// SetFirewall sets the firewall manager for the DefaultManager
|
||||||
if m.disableServerRoutes {
|
// Not thread-safe, should be called before starting the manager
|
||||||
|
func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
|
||||||
|
m.firewall = firewall
|
||||||
|
|
||||||
|
if m.disableServerRoutes || firewall == nil {
|
||||||
log.Info("server routes are disabled")
|
log.Info("server routes are disabled")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if firewall == nil {
|
|
||||||
return errors.New("firewall manager is not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -299,17 +326,20 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for id, route := range toAdd {
|
for id, route := range toAdd {
|
||||||
handler := client.HandlerFromRoute(
|
params := common.HandlerParams{
|
||||||
route,
|
Route: route,
|
||||||
m.routeRefCounter,
|
RouteRefCounter: m.routeRefCounter,
|
||||||
m.allowedIPsRefCounter,
|
AllowedIPsRefCounter: m.allowedIPsRefCounter,
|
||||||
m.dnsRouteInterval,
|
DnsRouterInterval: m.dnsRouteInterval,
|
||||||
m.statusRecorder,
|
StatusRecorder: m.statusRecorder,
|
||||||
m.wgInterface,
|
WgInterface: m.wgInterface,
|
||||||
m.dnsServer,
|
DnsServer: m.dnsServer,
|
||||||
m.peerStore,
|
PeerStore: m.peerStore,
|
||||||
m.useNewDNSRoute,
|
UseNewDNSRoute: m.useNewDNSRoute,
|
||||||
)
|
Firewall: m.firewall,
|
||||||
|
FakeIPManager: m.fakeIPManager,
|
||||||
|
}
|
||||||
|
handler := client.HandlerFromRoute(params)
|
||||||
if err := handler.AddRoute(m.ctx); err != nil {
|
if err := handler.AddRoute(m.ctx); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
|
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
|
||||||
continue
|
continue
|
||||||
@ -517,6 +547,7 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro
|
|||||||
for _, routes := range crMap {
|
for _, routes := range crMap {
|
||||||
rs = append(rs, routes...)
|
rs = append(rs, routes...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return rs
|
return rs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
|
func (m *MockManager) SetFirewall(firewall.Manager) error {
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,124 +0,0 @@
|
|||||||
package notifier
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"runtime"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Notifier struct {
|
|
||||||
initialRouteRanges []string
|
|
||||||
routeRanges []string
|
|
||||||
|
|
||||||
listener listener.NetworkChangeListener
|
|
||||||
listenerMux sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewNotifier() *Notifier {
|
|
||||||
return &Notifier{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
|
||||||
n.listenerMux.Lock()
|
|
||||||
defer n.listenerMux.Unlock()
|
|
||||||
n.listener = listener
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
|
||||||
nets := make([]string, 0)
|
|
||||||
for _, r := range clientRoutes {
|
|
||||||
if r.IsDynamic() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
nets = append(nets, r.Network.String())
|
|
||||||
}
|
|
||||||
sort.Strings(nets)
|
|
||||||
n.initialRouteRanges = nets
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
|
||||||
if runtime.GOOS != "android" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var newNets []string
|
|
||||||
for _, routes := range idMap {
|
|
||||||
for _, r := range routes {
|
|
||||||
if r.IsDynamic() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newNets = append(newNets, r.Network.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(newNets)
|
|
||||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
n.routeRanges = newNets
|
|
||||||
n.notify()
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnNewPrefixes is called from iOS only
|
|
||||||
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
|
||||||
newNets := make([]string, 0)
|
|
||||||
for _, prefix := range prefixes {
|
|
||||||
newNets = append(newNets, prefix.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(newNets)
|
|
||||||
if !n.hasDiff(n.routeRanges, newNets) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
n.routeRanges = newNets
|
|
||||||
n.notify()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) notify() {
|
|
||||||
n.listenerMux.Lock()
|
|
||||||
defer n.listenerMux.Unlock()
|
|
||||||
if n.listener == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go func(l listener.NetworkChangeListener) {
|
|
||||||
l.OnNetworkChanged(strings.Join(addIPv6RangeIfNeeded(n.routeRanges), ","))
|
|
||||||
}(n.listener)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) hasDiff(a []string, b []string) bool {
|
|
||||||
if len(a) != len(b) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
for i, v := range a {
|
|
||||||
if v != b[i] {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
|
||||||
return addIPv6RangeIfNeeded(n.initialRouteRanges)
|
|
||||||
}
|
|
||||||
|
|
||||||
// addIPv6RangeIfNeeded returns the input ranges with the default IPv6 range when there is an IPv4 default route.
|
|
||||||
func addIPv6RangeIfNeeded(inputRanges []string) []string {
|
|
||||||
ranges := inputRanges
|
|
||||||
for _, r := range inputRanges {
|
|
||||||
// we are intentionally adding the ipv6 default range in case of ipv4 default range
|
|
||||||
// to ensure that all traffic is managed by the tunnel interface on android
|
|
||||||
if r == "0.0.0.0/0" {
|
|
||||||
ranges = append(ranges, "::/0")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ranges
|
|
||||||
}
|
|
127
client/internal/routemanager/notifier/notifier_android.go
Normal file
127
client/internal/routemanager/notifier/notifier_android.go
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package notifier
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Notifier struct {
|
||||||
|
initialRoutes []*route.Route
|
||||||
|
currentRoutes []*route.Route
|
||||||
|
|
||||||
|
listener listener.NetworkChangeListener
|
||||||
|
listenerMux sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNotifier() *Notifier {
|
||||||
|
return &Notifier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||||
|
n.listenerMux.Lock()
|
||||||
|
defer n.listenerMux.Unlock()
|
||||||
|
n.listener = listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
|
||||||
|
// initialRoutes contains fake IP block for interface configuration
|
||||||
|
filteredInitial := make([]*route.Route, 0)
|
||||||
|
for _, r := range initialRoutes {
|
||||||
|
if r.IsDynamic() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filteredInitial = append(filteredInitial, r)
|
||||||
|
}
|
||||||
|
n.initialRoutes = filteredInitial
|
||||||
|
|
||||||
|
// routesForComparison excludes fake IP block for comparison with new routes
|
||||||
|
filteredComparison := make([]*route.Route, 0)
|
||||||
|
for _, r := range routesForComparison {
|
||||||
|
if r.IsDynamic() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filteredComparison = append(filteredComparison, r)
|
||||||
|
}
|
||||||
|
n.currentRoutes = filteredComparison
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||||
|
var newRoutes []*route.Route
|
||||||
|
for _, routes := range idMap {
|
||||||
|
for _, r := range routes {
|
||||||
|
if r.IsDynamic() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newRoutes = append(newRoutes, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !n.hasRouteDiff(n.currentRoutes, newRoutes) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
n.currentRoutes = newRoutes
|
||||||
|
n.notify()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) OnNewPrefixes([]netip.Prefix) {
|
||||||
|
// Not used on Android
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) notify() {
|
||||||
|
n.listenerMux.Lock()
|
||||||
|
defer n.listenerMux.Unlock()
|
||||||
|
if n.listener == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
routeStrings := n.routesToStrings(n.currentRoutes)
|
||||||
|
sort.Strings(routeStrings)
|
||||||
|
go func(l listener.NetworkChangeListener) {
|
||||||
|
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ","))
|
||||||
|
}(n.listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) routesToStrings(routes []*route.Route) []string {
|
||||||
|
nets := make([]string, 0, len(routes))
|
||||||
|
for _, r := range routes {
|
||||||
|
nets = append(nets, r.NetString())
|
||||||
|
}
|
||||||
|
return nets
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) hasRouteDiff(a []*route.Route, b []*route.Route) bool {
|
||||||
|
slices.SortFunc(a, func(x, y *route.Route) int {
|
||||||
|
return strings.Compare(x.NetString(), y.NetString())
|
||||||
|
})
|
||||||
|
slices.SortFunc(b, func(x, y *route.Route) int {
|
||||||
|
return strings.Compare(x.NetString(), y.NetString())
|
||||||
|
})
|
||||||
|
|
||||||
|
return !slices.EqualFunc(a, b, func(x, y *route.Route) bool {
|
||||||
|
return x.NetString() == y.NetString()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||||
|
initialStrings := n.routesToStrings(n.initialRoutes)
|
||||||
|
sort.Strings(initialStrings)
|
||||||
|
return n.addIPv6RangeIfNeeded(initialStrings, n.initialRoutes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string, routes []*route.Route) []string {
|
||||||
|
for _, r := range routes {
|
||||||
|
if r.Network.Addr().Is4() && r.Network.Bits() == 0 {
|
||||||
|
return append(slices.Clone(inputRanges), "::/0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return inputRanges
|
||||||
|
}
|
80
client/internal/routemanager/notifier/notifier_ios.go
Normal file
80
client/internal/routemanager/notifier/notifier_ios.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
|
package notifier
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Notifier struct {
|
||||||
|
currentPrefixes []string
|
||||||
|
|
||||||
|
listener listener.NetworkChangeListener
|
||||||
|
listenerMux sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNotifier() *Notifier {
|
||||||
|
return &Notifier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||||
|
n.listenerMux.Lock()
|
||||||
|
defer n.listenerMux.Unlock()
|
||||||
|
n.listener = listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
||||||
|
// iOS doesn't care about initial routes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) OnNewRoutes(route.HAMap) {
|
||||||
|
// Not used on iOS
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||||
|
newNets := make([]string, 0)
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
newNets = append(newNets, prefix.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(newNets)
|
||||||
|
|
||||||
|
if slices.Equal(n.currentPrefixes, newNets) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
n.currentPrefixes = newNets
|
||||||
|
n.notify()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) notify() {
|
||||||
|
n.listenerMux.Lock()
|
||||||
|
defer n.listenerMux.Unlock()
|
||||||
|
if n.listener == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func(l listener.NetworkChangeListener) {
|
||||||
|
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(n.currentPrefixes), ","))
|
||||||
|
}(n.listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string) []string {
|
||||||
|
for _, r := range inputRanges {
|
||||||
|
if r == "0.0.0.0/0" {
|
||||||
|
return append(slices.Clone(inputRanges), "::/0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return inputRanges
|
||||||
|
}
|
36
client/internal/routemanager/notifier/notifier_other.go
Normal file
36
client/internal/routemanager/notifier/notifier_other.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
//go:build !android && !ios
|
||||||
|
|
||||||
|
package notifier
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Notifier struct{}
|
||||||
|
|
||||||
|
func NewNotifier() *Notifier {
|
||||||
|
return &Notifier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||||
|
// Not used on non-mobile platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
||||||
|
// Not used on non-mobile platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||||
|
// Not used on non-mobile platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||||
|
// Not used on non-mobile platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||||
|
return []string{}
|
||||||
|
}
|
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@ -16,11 +17,11 @@ type Route struct {
|
|||||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
|
func NewRoute(params common.HandlerParams) *Route {
|
||||||
return &Route{
|
return &Route{
|
||||||
route: rt,
|
route: params.Route,
|
||||||
routeRefCounter: routeRefCounter,
|
routeRefCounter: params.RouteRefCounter,
|
||||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
allowedIPsRefcounter: params.AllowedIPsRefCounter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user