Replace net.IP with netip.Addr (#3425)

This commit is contained in:
Viktor Liu 2025-03-05 18:28:05 +01:00 committed by GitHub
parent 419ed275fa
commit e9f11fb11b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 341 additions and 309 deletions

View File

@ -4,6 +4,7 @@ package uspfilter
import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
@ -17,8 +18,8 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
@ -35,8 +36,8 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
}
if m.forwarder != nil {
m.forwarder.Stop()
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {

View File

@ -3,6 +3,7 @@ package uspfilter
import (
"context"
"fmt"
"net/netip"
"os/exec"
"syscall"
"time"
@ -26,8 +27,8 @@ func (m *Manager) Reset(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
@ -44,8 +45,8 @@ func (m *Manager) Reset(*statemanager.Manager) error {
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
}
if m.forwarder != nil {
m.forwarder.Stop()
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {

View File

@ -2,7 +2,6 @@ package conntrack
import (
"fmt"
"net"
"net/netip"
"sync/atomic"
"time"
@ -52,16 +51,3 @@ type ConnKey struct {
func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
}
// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
return ConnKey{
SrcIP: srcAddr,
DstIP: dstAddr,
SrcPort: srcPort,
DstPort: dstPort,
}
}

View File

@ -2,7 +2,7 @@ package conntrack
import (
"context"
"net"
"net/netip"
"testing"
"github.com/sirupsen/logrus"
@ -21,11 +21,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
defer tracker.Close()
// Generate different IPs
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
srcIPs := make([]netip.Addr, 100)
dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
}
b.ResetTimer()
@ -46,11 +46,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
defer tracker.Close()
// Generate different IPs
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
srcIPs := make([]netip.Addr, 100)
dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
}
b.ResetTimer()

View File

@ -2,7 +2,6 @@ package conntrack
import (
"fmt"
"net"
"net/netip"
"sync"
"time"
@ -70,8 +69,13 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
return tracker
}
func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) (ICMPConnKey, bool) {
key := makeICMPKey(srcIP, dstIP, id, seq)
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16) (ICMPConnKey, bool) {
key := ICMPConnKey{
SrcIP: srcIP,
DstIP: dstIP,
ID: id,
Sequence: seq,
}
t.mutex.RLock()
conn, exists := t.connections[key]
@ -87,7 +91,7 @@ func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq
}
// TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Egress)
@ -95,12 +99,12 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Ingress)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) {
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) {
// TODO: icmp doesn't need to extend the timeout
key, exists := t.updateIfExists(srcIP, dstIP, id, seq)
if exists {
@ -112,7 +116,7 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code)
return
}
@ -120,8 +124,8 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourceIP: srcIP,
DestIP: dstIP,
},
ICMPType: typ,
ICMPCode: code,
@ -133,16 +137,21 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, key, conn)
t.sendEvent(nftypes.TypeStart, conn)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, icmpType uint8) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false
}
key := makeICMPKey(dstIP, srcIP, id, seq)
key := ICMPConnKey{
SrcIP: dstIP,
DstIP: srcIP,
ID: id,
Sequence: seq,
}
t.mutex.RLock()
conn, exists := t.connections[key]
@ -177,7 +186,7 @@ func (t *ICMPTracker) cleanup() {
delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %s (timeout)", &key)
t.sendEvent(nftypes.TypeEnd, key, conn)
t.sendEvent(nftypes.TypeEnd, conn)
}
}
}
@ -192,40 +201,28 @@ func (t *ICMPTracker) Close() {
t.mutex.Unlock()
}
func (t *ICMPTracker) sendEvent(typ nftypes.Type, key ICMPConnKey, conn *ICMPConnTrack) {
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
Direction: conn.Direction,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
ICMPType: conn.ICMPType,
ICMPCode: conn.ICMPCode,
})
}
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, key ICMPConnKey, typ, code uint8) {
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeStart,
Direction: direction,
Protocol: nftypes.ICMP,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourceIP: srcIP,
DestIP: dstIP,
ICMPType: typ,
ICMPCode: code,
})
}
// makeICMPKey creates an ICMP connection key
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
return ICMPConnKey{
SrcIP: srcAddr,
DstIP: dstAddr,
ID: id,
Sequence: seq,
}
}

View File

