Replace net.IP with netip.Addr (#3425)

This commit is contained in:
Viktor Liu 2025-03-05 18:28:05 +01:00 committed by GitHub
parent 419ed275fa
commit e9f11fb11b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 341 additions and 309 deletions

View File

@ -4,6 +4,7 @@ package uspfilter
import ( import (
"context" "context"
"net/netip"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -17,8 +18,8 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet) m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
@ -35,8 +36,8 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
} }
if m.forwarder != nil { if fwder := m.forwarder.Load(); fwder != nil {
m.forwarder.Stop() fwder.Stop()
} }
if m.logger != nil { if m.logger != nil {

View File

@ -3,6 +3,7 @@ package uspfilter
import ( import (
"context" "context"
"fmt" "fmt"
"net/netip"
"os/exec" "os/exec"
"syscall" "syscall"
"time" "time"
@ -26,8 +27,8 @@ func (m *Manager) Reset(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet) m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
@ -44,8 +45,8 @@ func (m *Manager) Reset(*statemanager.Manager) error {
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
} }
if m.forwarder != nil { if fwder := m.forwarder.Load(); fwder != nil {
m.forwarder.Stop() fwder.Stop()
} }
if m.logger != nil { if m.logger != nil {

View File

@ -2,7 +2,6 @@ package conntrack
import ( import (
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"sync/atomic" "sync/atomic"
"time" "time"
@ -52,16 +51,3 @@ type ConnKey struct {
func (c ConnKey) String() string { func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
} }
// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
return ConnKey{
SrcIP: srcAddr,
DstIP: dstAddr,
SrcPort: srcPort,
DstPort: dstPort,
}
}

View File

@ -2,7 +2,7 @@ package conntrack
import ( import (
"context" "context"
"net" "net/netip"
"testing" "testing"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -21,11 +21,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
srcIPs := make([]net.IP, 100) srcIPs := make([]netip.Addr, 100)
dstIPs := make([]net.IP, 100) dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
} }
b.ResetTimer() b.ResetTimer()
@ -46,11 +46,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
srcIPs := make([]net.IP, 100) srcIPs := make([]netip.Addr, 100)
dstIPs := make([]net.IP, 100) dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
} }
b.ResetTimer() b.ResetTimer()

View File

@ -2,7 +2,6 @@ package conntrack
import ( import (
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"sync" "sync"
"time" "time"
@ -70,8 +69,13 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
return tracker return tracker
} }
func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) (ICMPConnKey, bool) { func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16) (ICMPConnKey, bool) {
key := makeICMPKey(srcIP, dstIP, id, seq) key := ICMPConnKey{
SrcIP: srcIP,
DstIP: dstIP,
ID: id,
Sequence: seq,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@ -87,7 +91,7 @@ func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq
} }
// TrackOutbound records an outbound ICMP connection // TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) { func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Egress) t.track(srcIP, dstIP, id, seq, typecode, nftypes.Egress)
@ -95,12 +99,12 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
} }
// TrackInbound records an inbound ICMP Echo Request // TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) { func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Ingress) t.track(srcIP, dstIP, id, seq, typecode, nftypes.Ingress)
} }
// track is the common implementation for tracking both inbound and outbound ICMP connections // track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) { func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) {
// TODO: icmp doesn't need to extend the timeout // TODO: icmp doesn't need to extend the timeout
key, exists := t.updateIfExists(srcIP, dstIP, id, seq) key, exists := t.updateIfExists(srcIP, dstIP, id, seq)
if exists { if exists {
@ -112,7 +116,7 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
// non echo requests don't need tracking // non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) { if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, key, typ, code) t.sendStartEvent(direction, srcIP, dstIP, typ, code)
return return
} }
@ -120,8 +124,8 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
BaseConnTrack: BaseConnTrack{ BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(), FlowId: uuid.New(),
Direction: direction, Direction: direction,
SourceIP: key.SrcIP, SourceIP: srcIP,
DestIP: key.DstIP, DestIP: dstIP,
}, },
ICMPType: typ, ICMPType: typ,
ICMPCode: code, ICMPCode: code,
@ -133,16 +137,21 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, key, conn) t.sendEvent(nftypes.TypeStart, conn)
} }
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, icmpType uint8) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) { if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false return false
} }
key := makeICMPKey(dstIP, srcIP, id, seq) key := ICMPConnKey{
SrcIP: dstIP,
DstIP: srcIP,
ID: id,
Sequence: seq,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@ -177,7 +186,7 @@ func (t *ICMPTracker) cleanup() {
delete(t.connections, key) delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %s (timeout)", &key) t.logger.Debug("Removed ICMP connection %s (timeout)", &key)
t.sendEvent(nftypes.TypeEnd, key, conn) t.sendEvent(nftypes.TypeEnd, conn)
} }
} }
} }
@ -192,40 +201,28 @@ func (t *ICMPTracker) Close() {
t.mutex.Unlock() t.mutex.Unlock()
} }
func (t *ICMPTracker) sendEvent(typ nftypes.Type, key ICMPConnKey, conn *ICMPConnTrack) { func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{ t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId, FlowID: conn.FlowId,
Type: typ, Type: typ,
Direction: conn.Direction, Direction: conn.Direction,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6 Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
SourceIP: key.SrcIP, SourceIP: conn.SourceIP,
DestIP: key.DstIP, DestIP: conn.DestIP,
ICMPType: conn.ICMPType, ICMPType: conn.ICMPType,
ICMPCode: conn.ICMPCode, ICMPCode: conn.ICMPCode,
}) })
} }
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, key ICMPConnKey, typ, code uint8) { func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8) {
t.flowLogger.StoreEvent(nftypes.EventFields{ t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
Type: nftypes.TypeStart, Type: nftypes.TypeStart,
Direction: direction, Direction: direction,
Protocol: nftypes.ICMP, Protocol: nftypes.ICMP,
SourceIP: key.SrcIP, SourceIP: srcIP,
DestIP: key.DstIP, DestIP: dstIP,
ICMPType: typ, ICMPType: typ,
ICMPCode: code, ICMPCode: code,
}) })
} }
// makeICMPKey creates an ICMP connection key
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
return ICMPConnKey{
SrcIP: srcAddr,
DstIP: dstAddr,
ID: id,
Sequence: seq,
}
}

