mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-28 13:42:31 +02:00
Replace net.IP with netip.Addr (#3425)
This commit is contained in:
parent
419ed275fa
commit
e9f11fb11b
@ -4,6 +4,7 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -17,8 +18,8 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
@ -35,8 +36,8 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
|
@ -3,6 +3,7 @@ package uspfilter
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
@ -26,8 +27,8 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
@ -44,8 +45,8 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
|
@ -2,7 +2,6 @@ package conntrack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -52,16 +51,3 @@ type ConnKey struct {
|
||||
func (c ConnKey) String() string {
|
||||
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||
}
|
||||
|
||||
// makeConnKey creates a connection key
|
||||
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
||||
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
||||
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
||||
|
||||
return ConnKey{
|
||||
SrcIP: srcAddr,
|
||||
DstIP: dstAddr,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ package conntrack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -21,11 +21,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
srcIPs := make([]net.IP, 100)
|
||||
dstIPs := make([]net.IP, 100)
|
||||
srcIPs := make([]netip.Addr, 100)
|
||||
dstIPs := make([]netip.Addr, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
@ -46,11 +46,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
srcIPs := make([]net.IP, 100)
|
||||
dstIPs := make([]net.IP, 100)
|
||||
srcIPs := make([]netip.Addr, 100)
|
||||
dstIPs := make([]netip.Addr, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
@ -2,7 +2,6 @@ package conntrack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
@ -70,8 +69,13 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
|
||||
return tracker
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) (ICMPConnKey, bool) {
|
||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16) (ICMPConnKey, bool) {
|
||||
key := ICMPConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
@ -87,7 +91,7 @@ func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound ICMP connection
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Egress)
|
||||
@ -95,12 +99,12 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
|
||||
}
|
||||
|
||||
// TrackInbound records an inbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
|
||||
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
|
||||
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Ingress)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||
func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) {
|
||||
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) {
|
||||
// TODO: icmp doesn't need to extend the timeout
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, id, seq)
|
||||
if exists {
|
||||
@ -112,7 +116,7 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
|
||||
// non echo requests don't need tracking
|
||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||
t.sendStartEvent(direction, key, typ, code)
|
||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code)
|
||||
return
|
||||
}
|
||||
|
||||
@ -120,8 +124,8 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: key.SrcIP,
|
||||
DestIP: key.DstIP,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
},
|
||||
ICMPType: typ,
|
||||
ICMPCode: code,
|
||||
@ -133,16 +137,21 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||
t.sendEvent(nftypes.TypeStart, key, conn)
|
||||
t.sendEvent(nftypes.TypeStart, conn)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, icmpType uint8) bool {
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||
return false
|
||||
}
|
||||
|
||||
key := makeICMPKey(dstIP, srcIP, id, seq)
|
||||
key := ICMPConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
@ -177,7 +186,7 @@ func (t *ICMPTracker) cleanup() {
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Debug("Removed ICMP connection %s (timeout)", &key)
|
||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||
t.sendEvent(nftypes.TypeEnd, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -192,40 +201,28 @@ func (t *ICMPTracker) Close() {
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) sendEvent(typ nftypes.Type, key ICMPConnKey, conn *ICMPConnTrack) {
|
||||
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||
SourceIP: key.SrcIP,
|
||||
DestIP: key.DstIP,
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
ICMPType: conn.ICMPType,
|
||||
ICMPCode: conn.ICMPCode,
|
||||
})
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, key ICMPConnKey, typ, code uint8) {
|
||||
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeStart,
|
||||
Direction: direction,
|
||||
Protocol: nftypes.ICMP,
|
||||
SourceIP: key.SrcIP,
|
||||
DestIP: key.DstIP,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
ICMPType: typ,
|
||||
ICMPCode: code,
|
||||
})
|
||||
}
|
||||
|
||||
// makeICMPKey creates an ICMP connection key
|
||||
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
||||
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
||||
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
||||
return ICMPConnKey{
|
||||
SrcIP: srcAddr,
|
||||
DstIP: dstAddr,
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -10,8 +10,8 @@ func BenchmarkICMPTracker(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@ -23,8 +23,8 @@ func BenchmarkICMPTracker(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
|
@ -3,7 +3,7 @@ package conntrack
|
||||
// TODO: Send RST packets for invalid/timed-out connections
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -144,8 +144,13 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
return tracker
|
||||
}
|
||||
|
||||
func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) {
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
@ -154,7 +159,6 @@ func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
if exists {
|
||||
conn.Lock()
|
||||
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
|
||||
conn.UpdateLastSeen()
|
||||
conn.Unlock()
|
||||
|
||||
return key, true
|
||||
@ -164,7 +168,7 @@ func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound TCP connection
|
||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress)
|
||||
@ -172,12 +176,12 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
}
|
||||
|
||||
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) {
|
||||
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags)
|
||||
if exists {
|
||||
return
|
||||
@ -187,14 +191,13 @@ func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: key.SrcIP,
|
||||
DestIP: key.DstIP,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
},
|
||||
}
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
conn.established.Store(false)
|
||||
conn.tombstone.Store(false)
|
||||
|
||||
@ -205,12 +208,17 @@ func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.sendEvent(nftypes.TypeStart, key, conn)
|
||||
t.sendEvent(nftypes.TypeStart, conn)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) bool {
|
||||
key := ConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
SrcPort: dstPort,
|
||||
DstPort: srcPort,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
@ -233,13 +241,12 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
conn.Unlock()
|
||||
|
||||
t.logger.Trace("TCP connection reset: %s", key)
|
||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||
t.sendEvent(nftypes.TypeEnd, conn)
|
||||
return true
|
||||
}
|
||||
|
||||
conn.Lock()
|
||||
t.updateState(key, conn, flags, false)
|
||||
conn.UpdateLastSeen()
|
||||
isEstablished := conn.IsEstablished()
|
||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||
conn.Unlock()
|
||||
@ -249,6 +256,8 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
|
||||
// updateState updates the TCP connection state based on flags
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||
conn.UpdateLastSeen()
|
||||
|
||||
state := conn.State
|
||||
defer func() {
|
||||
if state != conn.State {
|
||||
@ -312,7 +321,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
|
||||
conn.State = TCPStateTimeWait
|
||||
|
||||
t.logger.Trace("TCP connection %s completed", key)
|
||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||
t.sendEvent(nftypes.TypeEnd, conn)
|
||||
}
|
||||
|
||||
case TCPStateClosing:
|
||||
@ -321,7 +330,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
|
||||
// Keep established = false from previous state
|
||||
|
||||
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
|
||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||
t.sendEvent(nftypes.TypeEnd, conn)
|
||||
}
|
||||
|
||||
case TCPStateCloseWait:
|
||||
@ -335,7 +344,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
|
||||
conn.SetTombstone()
|
||||
|
||||
// Send close event for gracefully closed connections
|
||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||
t.sendEvent(nftypes.TypeEnd, conn)
|
||||
t.logger.Trace("TCP connection %s closed gracefully", key)
|
||||
}
|
||||
}
|
||||
@ -422,7 +431,7 @@ func (t *TCPTracker) cleanup() {
|
||||
|
||||
// event already handled by state change
|
||||
if conn.State != TCPStateTimeWait {
|
||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||
t.sendEvent(nftypes.TypeEnd, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -453,15 +462,15 @@ func isValidFlagCombination(flags uint8) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *TCPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *TCPConnTrack) {
|
||||
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.TCP,
|
||||
SourceIP: key.SrcIP,
|
||||
DestIP: key.DstIP,
|
||||
SourcePort: key.SrcPort,
|
||||
DestPort: key.DstPort,
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
SourcePort: conn.SourcePort,
|
||||
DestPort: conn.DestPort,
|
||||
})
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -12,8 +12,8 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
dstIP := net.ParseIP("100.64.0.2")
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
@ -165,8 +165,8 @@ func TestRSTHandling(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
dstIP := net.ParseIP("100.64.0.2")
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
@ -208,7 +208,12 @@ func TestRSTHandling(t *testing.T) {
|
||||
tt.sendRST()
|
||||
|
||||
// Verify connection state is as expected
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
if tt.wantValid {
|
||||
require.NotNil(t, conn)
|
||||
@ -220,7 +225,7 @@ func TestRSTHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
// Helper to establish a TCP connection
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||
t.Helper()
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
@ -236,8 +241,8 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@ -249,8 +254,8 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
@ -267,8 +272,8 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
@ -291,8 +296,8 @@ func BenchmarkCleanup(b *testing.B) {
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -54,7 +54,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound UDP connection
|
||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress)
|
||||
@ -62,12 +62,17 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
}
|
||||
|
||||
// TrackInbound records an inbound UDP connection
|
||||
func (t *UDPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress)
|
||||
}
|
||||
|
||||
func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) (ConnKey, bool) {
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) (ConnKey, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
@ -82,7 +87,7 @@ func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, direction nftypes.Direction) {
|
||||
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort)
|
||||
if exists {
|
||||
return
|
||||
@ -92,8 +97,8 @@ func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: key.SrcIP,
|
||||
DestIP: key.DstIP,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
},
|
||||
@ -105,12 +110,17 @@ func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace("New %s UDP connection: %s", direction, key)
|
||||
t.sendEvent(nftypes.TypeStart, key, conn)
|
||||
t.sendEvent(nftypes.TypeStart, conn)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) bool {
|
||||
key := ConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
SrcPort: dstPort,
|
||||
DstPort: srcPort,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
@ -146,7 +156,7 @@ func (t *UDPTracker) cleanup() {
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Removed UDP connection %s (timeout)", key)
|
||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||
t.sendEvent(nftypes.TypeEnd, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -162,11 +172,16 @@ func (t *UDPTracker) Close() {
|
||||
}
|
||||
|
||||
// GetConnection safely retrieves a connection state
|
||||
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
|
||||
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn, exists := t.connections[key]
|
||||
return conn, exists
|
||||
}
|
||||
@ -176,15 +191,15 @@ func (t *UDPTracker) Timeout() time.Duration {
|
||||
return t.timeout
|
||||
}
|
||||
|
||||
func (t *UDPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *UDPConnTrack) {
|
||||
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.UDP,
|
||||
SourceIP: key.SrcIP,
|
||||
DestIP: key.DstIP,
|
||||
SourcePort: key.SrcPort,
|
||||
DestPort: key.DstPort,
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
SourcePort: conn.SourcePort,
|
||||
DestPort: conn.DestPort,
|
||||
})
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
@ -49,10 +48,15 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(53)
|
||||
|
||||
tracker.TrackOutbound(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Verify connection was tracked
|
||||
key := makeConnKey(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn, exists := tracker.connections[key]
|
||||
require.True(t, exists)
|
||||
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
|
||||
@ -66,8 +70,8 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
dstIP := net.ParseIP("192.168.1.3")
|
||||
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(53)
|
||||
|
||||
@ -76,8 +80,8 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
srcIP net.IP
|
||||
dstIP net.IP
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
sleep time.Duration
|
||||
@ -94,7 +98,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "invalid source IP",
|
||||
srcIP: net.ParseIP("192.168.1.4"),
|
||||
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||
dstIP: srcIP,
|
||||
srcPort: dstPort,
|
||||
dstPort: srcPort,
|
||||
@ -104,7 +108,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
{
|
||||
name: "invalid destination IP",
|
||||
srcIP: dstIP,
|
||||
dstIP: net.ParseIP("192.168.1.4"),
|
||||
dstIP: netip.MustParseAddr("192.168.1.4"),
|
||||
srcPort: dstPort,
|
||||
dstPort: srcPort,
|
||||
sleep: 0,
|
||||
@ -170,20 +174,20 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
|
||||
// Add some connections
|
||||
connections := []struct {
|
||||
srcIP net.IP
|
||||
dstIP net.IP
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
}{
|
||||
{
|
||||
srcIP: net.ParseIP("192.168.1.2"),
|
||||
dstIP: net.ParseIP("192.168.1.3"),
|
||||
srcIP: netip.MustParseAddr("192.168.1.2"),
|
||||
dstIP: netip.MustParseAddr("192.168.1.3"),
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
},
|
||||
{
|
||||
srcIP: net.ParseIP("192.168.1.4"),
|
||||
dstIP: net.ParseIP("192.168.1.5"),
|
||||
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||
dstIP: netip.MustParseAddr("192.168.1.5"),
|
||||
srcPort: 12346,
|
||||
dstPort: 53,
|
||||
},
|
||||
@ -215,8 +219,8 @@ func BenchmarkUDPTracker(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@ -228,8 +232,8 @@ func BenchmarkUDPTracker(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
|
@ -3,6 +3,7 @@ package uspfilter
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -31,13 +32,9 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return false
|
||||
}
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
high := (uint16(ip[0]) << 8) | uint16(ip[1])
|
||||
low := (uint16(ip[2]) << 8) | uint16(ip[3])
|
||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||
}
|
||||
|
||||
@ -122,12 +119,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
||||
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
return m.checkBitmapBit(ipv4)
|
||||
if ip.Is4() {
|
||||
return m.checkBitmapBit(ip.AsSlice())
|
||||
}
|
||||
|
||||
return false
|
||||
|
@ -2,6 +2,7 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -13,7 +14,7 @@ func TestLocalIPManager(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupAddr iface.WGAddress
|
||||
testIP net.IP
|
||||
testIP netip.Addr
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
@ -25,7 +26,7 @@ func TestLocalIPManager(t *testing.T) {
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.0.0.2"),
|
||||
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
@ -37,7 +38,7 @@ func TestLocalIPManager(t *testing.T) {
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.0.0.1"),
|
||||
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
@ -49,7 +50,7 @@ func TestLocalIPManager(t *testing.T) {
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("127.255.255.255"),
|
||||
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
@ -61,7 +62,7 @@ func TestLocalIPManager(t *testing.T) {
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("192.168.1.1"),
|
||||
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
@ -73,7 +74,7 @@ func TestLocalIPManager(t *testing.T) {
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("192.168.1.2"),
|
||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
@ -85,7 +86,7 @@ func TestLocalIPManager(t *testing.T) {
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
},
|
||||
testIP: net.ParseIP("fe80::1"),
|
||||
testIP: netip.MustParseAddr("fe80::1"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
@ -174,7 +175,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||
t.Logf("Testing %d IPs", len(tests))
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
||||
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
|
||||
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||
})
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@ -13,7 +12,7 @@ import (
|
||||
type PeerRule struct {
|
||||
id string
|
||||
mgmtId []byte
|
||||
ip net.IP
|
||||
ip netip.Addr
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
protoLayer gopacket.LayerType
|
||||
|
@ -2,7 +2,7 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@ -53,8 +53,8 @@ type TraceResult struct {
|
||||
}
|
||||
|
||||
type PacketTrace struct {
|
||||
SourceIP net.IP
|
||||
DestinationIP net.IP
|
||||
SourceIP netip.Addr
|
||||
DestinationIP netip.Addr
|
||||
Protocol string
|
||||
SourcePort uint16
|
||||
DestinationPort uint16
|
||||
@ -72,8 +72,8 @@ type TCPState struct {
|
||||
}
|
||||
|
||||
type PacketBuilder struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
Protocol fw.Protocol
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||
SrcIP: p.SrcIP,
|
||||
DstIP: p.DstIP,
|
||||
SrcIP: p.SrcIP.AsSlice(),
|
||||
DstIP: p.DstIP.AsSlice(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -260,7 +260,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
@ -273,14 +273,14 @@ func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.nativeRouter {
|
||||
if m.nativeRouter.Load() {
|
||||
return m.handleNativeRouter(trace)
|
||||
}
|
||||
|
||||
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
||||
msg := "No existing connection found"
|
||||
if allowed {
|
||||
@ -309,13 +309,12 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
||||
return msg
|
||||
}
|
||||
|
||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
if !m.localForwarding {
|
||||
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||
|
||||
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||
@ -341,7 +340,7 @@ func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||
if !m.routingEnabled {
|
||||
if !m.routingEnabled.Load() {
|
||||
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||
return false
|
||||
@ -357,7 +356,7 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||
proto, _ := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
@ -373,7 +372,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
|
||||
}
|
||||
trace.AddResult(StageRouteACL, msg, allowed)
|
||||
|
||||
if allowed && m.forwarder != nil {
|
||||
if allowed && m.forwarder.Load() != nil {
|
||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||
}
|
||||
|
||||
|
0
client/firewall/uspfilter/tracer_test.go
Normal file
0
client/firewall/uspfilter/tracer_test.go
Normal file
@ -10,6 +10,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
@ -66,9 +67,9 @@ func (r RouteRules) Sort() {
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
// outgoingRules is used for hooks only
|
||||
outgoingRules map[string]RuleSet
|
||||
outgoingRules map[netip.Addr]RuleSet
|
||||
// incomingRules is used for filtering and hooks
|
||||
incomingRules map[string]RuleSet
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
routeRules RouteRules
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
@ -80,9 +81,9 @@ type Manager struct {
|
||||
// indicates whether server routes are disabled
|
||||
disableServerRoutes bool
|
||||
// indicates whether we forward packets not destined for ourselves
|
||||
routingEnabled bool
|
||||
routingEnabled atomic.Bool
|
||||
// indicates whether we leave forwarding and filtering to the native firewall
|
||||
nativeRouter bool
|
||||
nativeRouter atomic.Bool
|
||||
// indicates whether we track outbound connections
|
||||
stateful bool
|
||||
// indicates whether wireguards runs in netstack mode
|
||||
@ -95,7 +96,7 @@ type Manager struct {
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder *forwarder.Forwarder
|
||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
@ -168,18 +169,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
},
|
||||
},
|
||||
nativeFirewall: nativeFirewall,
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
outgoingRules: make(map[netip.Addr]RuleSet),
|
||||
incomingRules: make(map[netip.Addr]RuleSet),
|
||||
wgIface: iface,
|
||||
localipmanager: newLocalIPManager(),
|
||||
disableServerRoutes: disableServerRoutes,
|
||||
routingEnabled: false,
|
||||
stateful: !disableConntrack,
|
||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||
@ -211,7 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
}
|
||||
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||
if m.forwarder == nil {
|
||||
if m.forwarder.Load() == nil {
|
||||
return nil
|
||||
}
|
||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||
@ -255,20 +256,20 @@ func (m *Manager) determineRouting() error {
|
||||
|
||||
switch {
|
||||
case disableUspRouting:
|
||||
m.routingEnabled = false
|
||||
m.nativeRouter = false
|
||||
m.routingEnabled.Store(false)
|
||||
m.nativeRouter.Store(false)
|
||||
log.Info("userspace routing is disabled")
|
||||
|
||||
case m.disableServerRoutes:
|
||||
// if server routes are disabled we will let packets pass to the native stack
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(true)
|
||||
|
||||
log.Info("server routes are disabled")
|
||||
|
||||
case forceUserspaceRouter:
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = false
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
log.Info("userspace routing is forced")
|
||||
|
||||
@ -276,19 +277,19 @@ func (m *Manager) determineRouting() error {
|
||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||
// netstack mode won't support native routing as there is no interface
|
||||
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = true
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(true)
|
||||
|
||||
log.Info("native routing is enabled")
|
||||
|
||||
default:
|
||||
m.routingEnabled = true
|
||||
m.nativeRouter = false
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
log.Info("userspace routing enabled by default")
|
||||
}
|
||||
|
||||
if m.routingEnabled && !m.nativeRouter {
|
||||
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
|
||||
return m.initForwarder()
|
||||
}
|
||||
|
||||
@ -297,24 +298,24 @@ func (m *Manager) determineRouting() error {
|
||||
|
||||
// initForwarder initializes the forwarder, it disables routing on errors
|
||||
func (m *Manager) initForwarder() error {
|
||||
if m.forwarder != nil {
|
||||
if m.forwarder.Load() != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||
intf := m.wgIface.GetWGDevice()
|
||||
if intf == nil {
|
||||
m.routingEnabled = false
|
||||
m.routingEnabled.Store(false)
|
||||
return errors.New("forwarding not supported")
|
||||
}
|
||||
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
|
||||
if err != nil {
|
||||
m.routingEnabled = false
|
||||
m.routingEnabled.Store(false)
|
||||
return fmt.Errorf("create forwarder: %w", err)
|
||||
}
|
||||
|
||||
m.forwarder = forwarder
|
||||
m.forwarder.Store(forwarder)
|
||||
|
||||
log.Debug("forwarder initialized")
|
||||
|
||||
@ -330,7 +331,7 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
}
|
||||
|
||||
@ -341,7 +342,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
// RemoveNatRule removes a routing firewall rule
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.RemoveNatRule(pair)
|
||||
}
|
||||
return nil
|
||||
@ -360,17 +361,23 @@ func (m *Manager) AddPeerFiltering(
|
||||
action firewall.Action,
|
||||
_ string,
|
||||
) ([]firewall.Rule, error) {
|
||||
// TODO: fix in upper layers
|
||||
i, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid IP: %s", ip)
|
||||
}
|
||||
|
||||
i = i.Unmap()
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
mgmtId: id,
|
||||
ip: ip,
|
||||
ip: i,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
drop: action == firewall.ActionDrop,
|
||||
}
|
||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
||||
if i.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
r.ip = ipNormalized
|
||||
}
|
||||
|
||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
||||
@ -395,10 +402,10 @@ func (m *Manager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||
m.incomingRules[r.ip] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
m.incomingRules[r.ip][r.id] = r
|
||||
m.mutex.Unlock()
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
@ -412,13 +419,10 @@ func (m *Manager) AddRouteFiltering(
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
@ -432,14 +436,16 @@ func (m *Manager) AddRouteFiltering(
|
||||
action: action,
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
m.routeRules = append(m.routeRules, rule)
|
||||
m.routeRules.Sort()
|
||||
m.mutex.Unlock()
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeRouter && m.nativeFirewall != nil {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
@ -468,10 +474,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
|
||||
if _, ok := m.incomingRules[r.ip][r.id]; !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
delete(m.incomingRules[r.ip], r.id)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -519,9 +525,6 @@ func (m *Manager) UpdateLocalIPs() error {
|
||||
}
|
||||
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
@ -534,7 +537,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
}
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if srcIP == nil {
|
||||
if !srcIP.IsValid() {
|
||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||
return false
|
||||
}
|
||||
|
||||
@ -551,14 +555,18 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
|
||||
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
|
||||
switch d.decoded[0] {
|
||||
case layers.LayerTypeIPv4:
|
||||
return d.ip4.SrcIP, d.ip4.DstIP
|
||||
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
||||
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||
return src, dst
|
||||
case layers.LayerTypeIPv6:
|
||||
return d.ip6.SrcIP, d.ip6.DstIP
|
||||
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
||||
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
||||
return src, dst
|
||||
default:
|
||||
return nil, nil
|
||||
return netip.Addr{}, netip.Addr{}
|
||||
}
|
||||
}
|
||||
|
||||
@ -585,7 +593,7 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
return flags
|
||||
}
|
||||
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr) {
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
@ -598,7 +606,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr) {
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
@ -611,8 +619,11 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
|
||||
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
||||
func (m *Manager) checkUDPHooks(d *decoder, dstIP netip.Addr, packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
for _, ipKey := range []netip.Addr{dstIP, netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
|
||||
if rules, exists := m.outgoingRules[ipKey]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
@ -627,9 +638,6 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo
|
||||
// dropFilter implements filtering logic for incoming packets.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) dropFilter(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
@ -638,7 +646,7 @@ func (m *Manager) dropFilter(packetData []byte) bool {
|
||||
}
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if srcIP == nil {
|
||||
if !srcIP.IsValid() {
|
||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||
return true
|
||||
}
|
||||
@ -658,15 +666,13 @@ func (m *Manager) dropFilter(packetData []byte) bool {
|
||||
|
||||
// handleLocalTraffic handles local traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
|
||||
if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked {
|
||||
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
||||
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
||||
_, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleId, pnum, srcAddr, srcPort, dstAddr, dstPort)
|
||||
ruleId, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
@ -674,8 +680,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData
|
||||
RuleID: ruleId,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: pnum,
|
||||
SourceIP: srcAddr,
|
||||
DestIP: dstAddr,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
// TODO: icmp type/code
|
||||
@ -700,12 +706,12 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if m.forwarder == nil {
|
||||
if m.forwarder.Load() == nil {
|
||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||
return true
|
||||
}
|
||||
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject local packet: %v", err)
|
||||
}
|
||||
|
||||
@ -715,16 +721,16 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
|
||||
// handleRoutedTraffic handles routed traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
|
||||
// Drop if routing is disabled
|
||||
if !m.routingEnabled {
|
||||
if !m.routingEnabled.Load() {
|
||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
// Pass to native stack if native router is enabled or forced
|
||||
if m.nativeRouter {
|
||||
if m.nativeRouter.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
@ -732,9 +738,6 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
if id, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
|
||||
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
||||
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
||||
|
||||
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
id, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
@ -744,8 +747,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
|
||||
RuleID: id,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: pnum,
|
||||
SourceIP: srcAddr,
|
||||
DestIP: dstAddr,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
// TODO: icmp type/code
|
||||
@ -754,7 +757,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
|
||||
}
|
||||
|
||||
// Let forwarder handle the packet if it passed route ACLs
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||
}
|
||||
|
||||
@ -799,7 +802,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return m.tcpTracker.IsValidInbound(
|
||||
@ -844,20 +847,22 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||
}
|
||||
|
||||
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) ([]byte, bool) {
|
||||
func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
if m.isSpecialICMP(d) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
@ -882,10 +887,10 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
||||
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
for _, rule := range rules {
|
||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -919,16 +924,13 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *de
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
// routeACLsPass returns treu if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
|
||||
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
|
||||
|
||||
for _, rule := range m.routeRules {
|
||||
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
||||
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
|
||||
return rule.mgmtId, rule.action == firewall.ActionAccept
|
||||
}
|
||||
}
|
||||
@ -972,9 +974,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||
func (m *Manager) AddUDPPacketHook(
|
||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
||||
) string {
|
||||
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
@ -984,23 +984,22 @@ func (m *Manager) AddUDPPacketHook(
|
||||
udpHook: hook,
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
if ip.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
m.incomingRules[r.ip][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
m.outgoingRules[r.ip][r.id] = r
|
||||
}
|
||||
|
||||
m.mutex.Unlock()
|
||||
|
||||
return r.id
|
||||
@ -1048,20 +1047,21 @@ func (m *Manager) DisableRouting() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.forwarder == nil {
|
||||
fwder := m.forwarder.Load()
|
||||
if fwder == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.routingEnabled = false
|
||||
m.nativeRouter = false
|
||||
m.routingEnabled.Store(false)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
// don't stop forwarder if in use by netstack
|
||||
if m.netstack && m.localForwarding {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.forwarder.Stop()
|
||||
m.forwarder = nil
|
||||
fwder.Stop()
|
||||
m.forwarder.Store(nil)
|
||||
|
||||
log.Debug("forwarder stopped")
|
||||
|
||||
|
@ -1054,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, tc := range cases {
|
||||
srcIP := net.ParseIP(tc.srcIP)
|
||||
dstIP := net.ParseIP(tc.dstIP)
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||
}
|
||||
}
|
||||
|
@ -306,8 +306,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NoError(tb, err)
|
||||
require.NotNil(tb, manager)
|
||||
require.True(tb, manager.routingEnabled)
|
||||
require.False(tb, manager.nativeRouter)
|
||||
require.True(tb, manager.routingEnabled.Load())
|
||||
require.False(tb, manager.nativeRouter.Load())
|
||||
|
||||
tb.Cleanup(func() {
|
||||
require.NoError(tb, manager.Reset(nil))
|
||||
@ -818,8 +818,8 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
})
|
||||
|
||||
srcIP := net.ParseIP(tc.srcIP)
|
||||
dstIP := net.ParseIP(tc.dstIP)
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
|
||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
||||
// to the forwarder
|
||||
@ -1006,8 +1006,8 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
})
|
||||
|
||||
for i, p := range tc.packets {
|
||||
srcIP := net.ParseIP(p.srcIP)
|
||||
dstIP := net.ParseIP(p.dstIP)
|
||||
srcIP := netip.MustParseAddr(p.srcIP)
|
||||
dstIP := netip.MustParseAddr(p.dstIP)
|
||||
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
||||
|
@ -125,19 +125,19 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
ip := netip.MustParseAddr("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
rule2, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok {
|
||||
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
@ -151,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok {
|
||||
if _, ok := m.incomingRules[ip][r.ID()]; ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
@ -162,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name string
|
||||
in bool
|
||||
expDir fw.RuleDirection
|
||||
ip net.IP
|
||||
ip netip.Addr
|
||||
dPort uint16
|
||||
hook func([]byte) bool
|
||||
expectedID string
|
||||
@ -171,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name: "Test Outgoing UDP Packet Hook",
|
||||
in: false,
|
||||
expDir: fw.RuleDirectionOUT,
|
||||
ip: net.IPv4(10, 168, 0, 1),
|
||||
ip: netip.MustParseAddr("10.168.0.1"),
|
||||
dPort: 8000,
|
||||
hook: func([]byte) bool { return true },
|
||||
},
|
||||
@ -179,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name: "Test Incoming UDP Packet Hook",
|
||||
in: true,
|
||||
expDir: fw.RuleDirectionIN,
|
||||
ip: net.IPv6loopback,
|
||||
ip: netip.MustParseAddr("::1"),
|
||||
dPort: 9000,
|
||||
hook: func([]byte) bool { return false },
|
||||
},
|
||||
@ -196,11 +196,11 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
||||
for _, rule := range manager.incomingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
} else {
|
||||
@ -208,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
||||
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
}
|
||||
|
||||
if !tt.ip.Equal(addedRule.ip) {
|
||||
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||
return
|
||||
}
|
||||
@ -357,7 +357,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
|
||||
// Add a UDP packet hook
|
||||
hookFunc := func(data []byte) bool { return true }
|
||||
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
|
||||
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||
|
||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||
found := false
|
||||
@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
hookCalled := false
|
||||
hookID := manager.AddUDPPacketHook(
|
||||
false,
|
||||
net.ParseIP("100.10.0.100"),
|
||||
netip.MustParseAddr("100.10.0.100"),
|
||||
53,
|
||||
func([]byte) bool {
|
||||
hookCalled = true
|
||||
@ -573,7 +573,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||
|
||||
// Verify connection was tracked
|
||||
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
|
||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
|
||||
@ -641,7 +641,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
|
||||
// If the connection should still be valid, verify it exists
|
||||
if cp.shouldAllow {
|
||||
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
|
||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||
require.True(t, exists, "Connection should still exist during valid window")
|
||||
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
|
||||
"LastSeen should be updated for valid responses")
|
||||
|
@ -2,6 +2,7 @@ package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
@ -19,7 +20,7 @@ type PacketFilter interface {
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||
// Hook function receives raw network packet data as argument.
|
||||
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
|
||||
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||
|
||||
// RemovePacketHook removes hook by ID
|
||||
RemovePacketHook(hookID string) error
|
||||
|
@ -6,6 +6,7 @@ package mocks
|
||||
|
||||
import (
|
||||
net "net"
|
||||
"net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(string)
|
||||
|
@ -2,7 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@ -117,5 +117,10 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
return true
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
||||
ip, err := netip.ParseAddr(s.runtimeIP)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse runtime ip: %w", err)
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
@ -41,11 +42,21 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
srcIP = engine.GetWgAddr()
|
||||
}
|
||||
|
||||
srcAddr, ok := netip.AddrFromSlice(srcIP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source IP address")
|
||||
}
|
||||
|
||||
dstIP := net.ParseIP(req.GetDestinationIp())
|
||||
if req.GetDestinationIp() == "self" {
|
||||
dstIP = engine.GetWgAddr()
|
||||
}
|
||||
|
||||
dstAddr, ok := netip.AddrFromSlice(dstIP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source IP address")
|
||||
}
|
||||
|
||||
if srcIP == nil || dstIP == nil {
|
||||
return nil, fmt.Errorf("invalid IP address")
|
||||
}
|
||||
@ -85,8 +96,8 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
}
|
||||
|
||||
builder := &uspfilter.PacketBuilder{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcIP: srcAddr,
|
||||
DstIP: dstAddr,
|
||||
Protocol: protocol,
|
||||
SrcPort: uint16(req.GetSourcePort()),
|
||||
DstPort: uint16(req.GetDestinationPort()),
|
||||
|
Loading…
x
Reference in New Issue
Block a user