@ -1,7 +1,7 @@
package conntrack
import (
"net"
"net/netip"
"testing"
)
@ -10,8 +10,8 @@ func BenchmarkICMPTracker(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
@ -23,8 +23,8 @@ func BenchmarkICMPTracker(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {

View File

@ -3,7 +3,7 @@ package conntrack
// TODO: Send RST packets for invalid/timed-out connections
import (
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
@ -144,8 +144,13 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker
}
func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
@ -154,7 +159,6 @@ func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
if exists {
conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.UpdateLastSeen()
conn.Unlock()
return key, true
@ -164,7 +168,7 @@ func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress)
@ -172,12 +176,12 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) {
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags)
if exists {
return
@ -187,14 +191,13 @@ func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
DestPort: dstPort,
},
}
conn.UpdateLastSeen()
conn.established.Store(false)
conn.tombstone.Store(false)
@ -205,12 +208,17 @@ func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
t.connections[key] = conn
t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, key, conn)
t.sendEvent(nftypes.TypeStart, conn)
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
@ -233,13 +241,12 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
conn.Unlock()
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
t.sendEvent(nftypes.TypeEnd, conn)
return true
}
conn.Lock()
t.updateState(key, conn, flags, false)
conn.UpdateLastSeen()
isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock()
@ -249,6 +256,8 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
// updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
conn.UpdateLastSeen()
state := conn.State
defer func() {
if state != conn.State {
@ -312,7 +321,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
t.sendEvent(nftypes.TypeEnd, conn)
}
case TCPStateClosing:
@ -321,7 +330,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
// Keep established = false from previous state
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
t.sendEvent(nftypes.TypeEnd, conn)
}
case TCPStateCloseWait:
@ -335,7 +344,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.SetTombstone()
// Send close event for gracefully closed connections
t.sendEvent(nftypes.TypeEnd, key, conn)
t.sendEvent(nftypes.TypeEnd, conn)
t.logger.Trace("TCP connection %s closed gracefully", key)
}
}
@ -422,7 +431,7 @@ func (t *TCPTracker) cleanup() {
// event already handled by state change
if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, key, conn)
t.sendEvent(nftypes.TypeEnd, conn)
}
}
}
@ -453,15 +462,15 @@ func isValidFlagCombination(flags uint8) bool {
return true
}
func (t *TCPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *TCPConnTrack) {
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
Direction: conn.Direction,
Protocol: nftypes.TCP,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: key.SrcPort,
DestPort: key.DstPort,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
})
}

View File

@ -1,7 +1,7 @@
package conntrack
import (
"net"
"net/netip"
"testing"
"time"
@ -12,8 +12,8 @@ func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
@ -165,8 +165,8 @@ func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
@ -208,7 +208,12 @@ func TestRSTHandling(t *testing.T) {
tt.sendRST()
// Verify connection state is as expected
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
if tt.wantValid {
require.NotNil(t, conn)
@ -220,7 +225,7 @@ func TestRSTHandling(t *testing.T) {
}
// Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
@ -236,8 +241,8 @@ func BenchmarkTCPTracker(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
@ -249,8 +254,8 @@ func BenchmarkTCPTracker(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
@ -267,8 +272,8 @@ func BenchmarkTCPTracker(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
@ -291,8 +296,8 @@ func BenchmarkCleanup(b *testing.B) {
defer tracker.Close()
// Pre-populate with expired connections
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}

View File

@ -1,7 +1,7 @@
package conntrack
import (
"net"
"net/netip"
"sync"
"time"
@ -54,7 +54,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress)
@ -62,12 +62,17 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
}
// TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress)
}
func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) (ConnKey, bool) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
@ -82,7 +87,7 @@ func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, direction nftypes.Direction) {
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort)
if exists {
return
@ -92,8 +97,8 @@ func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
DestPort: dstPort,
},
@ -105,12 +110,17 @@ func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, key, conn)
t.sendEvent(nftypes.TypeStart, conn)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
@ -146,7 +156,7 @@ func (t *UDPTracker) cleanup() {
delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout)", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
t.sendEvent(nftypes.TypeEnd, conn)
}
}
}
@ -162,11 +172,16 @@ func (t *UDPTracker) Close() {
}
// GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key]
return conn, exists
}
@ -176,15 +191,15 @@ func (t *UDPTracker) Timeout() time.Duration {
return t.timeout
}
func (t *UDPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *UDPConnTrack) {
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
Direction: conn.Direction,
Protocol: nftypes.UDP,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: key.SrcPort,
DestPort: key.DstPort,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
})
}