View File

@ -1,7 +1,7 @@
package conntrack package conntrack
import ( import (
"net" "net/netip"
"testing" "testing"
) )
@ -10,8 +10,8 @@ func BenchmarkICMPTracker(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -23,8 +23,8 @@ func BenchmarkICMPTracker(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {

View File

@ -3,7 +3,7 @@ package conntrack
// TODO: Send RST packets for invalid/timed-out connections // TODO: Send RST packets for invalid/timed-out connections
import ( import (
"net" "net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -144,8 +144,13 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker return tracker
} }
func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) { func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@ -154,7 +159,6 @@ func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
if exists { if exists {
conn.Lock() conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress) t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.UpdateLastSeen()
conn.Unlock() conn.Unlock()
return key, true return key, true
@ -164,7 +168,7 @@ func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
} }
// TrackOutbound records an outbound TCP connection // TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress)
@ -172,12 +176,12 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
} }
// TrackInbound processes an inbound TCP packet and updates connection state // TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress)
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) { func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags) key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags)
if exists { if exists {
return return
@ -187,14 +191,13 @@ func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
BaseConnTrack: BaseConnTrack{ BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(), FlowId: uuid.New(),
Direction: direction, Direction: direction,
SourceIP: key.SrcIP, SourceIP: srcIP,
DestIP: key.DstIP, DestIP: dstIP,
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
}, },
} }
conn.UpdateLastSeen()
conn.established.Store(false) conn.established.Store(false)
conn.tombstone.Store(false) conn.tombstone.Store(false)
@ -205,12 +208,17 @@ func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, key, conn) t.sendEvent(nftypes.TypeStart, conn)
} }
// IsValidInbound checks if an inbound TCP packet matches a tracked connection // IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort) key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@ -233,13 +241,12 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
conn.Unlock() conn.Unlock()
t.logger.Trace("TCP connection reset: %s", key) t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, key, conn) t.sendEvent(nftypes.TypeEnd, conn)
return true return true
} }
conn.Lock() conn.Lock()
t.updateState(key, conn, flags, false) t.updateState(key, conn, flags, false)
conn.UpdateLastSeen()
isEstablished := conn.IsEstablished() isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags) isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock() conn.Unlock()
@ -249,6 +256,8 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
// updateState updates the TCP connection state based on flags // updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) { func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
conn.UpdateLastSeen()
state := conn.State state := conn.State
defer func() { defer func() {
if state != conn.State { if state != conn.State {
@ -312,7 +321,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.State = TCPStateTimeWait conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key) t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, key, conn) t.sendEvent(nftypes.TypeEnd, conn)
} }
case TCPStateClosing: case TCPStateClosing:
@ -321,7 +330,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
// Keep established = false from previous state // Keep established = false from previous state
t.logger.Trace("TCP connection %s closed (simultaneous)", key) t.logger.Trace("TCP connection %s closed (simultaneous)", key)
t.sendEvent(nftypes.TypeEnd, key, conn) t.sendEvent(nftypes.TypeEnd, conn)
} }
case TCPStateCloseWait: case TCPStateCloseWait:
@ -335,7 +344,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.SetTombstone() conn.SetTombstone()
// Send close event for gracefully closed connections // Send close event for gracefully closed connections
t.sendEvent(nftypes.TypeEnd, key, conn) t.sendEvent(nftypes.TypeEnd, conn)
t.logger.Trace("TCP connection %s closed gracefully", key) t.logger.Trace("TCP connection %s closed gracefully", key)
} }
} }
@ -422,7 +431,7 @@ func (t *TCPTracker) cleanup() {
// event already handled by state change // event already handled by state change
if conn.State != TCPStateTimeWait { if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, key, conn) t.sendEvent(nftypes.TypeEnd, conn)
} }
} }
} }
@ -453,15 +462,15 @@ func isValidFlagCombination(flags uint8) bool {
return true return true
} }
func (t *TCPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *TCPConnTrack) { func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{ t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId, FlowID: conn.FlowId,
Type: typ, Type: typ,
Direction: conn.Direction, Direction: conn.Direction,
Protocol: nftypes.TCP, Protocol: nftypes.TCP,
SourceIP: key.SrcIP, SourceIP: conn.SourceIP,
DestIP: key.DstIP, DestIP: conn.DestIP,
SourcePort: key.SrcPort, SourcePort: conn.SourcePort,
DestPort: key.DstPort, DestPort: conn.DestPort,
}) })
} }

View File

@ -1,7 +1,7 @@
package conntrack package conntrack
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
@ -12,8 +12,8 @@ func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2") dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(80) dstPort := uint16(80)
@ -165,8 +165,8 @@ func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2") dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(80) dstPort := uint16(80)
@ -208,7 +208,12 @@ func TestRSTHandling(t *testing.T) {
tt.sendRST() tt.sendRST()
// Verify connection state is as expected // Verify connection state is as expected
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key] conn := tracker.connections[key]
if tt.wantValid { if tt.wantValid {
require.NotNil(t, conn) require.NotNil(t, conn)
@ -220,7 +225,7 @@ func TestRSTHandling(t *testing.T) {
} }
// Helper to establish a TCP connection // Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) { func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.Helper() t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
@ -236,8 +241,8 @@ func BenchmarkTCPTracker(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -249,8 +254,8 @@ func BenchmarkTCPTracker(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
@ -267,8 +272,8 @@ func BenchmarkTCPTracker(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
i := 0 i := 0
@ -291,8 +296,8 @@ func BenchmarkCleanup(b *testing.B) {
defer tracker.Close() defer tracker.Close()
// Pre-populate with expired connections // Pre-populate with expired connections
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ { for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
} }

View File

@ -1,7 +1,7 @@
package conntrack package conntrack
import ( import (
"net" "net/netip"
"sync" "sync"
"time" "time"
@ -54,7 +54,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
} }
// TrackOutbound records an outbound UDP connection // TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress)
@ -62,12 +62,17 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
} }
// TrackInbound records an inbound UDP connection // TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress)
} }
func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) (ConnKey, bool) { func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) (ConnKey, bool) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@ -82,7 +87,7 @@ func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, direction nftypes.Direction) { func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort) key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort)
if exists { if exists {
return return
@ -92,8 +97,8 @@ func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
BaseConnTrack: BaseConnTrack{ BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(), FlowId: uuid.New(),
Direction: direction, Direction: direction,
SourceIP: key.SrcIP, SourceIP: srcIP,
DestIP: key.DstIP, DestIP: dstIP,
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
}, },
@ -105,12 +110,17 @@ func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key) t.logger.Trace("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, key, conn) t.sendEvent(nftypes.TypeStart, conn)
} }
// IsValidInbound checks if an inbound packet matches a tracked connection // IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort) key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@ -146,7 +156,7 @@ func (t *UDPTracker) cleanup() {
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout)", key) t.logger.Trace("Removed UDP connection %s (timeout)", key)
t.sendEvent(nftypes.TypeEnd, key, conn) t.sendEvent(nftypes.TypeEnd, conn)
} }
} }
} }
@ -162,11 +172,16 @@ func (t *UDPTracker) Close() {
} }
// GetConnection safely retrieves a connection state // GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) { func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock() t.mutex.RLock()
defer t.mutex.RUnlock() defer t.mutex.RUnlock()
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key] conn, exists := t.connections[key]
return conn, exists return conn, exists
} }
@ -176,15 +191,15 @@ func (t *UDPTracker) Timeout() time.Duration {
return t.timeout return t.timeout
} }
func (t *UDPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *UDPConnTrack) { func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{ t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId, FlowID: conn.FlowId,
Type: typ, Type: typ,
Direction: conn.Direction, Direction: conn.Direction,
Protocol: nftypes.UDP, Protocol: nftypes.UDP,
SourceIP: key.SrcIP, SourceIP: conn.SourceIP,
DestIP: key.DstIP, DestIP: conn.DestIP,
SourcePort: key.SrcPort, SourcePort: conn.SourcePort,
DestPort: key.DstPort, DestPort: conn.DestPort,
}) })
} }

View File