View File

@ -1,7 +1,6 @@
package conntrack
import (
"net"
"net/netip"
"testing"
"time"
@ -49,10 +48,15 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
srcPort := uint16(12345)
dstPort := uint16(53)
tracker.TrackOutbound(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
// Verify connection was tracked
key := makeConnKey(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := tracker.connections[key]
require.True(t, exists)
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
@ -66,8 +70,8 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3")
srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
@ -76,8 +80,8 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
tests := []struct {
name string
srcIP net.IP
dstIP net.IP
srcIP netip.Addr
dstIP netip.Addr
srcPort uint16
dstPort uint16
sleep time.Duration
@ -94,7 +98,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
},
{
name: "invalid source IP",
srcIP: net.ParseIP("192.168.1.4"),
srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: srcIP,
srcPort: dstPort,
dstPort: srcPort,
@ -104,7 +108,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
{
name: "invalid destination IP",
srcIP: dstIP,
dstIP: net.ParseIP("192.168.1.4"),
dstIP: netip.MustParseAddr("192.168.1.4"),
srcPort: dstPort,
dstPort: srcPort,
sleep: 0,
@ -170,20 +174,20 @@ func TestUDPTracker_Cleanup(t *testing.T) {
// Add some connections
connections := []struct {
srcIP net.IP
dstIP net.IP
srcIP netip.Addr
dstIP netip.Addr
srcPort uint16
dstPort uint16
}{
{
srcIP: net.ParseIP("192.168.1.2"),
dstIP: net.ParseIP("192.168.1.3"),
srcIP: netip.MustParseAddr("192.168.1.2"),
dstIP: netip.MustParseAddr("192.168.1.3"),
srcPort: 12345,
dstPort: 53,
},
{
srcIP: net.ParseIP("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.5"),
srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: netip.MustParseAddr("192.168.1.5"),
srcPort: 12346,
dstPort: 53,
},
@ -215,8 +219,8 @@ func BenchmarkUDPTracker(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
@ -228,8 +232,8 @@ func BenchmarkUDPTracker(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {

View File

@ -3,6 +3,7 @@ package uspfilter
import (
"fmt"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
@ -31,13 +32,9 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high] |= 1 << (low % 32)
}
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
ipv4 := ip.To4()
if ipv4 == nil {
return false
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := (uint16(ip[0]) << 8) | uint16(ip[1])
low := (uint16(ip[2]) << 8) | uint16(ip[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
}
@ -122,12 +119,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
return nil
}
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
m.mu.RLock()
defer m.mu.RUnlock()
if ipv4 := ip.To4(); ipv4 != nil {
return m.checkBitmapBit(ipv4)
if ip.Is4() {
return m.checkBitmapBit(ip.AsSlice())
}
return false

View File

@ -2,6 +2,7 @@ package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
@ -13,7 +14,7 @@ func TestLocalIPManager(t *testing.T) {
tests := []struct {
name string
setupAddr iface.WGAddress
testIP net.IP
testIP netip.Addr
expected bool
}{
{
@ -25,7 +26,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("127.0.0.2"),
testIP: netip.MustParseAddr("127.0.0.2"),
expected: true,
},
{
@ -37,7 +38,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("127.0.0.1"),
testIP: netip.MustParseAddr("127.0.0.1"),
expected: true,
},
{
@ -49,7 +50,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("127.255.255.255"),
testIP: netip.MustParseAddr("127.255.255.255"),
expected: true,
},
{
@ -61,7 +62,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("192.168.1.1"),
testIP: netip.MustParseAddr("192.168.1.1"),
expected: true,
},
{
@ -73,7 +74,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32),
},
},
testIP: net.ParseIP("192.168.1.2"),
testIP: netip.MustParseAddr("192.168.1.2"),
expected: false,
},
{
@ -85,7 +86,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(64, 128),
},
},
testIP: net.ParseIP("fe80::1"),
testIP: netip.MustParseAddr("fe80::1"),
expected: false,
},
}
@ -174,7 +175,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
t.Logf("Testing %d IPs", len(tests))
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
result := manager.IsLocalIP(net.ParseIP(tt.ip))
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
})
}

View File

@ -1,7 +1,6 @@
package uspfilter
import (
"net"
"net/netip"
"github.com/google/gopacket"
@ -13,7 +12,7 @@ import (
type PeerRule struct {
id string
mgmtId []byte
ip net.IP
ip netip.Addr
ipLayer gopacket.LayerType
matchByIP bool
protoLayer gopacket.LayerType

View File

@ -2,7 +2,7 @@ package uspfilter
import (
"fmt"
"net"
"net/netip"
"time"
"github.com/google/gopacket"
@ -53,8 +53,8 @@ type TraceResult struct {
}
type PacketTrace struct {
SourceIP net.IP
DestinationIP net.IP
SourceIP netip.Addr
DestinationIP netip.Addr
Protocol string
SourcePort uint16
DestinationPort uint16
@ -72,8 +72,8 @@ type TCPState struct {
}
type PacketBuilder struct {
SrcIP net.IP
DstIP net.IP
SrcIP netip.Addr
DstIP netip.Addr
Protocol fw.Protocol
SrcPort uint16
DstPort uint16
@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
Version: 4,
TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
SrcIP: p.SrcIP,
DstIP: p.DstIP,
SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP.AsSlice(),
}
}
@ -260,7 +260,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
}
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace
}
@ -273,14 +273,14 @@ func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder
return trace
}
if m.nativeRouter {
if m.nativeRouter.Load() {
return m.handleNativeRouter(trace)
}
return m.handleRouteACLs(trace, d, srcIP, dstIP)
}
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
msg := "No existing connection found"
if allowed {
@ -309,13 +309,12 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
return msg
}
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.localForwarding {
trace.AddResult(StageRouting, "Local forwarding disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
return true
}
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
@ -341,7 +340,7 @@ func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *
}
func (m *Manager) handleRouting(trace *PacketTrace) bool {
if !m.routingEnabled {
if !m.routingEnabled.Load() {
trace.AddResult(StageRouting, "Routing disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
return false
@ -357,7 +356,7 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
return trace
}
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
proto, _ := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
@ -373,7 +372,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
}
trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder != nil {
if allowed && m.forwarder.Load() != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
}

View File

View File

@ -10,6 +10,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
@ -66,9 +67,9 @@ func (r RouteRules) Sort() {
// Manager userspace firewall manager
type Manager struct {
// outgoingRules is used for hooks only
outgoingRules map[string]RuleSet
outgoingRules map[netip.Addr]RuleSet
// incomingRules is used for filtering and hooks
incomingRules map[string]RuleSet
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
wgNetwork *net.IPNet
decoders sync.Pool
@ -80,9 +81,9 @@ type Manager struct {
// indicates whether server routes are disabled
disableServerRoutes bool
// indicates whether we forward packets not destined for ourselves
routingEnabled bool
routingEnabled atomic.Bool
// indicates whether we leave forwarding and filtering to the native firewall
nativeRouter bool
nativeRouter atomic.Bool
// indicates whether we track outbound connections
stateful bool
// indicates whether wireguards runs in netstack mode
@ -95,7 +96,7 @@ type Manager struct {
udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker
forwarder *forwarder.Forwarder
forwarder atomic.Pointer[forwarder.Forwarder]
logger *nblog.Logger
flowLogger nftypes.FlowLogger
}
@ -168,18 +169,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
},
},
nativeFirewall: nativeFirewall,
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
outgoingRules: make(map[netip.Addr]RuleSet),
incomingRules: make(map[netip.Addr]RuleSet),
wgIface: iface,
localipmanager: newLocalIPManager(),
disableServerRoutes: disableServerRoutes,
routingEnabled: false,
stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()),
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
}
m.routingEnabled.Store(false)
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err)
@ -211,7 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
}
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
if m.forwarder == nil {
if m.forwarder.Load() == nil {
return nil
}
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
@ -255,20 +256,20 @@ func (m *Manager) determineRouting() error {
switch {
case disableUspRouting:
m.routingEnabled = false
m.nativeRouter = false
m.routingEnabled.Store(false)
m.nativeRouter.Store(false)
log.Info("userspace routing is disabled")
case m.disableServerRoutes:
// if server routes are disabled we will let packets pass to the native stack
m.routingEnabled = true
m.nativeRouter = true
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
log.Info("server routes are disabled")
case forceUserspaceRouter:
m.routingEnabled = true
m.nativeRouter = false
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
log.Info("userspace routing is forced")
@ -276,19 +277,19 @@ func (m *Manager) determineRouting() error {
// if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface
m.routingEnabled = true
m.nativeRouter = true
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
log.Info("native routing is enabled")
default:
m.routingEnabled = true
m.nativeRouter = false
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
log.Info("userspace routing enabled by default")
}
if m.routingEnabled && !m.nativeRouter {
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
return m.initForwarder()
}
@ -297,24 +298,24 @@ func (m *Manager) determineRouting() error {
// initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder() error {
if m.forwarder != nil {
if m.forwarder.Load() != nil {
return nil
}
// Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := m.wgIface.GetWGDevice()
if intf == nil {
m.routingEnabled = false
m.routingEnabled.Store(false)
return errors.New("forwarding not supported")
}
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
if err != nil {
m.routingEnabled = false
m.routingEnabled.Store(false)
return fmt.Errorf("create forwarder: %w", err)
}
m.forwarder = forwarder
m.forwarder.Store(forwarder)
log.Debug("forwarder initialized")
@ -330,7 +331,7 @@ func (m *Manager) IsServerRouteSupported() bool {
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter && m.nativeFirewall != nil {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair)
}
@ -341,7 +342,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeRouter && m.nativeFirewall != nil {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.RemoveNatRule(pair)
}
return nil
@ -360,17 +361,23 @@ func (m *Manager) AddPeerFiltering(
action firewall.Action,
_ string,
) ([]firewall.Rule, error) {
// TODO: fix in upper layers
i, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("invalid IP: %s", ip)
}
i = i.Unmap()
r := PeerRule{
id: uuid.New().String(),
mgmtId: id,
ip: ip,
ip: i,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
drop: action == firewall.ActionDrop,
}
if ipNormalized := ip.To4(); ipNormalized != nil {
if i.Is4() {
r.ipLayer = layers.LayerTypeIPv4
r.ip = ipNormalized
}
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
@ -395,10 +402,10 @@ func (m *Manager) AddPeerFiltering(
}
m.mutex.Lock()
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet)
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(RuleSet)
}
m.incomingRules[r.ip.String()][r.id] = r
m.incomingRules[r.ip][r.id] = r
m.mutex.Unlock()
return []firewall.Rule{&r}, nil
}
@ -412,13 +419,10 @@ func (m *Manager) AddRouteFiltering(
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if m.nativeRouter && m.nativeFirewall != nil {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := uuid.New().String()
rule := RouteRule{
// TODO: consolidate these IDs
@ -432,14 +436,16 @@ func (m *Manager) AddRouteFiltering(
action: action,
}
m.mutex.Lock()
m.routeRules = append(m.routeRules, rule)
m.routeRules.Sort()
m.mutex.Unlock()
return &rule, nil
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter && m.nativeFirewall != nil {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule)
}
@ -468,10 +474,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
}
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
if _, ok := m.incomingRules[r.ip][r.id]; !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip.String()], r.id)
delete(m.incomingRules[r.ip], r.id)
return nil
}
@ -519,9 +525,6 @@ func (m *Manager) UpdateLocalIPs() error {
}
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@ -534,7 +537,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
}
srcIP, dstIP := m.extractIPs(d)
if srcIP == nil {
if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
return false
}
@ -551,14 +555,18 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
return false
}
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
return d.ip4.SrcIP, d.ip4.DstIP
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
return src, dst
case layers.LayerTypeIPv6:
return d.ip6.SrcIP, d.ip6.DstIP
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
return src, dst
default:
return nil, nil
return netip.Addr{}, netip.Addr{}
}
}
@ -585,7 +593,7 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags
}
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) {
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
@ -598,7 +606,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) {
}
}
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) {
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
@ -611,8 +619,11 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) {
}
}
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
func (m *Manager) checkUDPHooks(d *decoder, dstIP netip.Addr, packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
for _, ipKey := range []netip.Addr{dstIP, netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
if rules, exists := m.outgoingRules[ipKey]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
@ -627,9 +638,6 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo
// dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@ -638,7 +646,7 @@ func (m *Manager) dropFilter(packetData []byte) bool {
}
srcIP, dstIP := m.extractIPs(d)
if srcIP == nil {
if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
return true
}
@ -658,15 +666,13 @@ func (m *Manager) dropFilter(packetData []byte) bool {
// handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
_, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleId, pnum, srcAddr, srcPort, dstAddr, dstPort)
ruleId, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
@ -674,8 +680,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData
RuleID: ruleId,
Direction: nftypes.Ingress,
Protocol: pnum,
SourceIP: srcAddr,
DestIP: dstAddr,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
DestPort: dstPort,
// TODO: icmp type/code
@ -700,12 +706,12 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
return false
}
if m.forwarder == nil {
if m.forwarder.Load() == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)")
return true
}
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject local packet: %v", err)
}
@ -715,16 +721,16 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
// handleRoutedTraffic handles routed traffic.
// If it returns true, the packet should be dropped.
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
// Drop if routing is disabled
if !m.routingEnabled {
if !m.routingEnabled.Load() {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
return true
}
// Pass to native stack if native router is enabled or forced
if m.nativeRouter {
if m.nativeRouter.Load() {
return false
}
@ -732,9 +738,6 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
srcPort, dstPort := getPortsFromPacket(d)
if id, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
id, pnum, srcIP, srcPort, dstIP, dstPort)
@ -744,8 +747,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
RuleID: id,
Direction: nftypes.Ingress,
Protocol: pnum,
SourceIP: srcAddr,
DestIP: dstAddr,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
DestPort: dstPort,
// TODO: icmp type/code
@ -754,7 +757,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
}
// Let forwarder handle the packet if it passed route ACLs
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject incoming packet: %v", err)
}
@ -799,7 +802,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
return true
}
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr) bool {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return m.tcpTracker.IsValidInbound(
@ -844,20 +847,22 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
icmpType == layers.ICMPv4TypeTimeExceeded
}
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) ([]byte, bool) {
func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
if m.isSpecialICMP(d) {
return nil, false
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
return mgmtId, filter
}
@ -882,10 +887,10 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
return false
}
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range rules {
if rule.matchByIP && !ip.Equal(rule.ip) {
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
continue
}
@ -919,16 +924,13 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *de
return nil, false, false
}
// routeACLsPass returns treu if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
// routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
for _, rule := range m.routeRules {
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
return rule.mgmtId, rule.action == firewall.ActionAccept
}
}
@ -972,9 +974,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook(
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
) string {
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
r := PeerRule{
id: uuid.New().String(),
ip: ip,
@ -984,23 +984,22 @@ func (m *Manager) AddUDPPacketHook(
udpHook: hook,
}
if ip.To4() != nil {
if ip.Is4() {
r.ipLayer = layers.LayerTypeIPv4
}
m.mutex.Lock()
if in {
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule)
}
m.incomingRules[r.ip.String()][r.id] = r
m.incomingRules[r.ip][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip] = make(map[string]PeerRule)
}
m.outgoingRules[r.ip.String()][r.id] = r
m.outgoingRules[r.ip][r.id] = r
}
m.mutex.Unlock()
return r.id
@ -1048,20 +1047,21 @@ func (m *Manager) DisableRouting() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.forwarder == nil {
fwder := m.forwarder.Load()
if fwder == nil {
return nil
}
m.routingEnabled = false
m.nativeRouter = false
m.routingEnabled.Store(false)
m.nativeRouter.Store(false)
// don't stop forwarder if in use by netstack
if m.netstack && m.localForwarding {
return nil
}
m.forwarder.Stop()
m.forwarder = nil
fwder.Stop()
m.forwarder.Store(nil)
log.Debug("forwarder stopped")

View File

@ -1054,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, tc := range cases {
srcIP := net.ParseIP(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP)
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
}
}