@ -1,7 +1,6 @@
package conntrack package conntrack
import ( import (
"net"
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
@ -49,10 +48,15 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(53) dstPort := uint16(53)
tracker.TrackOutbound(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
// Verify connection was tracked // Verify connection was tracked
key := makeConnKey(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := tracker.connections[key] conn, exists := tracker.connections[key]
require.True(t, exists) require.True(t, exists)
assert.True(t, conn.SourceIP.Compare(srcIP) == 0) assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
@ -66,8 +70,8 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, logger, flowLogger) tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3") dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(53) dstPort := uint16(53)
@ -76,8 +80,8 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
srcIP net.IP srcIP netip.Addr
dstIP net.IP dstIP netip.Addr
srcPort uint16 srcPort uint16
dstPort uint16 dstPort uint16
sleep time.Duration sleep time.Duration
@ -94,7 +98,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
}, },
{ {
name: "invalid source IP", name: "invalid source IP",
srcIP: net.ParseIP("192.168.1.4"), srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: srcIP, dstIP: srcIP,
srcPort: dstPort, srcPort: dstPort,
dstPort: srcPort, dstPort: srcPort,
@ -104,7 +108,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
{ {
name: "invalid destination IP", name: "invalid destination IP",
srcIP: dstIP, srcIP: dstIP,
dstIP: net.ParseIP("192.168.1.4"), dstIP: netip.MustParseAddr("192.168.1.4"),
srcPort: dstPort, srcPort: dstPort,
dstPort: srcPort, dstPort: srcPort,
sleep: 0, sleep: 0,
@ -170,20 +174,20 @@ func TestUDPTracker_Cleanup(t *testing.T) {
// Add some connections // Add some connections
connections := []struct { connections := []struct {
srcIP net.IP srcIP netip.Addr
dstIP net.IP dstIP netip.Addr
srcPort uint16 srcPort uint16
dstPort uint16 dstPort uint16
}{ }{
{ {
srcIP: net.ParseIP("192.168.1.2"), srcIP: netip.MustParseAddr("192.168.1.2"),
dstIP: net.ParseIP("192.168.1.3"), dstIP: netip.MustParseAddr("192.168.1.3"),
srcPort: 12345, srcPort: 12345,
dstPort: 53, dstPort: 53,
}, },
{ {
srcIP: net.ParseIP("192.168.1.4"), srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.5"), dstIP: netip.MustParseAddr("192.168.1.5"),
srcPort: 12346, srcPort: 12346,
dstPort: 53, dstPort: 53,
}, },
@ -215,8 +219,8 @@ func BenchmarkUDPTracker(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -228,8 +232,8 @@ func BenchmarkUDPTracker(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {

View File

@ -3,6 +3,7 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -31,13 +32,9 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high] |= 1 << (low % 32) m.ipv4Bitmap[high] |= 1 << (low % 32)
} }
func (m *localIPManager) checkBitmapBit(ip net.IP) bool { func (m *localIPManager) checkBitmapBit(ip []byte) bool {
ipv4 := ip.To4() high := (uint16(ip[0]) << 8) | uint16(ip[1])
if ipv4 == nil { low := (uint16(ip[2]) << 8) | uint16(ip[3])
return false
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0 return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
} }
@ -122,12 +119,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
return nil return nil
} }
func (m *localIPManager) IsLocalIP(ip net.IP) bool { func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
m.mu.RLock() m.mu.RLock()
defer m.mu.RUnlock() defer m.mu.RUnlock()
if ipv4 := ip.To4(); ipv4 != nil { if ip.Is4() {
return m.checkBitmapBit(ipv4) return m.checkBitmapBit(ip.AsSlice())
} }
return false return false

View File

@ -2,6 +2,7 @@ package uspfilter
import ( import (
"net" "net"
"net/netip"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -13,7 +14,7 @@ func TestLocalIPManager(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
setupAddr iface.WGAddress setupAddr iface.WGAddress
testIP net.IP testIP netip.Addr
expected bool expected bool
}{ }{
{ {
@ -25,7 +26,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.0.0.2"), testIP: netip.MustParseAddr("127.0.0.2"),
expected: true, expected: true,
}, },
{ {
@ -37,7 +38,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.0.0.1"), testIP: netip.MustParseAddr("127.0.0.1"),
expected: true, expected: true,
}, },
{ {
@ -49,7 +50,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.255.255.255"), testIP: netip.MustParseAddr("127.255.255.255"),
expected: true, expected: true,
}, },
{ {
@ -61,7 +62,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("192.168.1.1"), testIP: netip.MustParseAddr("192.168.1.1"),
expected: true, expected: true,
}, },
{ {
@ -73,7 +74,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("192.168.1.2"), testIP: netip.MustParseAddr("192.168.1.2"),
expected: false, expected: false,
}, },
{ {
@ -85,7 +86,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(64, 128), Mask: net.CIDRMask(64, 128),
}, },
}, },
testIP: net.ParseIP("fe80::1"), testIP: netip.MustParseAddr("fe80::1"),
expected: false, expected: false,
}, },
} }
@ -174,7 +175,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
t.Logf("Testing %d IPs", len(tests)) t.Logf("Testing %d IPs", len(tests))
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) { t.Run(tt.ip, func(t *testing.T) {
result := manager.IsLocalIP(net.ParseIP(tt.ip)) result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
require.Equal(t, tt.expected, result, "IP: %s", tt.ip) require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
}) })
} }

View File

@ -1,7 +1,6 @@
package uspfilter package uspfilter
import ( import (
"net"
"net/netip" "net/netip"
"github.com/google/gopacket" "github.com/google/gopacket"
@ -13,7 +12,7 @@ import (
type PeerRule struct { type PeerRule struct {
id string id string
mgmtId []byte mgmtId []byte
ip net.IP ip netip.Addr
ipLayer gopacket.LayerType ipLayer gopacket.LayerType
matchByIP bool matchByIP bool
protoLayer gopacket.LayerType protoLayer gopacket.LayerType

View File

@ -2,7 +2,7 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net/netip"
"time" "time"
"github.com/google/gopacket" "github.com/google/gopacket"
@ -53,8 +53,8 @@ type TraceResult struct {
} }
type PacketTrace struct { type PacketTrace struct {
SourceIP net.IP SourceIP netip.Addr
DestinationIP net.IP DestinationIP netip.Addr
Protocol string Protocol string
SourcePort uint16 SourcePort uint16
DestinationPort uint16 DestinationPort uint16
@ -72,8 +72,8 @@ type TCPState struct {
} }
type PacketBuilder struct { type PacketBuilder struct {
SrcIP net.IP SrcIP netip.Addr
DstIP net.IP DstIP netip.Addr
Protocol fw.Protocol Protocol fw.Protocol
SrcPort uint16 SrcPort uint16
DstPort uint16 DstPort uint16
@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)), Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
SrcIP: p.SrcIP, SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP, DstIP: p.DstIP.AsSlice(),
} }
} }
@ -260,7 +260,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
return m.traceInbound(packetData, trace, d, srcIP, dstIP) return m.traceInbound(packetData, trace, d, srcIP, dstIP)
} }
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace { func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace return trace
} }
@ -273,14 +273,14 @@ func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder
return trace return trace
} }
if m.nativeRouter { if m.nativeRouter.Load() {
return m.handleNativeRouter(trace) return m.handleNativeRouter(trace)
} }
return m.handleRouteACLs(trace, d, srcIP, dstIP) return m.handleRouteACLs(trace, d, srcIP, dstIP)
} }
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP) allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
msg := "No existing connection found" msg := "No existing connection found"
if allowed { if allowed {
@ -309,13 +309,12 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
return msg return msg
} }
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.localForwarding { if !m.localForwarding {
trace.AddResult(StageRouting, "Local forwarding disabled", false) trace.AddResult(StageRouting, "Local forwarding disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false) trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
return true return true
} }
trace.AddResult(StageRouting, "Packet destined for local delivery", true) trace.AddResult(StageRouting, "Packet destined for local delivery", true)
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
@ -341,7 +340,7 @@ func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *
} }
func (m *Manager) handleRouting(trace *PacketTrace) bool { func (m *Manager) handleRouting(trace *PacketTrace) bool {
if !m.routingEnabled { if !m.routingEnabled.Load() {
trace.AddResult(StageRouting, "Routing disabled", false) trace.AddResult(StageRouting, "Routing disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false) trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
return false return false
@ -357,7 +356,7 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
return trace return trace
} }
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace { func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
proto, _ := getProtocolFromPacket(d) proto, _ := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
@ -373,7 +372,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
} }
trace.AddResult(StageRouteACL, msg, allowed) trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder != nil { if allowed && m.forwarder.Load() != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true) m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
} }