View File

@ -306,8 +306,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err)
require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled)
require.False(tb, manager.nativeRouter)
require.True(tb, manager.routingEnabled.Load())
require.False(tb, manager.nativeRouter.Load())
tb.Cleanup(func() {
require.NoError(tb, manager.Reset(nil))
@ -818,8 +818,8 @@ func TestRouteACLFiltering(t *testing.T) {
require.NoError(t, manager.DeleteRouteRule(rule))
})
srcIP := net.ParseIP(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP)
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
// to the forwarder
@ -1006,8 +1006,8 @@ func TestRouteACLOrder(t *testing.T) {
})
for i, p := range tc.packets {
srcIP := net.ParseIP(p.srcIP)
dstIP := net.ParseIP(p.dstIP)
srcIP := netip.MustParseAddr(p.srcIP)
dstIP := netip.MustParseAddr(p.dstIP)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)

View File

@ -125,19 +125,19 @@ func TestManagerDeleteRule(t *testing.T) {
return
}
ip := net.ParseIP("192.168.1.1")
ip := netip.MustParseAddr("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
rule2, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok {
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
@ -151,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok {
if _, ok := m.incomingRules[ip][r.ID()]; ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
@ -162,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name string
in bool
expDir fw.RuleDirection
ip net.IP
ip netip.Addr
dPort uint16
hook func([]byte) bool
expectedID string
@ -171,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
ip: net.IPv4(10, 168, 0, 1),
ip: netip.MustParseAddr("10.168.0.1"),
dPort: 8000,
hook: func([]byte) bool { return true },
},
@ -179,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
ip: net.IPv6loopback,
ip: netip.MustParseAddr("::1"),
dPort: 9000,
hook: func([]byte) bool { return false },
},
@ -196,11 +196,11 @@ func TestAddUDPPacketHook(t *testing.T) {
var addedRule PeerRule
if tt.in {
if len(manager.incomingRules[tt.ip.String()]) != 1 {
if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return
}
for _, rule := range manager.incomingRules[tt.ip.String()] {
for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule
}
} else {
@ -208,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return
}
for _, rule := range manager.outgoingRules[tt.ip.String()] {
for _, rule := range manager.outgoingRules[tt.ip] {
addedRule = rule
}
}
if !tt.ip.Equal(addedRule.ip) {
if tt.ip.Compare(addedRule.ip) != 0 {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return
}
@ -357,7 +357,7 @@ func TestRemovePacketHook(t *testing.T) {
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
hookCalled := false
hookID := manager.AddUDPPacketHook(
false,
net.ParseIP("100.10.0.100"),
netip.MustParseAddr("100.10.0.100"),
53,
func([]byte) bool {
hookCalled = true
@ -573,7 +573,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
@ -641,7 +641,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
// If the connection should still be valid, verify it exists
if cp.shouldAllow {
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should still exist during valid window")
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
"LastSeen should be updated for valid responses")

View File

@ -2,6 +2,7 @@ package device
import (
"net"
"net/netip"
"sync"
"golang.zx2c4.com/wireguard/tun"
@ -19,7 +20,7 @@ type PacketFilter interface {
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error

View File

@ -6,6 +6,7 @@ package mocks
import (
net "net"
"net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string)

View File

@ -2,7 +2,7 @@ package dns
import (
"fmt"
"net"
"net/netip"
"sync"
"github.com/google/gopacket"
@ -117,5 +117,10 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
return true
}
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
ip, err := netip.ParseAddr(s.runtimeIP)
if err != nil {
return "", fmt.Errorf("parse runtime ip: %w", err)
}
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
}

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"net/netip"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
@ -41,11 +42,21 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
srcIP = engine.GetWgAddr()
}
srcAddr, ok := netip.AddrFromSlice(srcIP)
if !ok {
return nil, fmt.Errorf("invalid source IP address")
}
dstIP := net.ParseIP(req.GetDestinationIp())
if req.GetDestinationIp() == "self" {
dstIP = engine.GetWgAddr()
}
dstAddr, ok := netip.AddrFromSlice(dstIP)
if !ok {
return nil, fmt.Errorf("invalid source IP address")
}
if srcIP == nil || dstIP == nil {
return nil, fmt.Errorf("invalid IP address")
}
@ -85,8 +96,8 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
}
builder := &uspfilter.PacketBuilder{
SrcIP: srcIP,
DstIP: dstIP,
SrcIP: srcAddr,
DstIP: dstAddr,
Protocol: protocol,
SrcPort: uint16(req.GetSourcePort()),
DstPort: uint16(req.GetDestinationPort()),