View File

View File

@ -10,6 +10,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
@ -66,9 +67,9 @@ func (r RouteRules) Sort() {
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
// outgoingRules is used for hooks only // outgoingRules is used for hooks only
outgoingRules map[string]RuleSet outgoingRules map[netip.Addr]RuleSet
// incomingRules is used for filtering and hooks // incomingRules is used for filtering and hooks
incomingRules map[string]RuleSet incomingRules map[netip.Addr]RuleSet
routeRules RouteRules routeRules RouteRules
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
@ -80,9 +81,9 @@ type Manager struct {
// indicates whether server routes are disabled // indicates whether server routes are disabled
disableServerRoutes bool disableServerRoutes bool
// indicates whether we forward packets not destined for ourselves // indicates whether we forward packets not destined for ourselves
routingEnabled bool routingEnabled atomic.Bool
// indicates whether we leave forwarding and filtering to the native firewall // indicates whether we leave forwarding and filtering to the native firewall
nativeRouter bool nativeRouter atomic.Bool
// indicates whether we track outbound connections // indicates whether we track outbound connections
stateful bool stateful bool
// indicates whether wireguards runs in netstack mode // indicates whether wireguards runs in netstack mode
@ -95,7 +96,7 @@ type Manager struct {
udpTracker *conntrack.UDPTracker udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker tcpTracker *conntrack.TCPTracker
forwarder *forwarder.Forwarder forwarder atomic.Pointer[forwarder.Forwarder]
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
@ -168,18 +169,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
}, },
}, },
nativeFirewall: nativeFirewall, nativeFirewall: nativeFirewall,
outgoingRules: make(map[string]RuleSet), outgoingRules: make(map[netip.Addr]RuleSet),
incomingRules: make(map[string]RuleSet), incomingRules: make(map[netip.Addr]RuleSet),
wgIface: iface, wgIface: iface,
localipmanager: newLocalIPManager(), localipmanager: newLocalIPManager(),
disableServerRoutes: disableServerRoutes, disableServerRoutes: disableServerRoutes,
routingEnabled: false,
stateful: !disableConntrack, stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()), logger: nblog.NewFromLogrus(log.StandardLogger()),
flowLogger: flowLogger, flowLogger: flowLogger,
netstack: netstack.IsEnabled(), netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding, localForwarding: enableLocalForwarding,
} }
m.routingEnabled.Store(false)
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err) return nil, fmt.Errorf("update local IPs: %w", err)
@ -211,7 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
} }
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error { func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
if m.forwarder == nil { if m.forwarder.Load() == nil {
return nil return nil
} }
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
@ -255,20 +256,20 @@ func (m *Manager) determineRouting() error {
switch { switch {
case disableUspRouting: case disableUspRouting:
m.routingEnabled = false m.routingEnabled.Store(false)
m.nativeRouter = false m.nativeRouter.Store(false)
log.Info("userspace routing is disabled") log.Info("userspace routing is disabled")
case m.disableServerRoutes: case m.disableServerRoutes:
// if server routes are disabled we will let packets pass to the native stack // if server routes are disabled we will let packets pass to the native stack
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = true m.nativeRouter.Store(true)
log.Info("server routes are disabled") log.Info("server routes are disabled")
case forceUserspaceRouter: case forceUserspaceRouter:
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = false m.nativeRouter.Store(false)
log.Info("userspace routing is forced") log.Info("userspace routing is forced")
@ -276,19 +277,19 @@ func (m *Manager) determineRouting() error {
// if the OS supports routing natively, then we don't need to filter/route ourselves // if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface // netstack mode won't support native routing as there is no interface
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = true m.nativeRouter.Store(true)
log.Info("native routing is enabled") log.Info("native routing is enabled")
default: default:
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = false m.nativeRouter.Store(false)
log.Info("userspace routing enabled by default") log.Info("userspace routing enabled by default")
} }
if m.routingEnabled && !m.nativeRouter { if m.routingEnabled.Load() && !m.nativeRouter.Load() {
return m.initForwarder() return m.initForwarder()
} }
@ -297,24 +298,24 @@ func (m *Manager) determineRouting() error {
// initForwarder initializes the forwarder, it disables routing on errors // initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder() error { func (m *Manager) initForwarder() error {
if m.forwarder != nil { if m.forwarder.Load() != nil {
return nil return nil
} }
// Only supported in userspace mode as we need to inject packets back into wireguard directly // Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := m.wgIface.GetWGDevice() intf := m.wgIface.GetWGDevice()
if intf == nil { if intf == nil {
m.routingEnabled = false m.routingEnabled.Store(false)
return errors.New("forwarding not supported") return errors.New("forwarding not supported")
} }
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack) forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
if err != nil { if err != nil {
m.routingEnabled = false m.routingEnabled.Store(false)
return fmt.Errorf("create forwarder: %w", err) return fmt.Errorf("create forwarder: %w", err)
} }
m.forwarder = forwarder m.forwarder.Store(forwarder)
log.Debug("forwarder initialized") log.Debug("forwarder initialized")
@ -330,7 +331,7 @@ func (m *Manager) IsServerRouteSupported() bool {
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair) return m.nativeFirewall.AddNatRule(pair)
} }
@ -341,7 +342,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes a routing firewall rule // RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.RemoveNatRule(pair) return m.nativeFirewall.RemoveNatRule(pair)
} }
return nil return nil
@ -360,17 +361,23 @@ func (m *Manager) AddPeerFiltering(
action firewall.Action, action firewall.Action,
_ string, _ string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
// TODO: fix in upper layers
i, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("invalid IP: %s", ip)
}
i = i.Unmap()
r := PeerRule{ r := PeerRule{
id: uuid.New().String(), id: uuid.New().String(),
mgmtId: id, mgmtId: id,
ip: ip, ip: i,
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
matchByIP: true, matchByIP: true,
drop: action == firewall.ActionDrop, drop: action == firewall.ActionDrop,
} }
if ipNormalized := ip.To4(); ipNormalized != nil { if i.Is4() {
r.ipLayer = layers.LayerTypeIPv4 r.ipLayer = layers.LayerTypeIPv4
r.ip = ipNormalized
} }
if s := r.ip.String(); s == "0.0.0.0" || s == "::" { if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
@ -395,10 +402,10 @@ func (m *Manager) AddPeerFiltering(
} }
m.mutex.Lock() m.mutex.Lock()
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet) m.incomingRules[r.ip] = make(RuleSet)
} }
m.incomingRules[r.ip.String()][r.id] = r m.incomingRules[r.ip][r.id] = r
m.mutex.Unlock() m.mutex.Unlock()
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
@ -412,13 +419,10 @@ func (m *Manager) AddRouteFiltering(
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := uuid.New().String() ruleID := uuid.New().String()
rule := RouteRule{ rule := RouteRule{
// TODO: consolidate these IDs // TODO: consolidate these IDs
@ -432,14 +436,16 @@ func (m *Manager) AddRouteFiltering(
action: action, action: action,
} }
m.mutex.Lock()
m.routeRules = append(m.routeRules, rule) m.routeRules = append(m.routeRules, rule)
m.routeRules.Sort() m.routeRules.Sort()
m.mutex.Unlock()
return &rule, nil return &rule, nil
} }
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule) return m.nativeFirewall.DeleteRouteRule(rule)
} }
@ -468,10 +474,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok { if _, ok := m.incomingRules[r.ip][r.id]; !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id) return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
} }
delete(m.incomingRules[r.ip.String()], r.id) delete(m.incomingRules[r.ip], r.id)
return nil return nil
} }
@ -519,9 +525,6 @@ func (m *Manager) UpdateLocalIPs() error {
} }
func (m *Manager) processOutgoingHooks(packetData []byte) bool { func (m *Manager) processOutgoingHooks(packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
@ -534,7 +537,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
} }
srcIP, dstIP := m.extractIPs(d) srcIP, dstIP := m.extractIPs(d)
if srcIP == nil { if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
return false return false
} }
@ -551,14 +555,18 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
return false return false
} }
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) { func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
switch d.decoded[0] { switch d.decoded[0] {
case layers.LayerTypeIPv4: case layers.LayerTypeIPv4:
return d.ip4.SrcIP, d.ip4.DstIP src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
return src, dst
case layers.LayerTypeIPv6: case layers.LayerTypeIPv6:
return d.ip6.SrcIP, d.ip6.DstIP src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
return src, dst
default: default:
return nil, nil return netip.Addr{}, netip.Addr{}
} }
} }
@ -585,7 +593,7 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags return flags
} }
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) { func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr) {
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
@ -598,7 +606,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) {
} }
} }
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) { func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr) {
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
@ -611,8 +619,11 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) {
} }
} }
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool { func (m *Manager) checkUDPHooks(d *decoder, dstIP netip.Addr, packetData []byte) bool {
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { m.mutex.RLock()
defer m.mutex.RUnlock()
for _, ipKey := range []netip.Addr{dstIP, netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
if rules, exists := m.outgoingRules[ipKey]; exists { if rules, exists := m.outgoingRules[ipKey]; exists {
for _, rule := range rules { for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) { if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
@ -627,9 +638,6 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo
// dropFilter implements filtering logic for incoming packets. // dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte) bool { func (m *Manager) dropFilter(packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
@ -638,7 +646,7 @@ func (m *Manager) dropFilter(packetData []byte) bool {
} }
srcIP, dstIP := m.extractIPs(d) srcIP, dstIP := m.extractIPs(d)
if srcIP == nil { if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0]) m.logger.Error("Unknown network layer: %v", d.decoded[0])
return true return true
} }
@ -658,15 +666,13 @@ func (m *Manager) dropFilter(packetData []byte) bool {
// handleLocalTraffic handles local traffic. // handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked { if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
_, pnum := getProtocolFromPacket(d) _, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleId, pnum, srcAddr, srcPort, dstAddr, dstPort) ruleId, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{ m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
@ -674,8 +680,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData
RuleID: ruleId, RuleID: ruleId,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: pnum, Protocol: pnum,
SourceIP: srcAddr, SourceIP: srcIP,
DestIP: dstAddr, DestIP: dstIP,
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
// TODO: icmp type/code // TODO: icmp type/code
@ -700,12 +706,12 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
return false return false
} }
if m.forwarder == nil { if m.forwarder.Load() == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)") m.logger.Trace("Dropping local packet (forwarder not initialized)")
return true return true
} }
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject local packet: %v", err) m.logger.Error("Failed to inject local packet: %v", err)
} }
@ -715,16 +721,16 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
// handleRoutedTraffic handles routed traffic. // handleRoutedTraffic handles routed traffic.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
// Drop if routing is disabled // Drop if routing is disabled
if !m.routingEnabled { if !m.routingEnabled.Load() {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s", m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP) srcIP, dstIP)
return true return true
} }
// Pass to native stack if native router is enabled or forced // Pass to native stack if native router is enabled or forced
if m.nativeRouter { if m.nativeRouter.Load() {
return false return false
} }
@ -732,9 +738,6 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
if id, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass { if id, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
id, pnum, srcIP, srcPort, dstIP, dstPort) id, pnum, srcIP, srcPort, dstIP, dstPort)
@ -744,8 +747,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
RuleID: id, RuleID: id,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: pnum, Protocol: pnum,
SourceIP: srcAddr, SourceIP: srcIP,
DestIP: dstAddr, DestIP: dstIP,
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
// TODO: icmp type/code // TODO: icmp type/code
@ -754,7 +757,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
} }
// Let forwarder handle the packet if it passed route ACLs // Let forwarder handle the packet if it passed route ACLs
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject incoming packet: %v", err) m.logger.Error("Failed to inject incoming packet: %v", err)
} }
@ -799,7 +802,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
return true return true
} }
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr) bool {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
return m.tcpTracker.IsValidInbound( return m.tcpTracker.IsValidInbound(
@ -844,20 +847,22 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
icmpType == layers.ICMPv4TypeTimeExceeded icmpType == layers.ICMPv4TypeTimeExceeded
} }
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) ([]byte, bool) { func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
if m.isSpecialICMP(d) { if m.isSpecialICMP(d) {
return nil, false return nil, false
} }
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
return mgmtId, filter return mgmtId, filter
} }
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
return mgmtId, filter return mgmtId, filter
} }
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
return mgmtId, filter return mgmtId, filter
} }
@ -882,10 +887,10 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
return false return false
} }
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) { func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1] payloadLayer := d.decoded[1]
for _, rule := range rules { for _, rule := range rules {
if rule.matchByIP && !ip.Equal(rule.ip) { if rule.matchByIP && ip.Compare(rule.ip) != 0 {
continue continue
} }
@ -919,16 +924,13 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *de
return nil, false, false return nil, false, false
} }
// routeACLsPass returns treu if the packet is allowed by the route ACLs // routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) { func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
for _, rule := range m.routeRules { for _, rule := range m.routeRules {
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) { if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
return rule.mgmtId, rule.action == firewall.ActionAccept return rule.mgmtId, rule.action == firewall.ActionAccept
} }
} }
@ -972,9 +974,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not // Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook( func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
) string {
r := PeerRule{ r := PeerRule{
id: uuid.New().String(), id: uuid.New().String(),
ip: ip, ip: ip,
@ -984,23 +984,22 @@ func (m *Manager) AddUDPPacketHook(
udpHook: hook, udpHook: hook,
} }
if ip.To4() != nil { if ip.Is4() {
r.ipLayer = layers.LayerTypeIPv4 r.ipLayer = layers.LayerTypeIPv4
} }
m.mutex.Lock() m.mutex.Lock()
if in { if in {
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]PeerRule) m.incomingRules[r.ip] = make(map[string]PeerRule)
} }
m.incomingRules[r.ip.String()][r.id] = r m.incomingRules[r.ip][r.id] = r
} else { } else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok { if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule) m.outgoingRules[r.ip] = make(map[string]PeerRule)
} }
m.outgoingRules[r.ip.String()][r.id] = r m.outgoingRules[r.ip][r.id] = r
} }
m.mutex.Unlock() m.mutex.Unlock()
return r.id return r.id
@ -1048,20 +1047,21 @@ func (m *Manager) DisableRouting() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.forwarder == nil { fwder := m.forwarder.Load()
if fwder == nil {
return nil return nil
} }
m.routingEnabled = false m.routingEnabled.Store(false)
m.nativeRouter = false m.nativeRouter.Store(false)
// don't stop forwarder if in use by netstack // don't stop forwarder if in use by netstack
if m.netstack && m.localForwarding { if m.netstack && m.localForwarding {
return nil return nil
} }
m.forwarder.Stop() fwder.Stop()
m.forwarder = nil m.forwarder.Store(nil)
log.Debug("forwarder stopped") log.Debug("forwarder stopped")

View File

@ -1054,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, tc := range cases { for _, tc := range cases {
srcIP := net.ParseIP(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort) manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
} }
} }

View File

@ -306,8 +306,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
require.NoError(tb, manager.EnableRouting()) require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err) require.NoError(tb, err)
require.NotNil(tb, manager) require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled) require.True(tb, manager.routingEnabled.Load())
require.False(tb, manager.nativeRouter) require.False(tb, manager.nativeRouter.Load())
tb.Cleanup(func() { tb.Cleanup(func() {
require.NoError(tb, manager.Reset(nil)) require.NoError(tb, manager.Reset(nil))
@ -818,8 +818,8 @@ func TestRouteACLFiltering(t *testing.T) {
require.NoError(t, manager.DeleteRouteRule(rule)) require.NoError(t, manager.DeleteRouteRule(rule))
}) })
srcIP := net.ParseIP(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed // testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
// to the forwarder // to the forwarder
@ -1006,8 +1006,8 @@ func TestRouteACLOrder(t *testing.T) {
}) })
for i, p := range tc.packets { for i, p := range tc.packets {
srcIP := net.ParseIP(p.srcIP) srcIP := netip.MustParseAddr(p.srcIP)
dstIP := net.ParseIP(p.dstIP) dstIP := netip.MustParseAddr(p.dstIP)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)

View File

@ -125,19 +125,19 @@ func TestManagerDeleteRule(t *testing.T) {
return return
} }
ip := net.ParseIP("192.168.1.1") ip := netip.MustParseAddr("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}} port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop action := fw.ActionDrop
rule2, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "") rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok { if _, ok := m.incomingRules[ip][r.ID()]; !ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@ -151,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
} }
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok { if _, ok := m.incomingRules[ip][r.ID()]; ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@ -162,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name string name string
in bool in bool
expDir fw.RuleDirection expDir fw.RuleDirection
ip net.IP ip netip.Addr
dPort uint16 dPort uint16
hook func([]byte) bool hook func([]byte) bool
expectedID string expectedID string
@ -171,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Outgoing UDP Packet Hook", name: "Test Outgoing UDP Packet Hook",
in: false, in: false,
expDir: fw.RuleDirectionOUT, expDir: fw.RuleDirectionOUT,
ip: net.IPv4(10, 168, 0, 1), ip: netip.MustParseAddr("10.168.0.1"),
dPort: 8000, dPort: 8000,
hook: func([]byte) bool { return true }, hook: func([]byte) bool { return true },
}, },
@ -179,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Incoming UDP Packet Hook", name: "Test Incoming UDP Packet Hook",
in: true, in: true,
expDir: fw.RuleDirectionIN, expDir: fw.RuleDirectionIN,
ip: net.IPv6loopback, ip: netip.MustParseAddr("::1"),
dPort: 9000, dPort: 9000,
hook: func([]byte) bool { return false }, hook: func([]byte) bool { return false },
}, },
@ -196,11 +196,11 @@ func TestAddUDPPacketHook(t *testing.T) {
var addedRule PeerRule var addedRule PeerRule
if tt.in { if tt.in {
if len(manager.incomingRules[tt.ip.String()]) != 1 { if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return return
} }
for _, rule := range manager.incomingRules[tt.ip.String()] { for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule addedRule = rule
} }
} else { } else {
@ -208,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return return
} }
for _, rule := range manager.outgoingRules[tt.ip.String()] { for _, rule := range manager.outgoingRules[tt.ip] {
addedRule = rule addedRule = rule
} }
} }
if !tt.ip.Equal(addedRule.ip) { if tt.ip.Compare(addedRule.ip) != 0 {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip) t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return return
} }
@ -357,7 +357,7 @@ func TestRemovePacketHook(t *testing.T) {
// Add a UDP packet hook // Add a UDP packet hook
hookFunc := func(data []byte) bool { return true } hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc) hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
// Assert the hook is added by finding it in the manager's outgoing rules // Assert the hook is added by finding it in the manager's outgoing rules
found := false found := false
@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
hookCalled := false hookCalled := false
hookID := manager.AddUDPPacketHook( hookID := manager.AddUDPPacketHook(
false, false,
net.ParseIP("100.10.0.100"), netip.MustParseAddr("100.10.0.100"),
53, 53,
func([]byte) bool { func([]byte) bool {
hookCalled = true hookCalled = true
@ -573,7 +573,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.False(t, drop, "Initial outbound packet should not be dropped") require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked // Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort) conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet") require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match") require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
@ -641,7 +641,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
// If the connection should still be valid, verify it exists // If the connection should still be valid, verify it exists
if cp.shouldAllow { if cp.shouldAllow {
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort) conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should still exist during valid window") require.True(t, exists, "Connection should still exist during valid window")
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(), require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
"LastSeen should be updated for valid responses") "LastSeen should be updated for valid responses")

View File

@ -2,6 +2,7 @@ package device
import ( import (
"net" "net"
"net/netip"
"sync" "sync"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
@ -19,7 +20,7 @@ type PacketFilter interface {
// //
// Hook function returns flag which indicates should be the matched package dropped or not. // Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument. // Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID // RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error RemovePacketHook(hookID string) error

View File

@ -6,6 +6,7 @@ package mocks
import ( import (
net "net" net "net"
"net/netip"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
} }
// AddUDPPacketHook mocks base method. // AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string { func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string) ret0, _ := ret[0].(string)

View File

@ -2,7 +2,7 @@ package dns
import ( import (
"fmt" "fmt"
"net" "net/netip"
"sync" "sync"
"github.com/google/gopacket" "github.com/google/gopacket"
@ -117,5 +117,10 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
return true return true
} }
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil ip, err := netip.ParseAddr(s.runtimeIP)
if err != nil {
return "", fmt.Errorf("parse runtime ip: %w", err)
}
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
@ -41,11 +42,21 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
srcIP = engine.GetWgAddr() srcIP = engine.GetWgAddr()
} }
srcAddr, ok := netip.AddrFromSlice(srcIP)
if !ok {
return nil, fmt.Errorf("invalid source IP address")
}
dstIP := net.ParseIP(req.GetDestinationIp()) dstIP := net.ParseIP(req.GetDestinationIp())
if req.GetDestinationIp() == "self" { if req.GetDestinationIp() == "self" {
dstIP = engine.GetWgAddr() dstIP = engine.GetWgAddr()
} }
dstAddr, ok := netip.AddrFromSlice(dstIP)
if !ok {
return nil, fmt.Errorf("invalid source IP address")
}
if srcIP == nil || dstIP == nil { if srcIP == nil || dstIP == nil {
return nil, fmt.Errorf("invalid IP address") return nil, fmt.Errorf("invalid IP address")
} }
@ -85,8 +96,8 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
} }
builder := &uspfilter.PacketBuilder{ builder := &uspfilter.PacketBuilder{
SrcIP: srcIP, SrcIP: srcAddr,
DstIP: dstIP, DstIP: dstAddr,
Protocol: protocol, Protocol: protocol,
SrcPort: uint16(req.GetSourcePort()), SrcPort: uint16(req.GetSourcePort()),
DstPort: uint16(req.GetDestinationPort()), DstPort: uint16(req.GetDestinationPort()),