mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-28 13:42:31 +02:00
Replace net.IP with netip.Addr (#3425)
This commit is contained in:
parent
419ed275fa
commit
e9f11fb11b
@ -4,6 +4,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -17,8 +18,8 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
@ -35,8 +36,8 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
|||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder != nil {
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
m.forwarder.Stop()
|
fwder.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.logger != nil {
|
if m.logger != nil {
|
||||||
|
@ -3,6 +3,7 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@ -26,8 +27,8 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
@ -44,8 +45,8 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
|||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder != nil {
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
m.forwarder.Stop()
|
fwder.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.logger != nil {
|
if m.logger != nil {
|
||||||
|
@ -2,7 +2,6 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@ -52,16 +51,3 @@ type ConnKey struct {
|
|||||||
func (c ConnKey) String() string {
|
func (c ConnKey) String() string {
|
||||||
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeConnKey creates a connection key
|
|
||||||
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
|
||||||
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
|
||||||
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
|
||||||
|
|
||||||
return ConnKey{
|
|
||||||
SrcIP: srcAddr,
|
|
||||||
DstIP: dstAddr,
|
|
||||||
SrcPort: srcPort,
|
|
||||||
DstPort: dstPort,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -2,7 +2,7 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@ -21,11 +21,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
|||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
srcIPs := make([]net.IP, 100)
|
srcIPs := make([]netip.Addr, 100)
|
||||||
dstIPs := make([]net.IP, 100)
|
dstIPs := make([]netip.Addr, 100)
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
@ -46,11 +46,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
|||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
srcIPs := make([]net.IP, 100)
|
srcIPs := make([]netip.Addr, 100)
|
||||||
dstIPs := make([]net.IP, 100)
|
dstIPs := make([]netip.Addr, 100)
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
@ -2,7 +2,6 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -70,8 +69,13 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
|
|||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) (ICMPConnKey, bool) {
|
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16) (ICMPConnKey, bool) {
|
||||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
key := ICMPConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
ID: id,
|
||||||
|
Sequence: seq,
|
||||||
|
}
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@ -87,7 +91,7 @@ func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound ICMP connection
|
// TrackOutbound records an outbound ICMP connection
|
||||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
|
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
|
||||||
if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists {
|
if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists {
|
||||||
// if (inverted direction) conn is not tracked, track this direction
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Egress)
|
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Egress)
|
||||||
@ -95,12 +99,12 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackInbound records an inbound ICMP Echo Request
|
// TrackInbound records an inbound ICMP Echo Request
|
||||||
func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
|
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
|
||||||
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Ingress)
|
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Ingress)
|
||||||
}
|
}
|
||||||
|
|
||||||
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||||
func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) {
|
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) {
|
||||||
// TODO: icmp doesn't need to extend the timeout
|
// TODO: icmp doesn't need to extend the timeout
|
||||||
key, exists := t.updateIfExists(srcIP, dstIP, id, seq)
|
key, exists := t.updateIfExists(srcIP, dstIP, id, seq)
|
||||||
if exists {
|
if exists {
|
||||||
@ -112,7 +116,7 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
|
|||||||
// non echo requests don't need tracking
|
// non echo requests don't need tracking
|
||||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||||
t.sendStartEvent(direction, key, typ, code)
|
t.sendStartEvent(direction, srcIP, dstIP, typ, code)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,8 +124,8 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
|
|||||||
BaseConnTrack: BaseConnTrack{
|
BaseConnTrack: BaseConnTrack{
|
||||||
FlowId: uuid.New(),
|
FlowId: uuid.New(),
|
||||||
Direction: direction,
|
Direction: direction,
|
||||||
SourceIP: key.SrcIP,
|
SourceIP: srcIP,
|
||||||
DestIP: key.DstIP,
|
DestIP: dstIP,
|
||||||
},
|
},
|
||||||
ICMPType: typ,
|
ICMPType: typ,
|
||||||
ICMPCode: code,
|
ICMPCode: code,
|
||||||
@ -133,16 +137,21 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
|
|||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||||
t.sendEvent(nftypes.TypeStart, key, conn)
|
t.sendEvent(nftypes.TypeStart, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, icmpType uint8) bool {
|
||||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
key := makeICMPKey(dstIP, srcIP, id, seq)
|
key := ICMPConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
ID: id,
|
||||||
|
Sequence: seq,
|
||||||
|
}
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@ -177,7 +186,7 @@ func (t *ICMPTracker) cleanup() {
|
|||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Debug("Removed ICMP connection %s (timeout)", &key)
|
t.logger.Debug("Removed ICMP connection %s (timeout)", &key)
|
||||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
t.sendEvent(nftypes.TypeEnd, conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -192,40 +201,28 @@ func (t *ICMPTracker) Close() {
|
|||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ICMPTracker) sendEvent(typ nftypes.Type, key ICMPConnKey, conn *ICMPConnTrack) {
|
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack) {
|
||||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
FlowID: conn.FlowId,
|
FlowID: conn.FlowId,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: conn.Direction,
|
Direction: conn.Direction,
|
||||||
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||||
SourceIP: key.SrcIP,
|
SourceIP: conn.SourceIP,
|
||||||
DestIP: key.DstIP,
|
DestIP: conn.DestIP,
|
||||||
ICMPType: conn.ICMPType,
|
ICMPType: conn.ICMPType,
|
||||||
ICMPCode: conn.ICMPCode,
|
ICMPCode: conn.ICMPCode,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, key ICMPConnKey, typ, code uint8) {
|
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8) {
|
||||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
FlowID: uuid.New(),
|
FlowID: uuid.New(),
|
||||||
Type: nftypes.TypeStart,
|
Type: nftypes.TypeStart,
|
||||||
Direction: direction,
|
Direction: direction,
|
||||||
Protocol: nftypes.ICMP,
|
Protocol: nftypes.ICMP,
|
||||||
SourceIP: key.SrcIP,
|
SourceIP: srcIP,
|
||||||
DestIP: key.DstIP,
|
DestIP: dstIP,
|
||||||
ICMPType: typ,
|
ICMPType: typ,
|
||||||
ICMPCode: code,
|
ICMPCode: code,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeICMPKey creates an ICMP connection key
|
|
||||||
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
|
||||||
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
|
||||||
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
|
||||||
return ICMPConnKey{
|
|
||||||
SrcIP: srcAddr,
|
|
||||||
DstIP: dstAddr,
|
|
||||||
ID: id,
|
|
||||||
Sequence: seq,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -10,8 +10,8 @@ func BenchmarkICMPTracker(b *testing.B) {
|
|||||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
@ -23,8 +23,8 @@ func BenchmarkICMPTracker(b *testing.B) {
|
|||||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
|
@ -3,7 +3,7 @@ package conntrack
|
|||||||
// TODO: Send RST packets for invalid/timed-out connections
|
// TODO: Send RST packets for invalid/timed-out connections
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@ -144,8 +144,13 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
|||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) {
|
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) {
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@ -154,7 +159,6 @@ func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
if exists {
|
if exists {
|
||||||
conn.Lock()
|
conn.Lock()
|
||||||
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
|
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
|
||||||
conn.UpdateLastSeen()
|
|
||||||
conn.Unlock()
|
conn.Unlock()
|
||||||
|
|
||||||
return key, true
|
return key, true
|
||||||
@ -164,7 +168,7 @@ func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound TCP connection
|
// TrackOutbound records an outbound TCP connection
|
||||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) {
|
||||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists {
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists {
|
||||||
// if (inverted direction) conn is not tracked, track this direction
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress)
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress)
|
||||||
@ -172,12 +176,12 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackInbound processes an inbound TCP packet and updates connection state
|
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||||
func (t *TCPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) {
|
||||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress)
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress)
|
||||||
}
|
}
|
||||||
|
|
||||||
// track is the common implementation for tracking both inbound and outbound connections
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) {
|
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) {
|
||||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags)
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags)
|
||||||
if exists {
|
if exists {
|
||||||
return
|
return
|
||||||
@ -187,14 +191,13 @@ func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
|
|||||||
BaseConnTrack: BaseConnTrack{
|
BaseConnTrack: BaseConnTrack{
|
||||||
FlowId: uuid.New(),
|
FlowId: uuid.New(),
|
||||||
Direction: direction,
|
Direction: direction,
|
||||||
SourceIP: key.SrcIP,
|
SourceIP: srcIP,
|
||||||
DestIP: key.DstIP,
|
DestIP: dstIP,
|
||||||
SourcePort: srcPort,
|
SourcePort: srcPort,
|
||||||
DestPort: dstPort,
|
DestPort: dstPort,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
conn.established.Store(false)
|
conn.established.Store(false)
|
||||||
conn.tombstone.Store(false)
|
conn.tombstone.Store(false)
|
||||||
|
|
||||||
@ -205,12 +208,17 @@ func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
|
|||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
t.sendEvent(nftypes.TypeStart, key, conn)
|
t.sendEvent(nftypes.TypeStart, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||||
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) bool {
|
||||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
key := ConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
SrcPort: dstPort,
|
||||||
|
DstPort: srcPort,
|
||||||
|
}
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@ -233,13 +241,12 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
conn.Unlock()
|
conn.Unlock()
|
||||||
|
|
||||||
t.logger.Trace("TCP connection reset: %s", key)
|
t.logger.Trace("TCP connection reset: %s", key)
|
||||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
t.sendEvent(nftypes.TypeEnd, conn)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.Lock()
|
conn.Lock()
|
||||||
t.updateState(key, conn, flags, false)
|
t.updateState(key, conn, flags, false)
|
||||||
conn.UpdateLastSeen()
|
|
||||||
isEstablished := conn.IsEstablished()
|
isEstablished := conn.IsEstablished()
|
||||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||||
conn.Unlock()
|
conn.Unlock()
|
||||||
@ -249,6 +256,8 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
|
|
||||||
// updateState updates the TCP connection state based on flags
|
// updateState updates the TCP connection state based on flags
|
||||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
|
||||||
state := conn.State
|
state := conn.State
|
||||||
defer func() {
|
defer func() {
|
||||||
if state != conn.State {
|
if state != conn.State {
|
||||||
@ -312,7 +321,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
|
|||||||
conn.State = TCPStateTimeWait
|
conn.State = TCPStateTimeWait
|
||||||
|
|
||||||
t.logger.Trace("TCP connection %s completed", key)
|
t.logger.Trace("TCP connection %s completed", key)
|
||||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
t.sendEvent(nftypes.TypeEnd, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateClosing:
|
case TCPStateClosing:
|
||||||
@ -321,7 +330,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
|
|||||||
// Keep established = false from previous state
|
// Keep established = false from previous state
|
||||||
|
|
||||||
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
|
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
|
||||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
t.sendEvent(nftypes.TypeEnd, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateCloseWait:
|
case TCPStateCloseWait:
|
||||||
@ -335,7 +344,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
|
|||||||
conn.SetTombstone()
|
conn.SetTombstone()
|
||||||
|
|
||||||
// Send close event for gracefully closed connections
|
// Send close event for gracefully closed connections
|
||||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
t.sendEvent(nftypes.TypeEnd, conn)
|
||||||
t.logger.Trace("TCP connection %s closed gracefully", key)
|
t.logger.Trace("TCP connection %s closed gracefully", key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -422,7 +431,7 @@ func (t *TCPTracker) cleanup() {
|
|||||||
|
|
||||||
// event already handled by state change
|
// event already handled by state change
|
||||||
if conn.State != TCPStateTimeWait {
|
if conn.State != TCPStateTimeWait {
|
||||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
t.sendEvent(nftypes.TypeEnd, conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -453,15 +462,15 @@ func isValidFlagCombination(flags uint8) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TCPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *TCPConnTrack) {
|
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack) {
|
||||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
FlowID: conn.FlowId,
|
FlowID: conn.FlowId,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: conn.Direction,
|
Direction: conn.Direction,
|
||||||
Protocol: nftypes.TCP,
|
Protocol: nftypes.TCP,
|
||||||
SourceIP: key.SrcIP,
|
SourceIP: conn.SourceIP,
|
||||||
DestIP: key.DstIP,
|
DestIP: conn.DestIP,
|
||||||
SourcePort: key.SrcPort,
|
SourcePort: conn.SourcePort,
|
||||||
DestPort: key.DstPort,
|
DestPort: conn.DestPort,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -12,8 +12,8 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
dstIP := net.ParseIP("100.64.0.2")
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(80)
|
dstPort := uint16(80)
|
||||||
|
|
||||||
@ -165,8 +165,8 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
dstIP := net.ParseIP("100.64.0.2")
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(80)
|
dstPort := uint16(80)
|
||||||
|
|
||||||
@ -208,7 +208,12 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
tt.sendRST()
|
tt.sendRST()
|
||||||
|
|
||||||
// Verify connection state is as expected
|
// Verify connection state is as expected
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
conn := tracker.connections[key]
|
conn := tracker.connections[key]
|
||||||
if tt.wantValid {
|
if tt.wantValid {
|
||||||
require.NotNil(t, conn)
|
require.NotNil(t, conn)
|
||||||
@ -220,7 +225,7 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper to establish a TCP connection
|
// Helper to establish a TCP connection
|
||||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
|
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||||
@ -236,8 +241,8 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
@ -249,8 +254,8 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
@ -267,8 +272,8 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
i := 0
|
i := 0
|
||||||
@ -291,8 +296,8 @@ func BenchmarkCleanup(b *testing.B) {
|
|||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Pre-populate with expired connections
|
// Pre-populate with expired connections
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
for i := 0; i < 10000; i++ {
|
for i := 0; i < 10000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound UDP connection
|
// TrackOutbound records an outbound UDP connection
|
||||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) {
|
||||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists {
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists {
|
||||||
// if (inverted direction) conn is not tracked, track this direction
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress)
|
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress)
|
||||||
@ -62,12 +62,17 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackInbound records an inbound UDP connection
|
// TrackInbound records an inbound UDP connection
|
||||||
func (t *UDPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) {
|
||||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress)
|
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) (ConnKey, bool) {
|
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) (ConnKey, bool) {
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@ -82,7 +87,7 @@ func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// track is the common implementation for tracking both inbound and outbound connections
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, direction nftypes.Direction) {
|
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction) {
|
||||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort)
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort)
|
||||||
if exists {
|
if exists {
|
||||||
return
|
return
|
||||||
@ -92,8 +97,8 @@ func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
|
|||||||
BaseConnTrack: BaseConnTrack{
|
BaseConnTrack: BaseConnTrack{
|
||||||
FlowId: uuid.New(),
|
FlowId: uuid.New(),
|
||||||
Direction: direction,
|
Direction: direction,
|
||||||
SourceIP: key.SrcIP,
|
SourceIP: srcIP,
|
||||||
DestIP: key.DstIP,
|
DestIP: dstIP,
|
||||||
SourcePort: srcPort,
|
SourcePort: srcPort,
|
||||||
DestPort: dstPort,
|
DestPort: dstPort,
|
||||||
},
|
},
|
||||||
@ -105,12 +110,17 @@ func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
|
|||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
t.logger.Trace("New %s UDP connection: %s", direction, key)
|
t.logger.Trace("New %s UDP connection: %s", direction, key)
|
||||||
t.sendEvent(nftypes.TypeStart, key, conn)
|
t.sendEvent(nftypes.TypeStart, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||||
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
|
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) bool {
|
||||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
key := ConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
SrcPort: dstPort,
|
||||||
|
DstPort: srcPort,
|
||||||
|
}
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
@ -146,7 +156,7 @@ func (t *UDPTracker) cleanup() {
|
|||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Trace("Removed UDP connection %s (timeout)", key)
|
t.logger.Trace("Removed UDP connection %s (timeout)", key)
|
||||||
t.sendEvent(nftypes.TypeEnd, key, conn)
|
t.sendEvent(nftypes.TypeEnd, conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -162,11 +172,16 @@ func (t *UDPTracker) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetConnection safely retrieves a connection state
|
// GetConnection safely retrieves a connection state
|
||||||
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
|
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
defer t.mutex.RUnlock()
|
defer t.mutex.RUnlock()
|
||||||
|
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
return conn, exists
|
return conn, exists
|
||||||
}
|
}
|
||||||
@ -176,15 +191,15 @@ func (t *UDPTracker) Timeout() time.Duration {
|
|||||||
return t.timeout
|
return t.timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UDPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *UDPConnTrack) {
|
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack) {
|
||||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
FlowID: conn.FlowId,
|
FlowID: conn.FlowId,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: conn.Direction,
|
Direction: conn.Direction,
|
||||||
Protocol: nftypes.UDP,
|
Protocol: nftypes.UDP,
|
||||||
SourceIP: key.SrcIP,
|
SourceIP: conn.SourceIP,
|
||||||
DestIP: key.DstIP,
|
DestIP: conn.DestIP,
|
||||||
SourcePort: key.SrcPort,
|
SourcePort: conn.SourcePort,
|
||||||
DestPort: key.DstPort,
|
DestPort: conn.DestPort,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -49,10 +48,15 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
|
|||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
key := makeConnKey(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
conn, exists := tracker.connections[key]
|
conn, exists := tracker.connections[key]
|
||||||
require.True(t, exists)
|
require.True(t, exists)
|
||||||
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
|
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
|
||||||
@ -66,8 +70,8 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
|
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||||
dstIP := net.ParseIP("192.168.1.3")
|
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
@ -76,8 +80,8 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
srcIP net.IP
|
srcIP netip.Addr
|
||||||
dstIP net.IP
|
dstIP netip.Addr
|
||||||
srcPort uint16
|
srcPort uint16
|
||||||
dstPort uint16
|
dstPort uint16
|
||||||
sleep time.Duration
|
sleep time.Duration
|
||||||
@ -94,7 +98,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid source IP",
|
name: "invalid source IP",
|
||||||
srcIP: net.ParseIP("192.168.1.4"),
|
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
dstIP: srcIP,
|
dstIP: srcIP,
|
||||||
srcPort: dstPort,
|
srcPort: dstPort,
|
||||||
dstPort: srcPort,
|
dstPort: srcPort,
|
||||||
@ -104,7 +108,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "invalid destination IP",
|
name: "invalid destination IP",
|
||||||
srcIP: dstIP,
|
srcIP: dstIP,
|
||||||
dstIP: net.ParseIP("192.168.1.4"),
|
dstIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
srcPort: dstPort,
|
srcPort: dstPort,
|
||||||
dstPort: srcPort,
|
dstPort: srcPort,
|
||||||
sleep: 0,
|
sleep: 0,
|
||||||
@ -170,20 +174,20 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
|
|
||||||
// Add some connections
|
// Add some connections
|
||||||
connections := []struct {
|
connections := []struct {
|
||||||
srcIP net.IP
|
srcIP netip.Addr
|
||||||
dstIP net.IP
|
dstIP netip.Addr
|
||||||
srcPort uint16
|
srcPort uint16
|
||||||
dstPort uint16
|
dstPort uint16
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
srcIP: net.ParseIP("192.168.1.2"),
|
srcIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
dstIP: net.ParseIP("192.168.1.3"),
|
dstIP: netip.MustParseAddr("192.168.1.3"),
|
||||||
srcPort: 12345,
|
srcPort: 12345,
|
||||||
dstPort: 53,
|
dstPort: 53,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
srcIP: net.ParseIP("192.168.1.4"),
|
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
dstIP: net.ParseIP("192.168.1.5"),
|
dstIP: netip.MustParseAddr("192.168.1.5"),
|
||||||
srcPort: 12346,
|
srcPort: 12346,
|
||||||
dstPort: 53,
|
dstPort: 53,
|
||||||
},
|
},
|
||||||
@ -215,8 +219,8 @@ func BenchmarkUDPTracker(b *testing.B) {
|
|||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
@ -228,8 +232,8 @@ func BenchmarkUDPTracker(b *testing.B) {
|
|||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
|
@ -3,6 +3,7 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -31,13 +32,9 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
|||||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||||
ipv4 := ip.To4()
|
high := (uint16(ip[0]) << 8) | uint16(ip[1])
|
||||||
if ipv4 == nil {
|
low := (uint16(ip[2]) << 8) | uint16(ip[3])
|
||||||
return false
|
|
||||||
}
|
|
||||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
|
||||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
|
||||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,12 +119,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
defer m.mu.RUnlock()
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
if ipv4 := ip.To4(); ipv4 != nil {
|
if ip.Is4() {
|
||||||
return m.checkBitmapBit(ipv4)
|
return m.checkBitmapBit(ip.AsSlice())
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
@ -2,6 +2,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -13,7 +14,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
setupAddr iface.WGAddress
|
setupAddr iface.WGAddress
|
||||||
testIP net.IP
|
testIP netip.Addr
|
||||||
expected bool
|
expected bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@ -25,7 +26,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.0.0.2"),
|
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -37,7 +38,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.0.0.1"),
|
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -49,7 +50,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.255.255.255"),
|
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -61,7 +62,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("192.168.1.1"),
|
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -73,7 +74,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("192.168.1.2"),
|
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -85,7 +86,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(64, 128),
|
Mask: net.CIDRMask(64, 128),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("fe80::1"),
|
testIP: netip.MustParseAddr("fe80::1"),
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -174,7 +175,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
|||||||
t.Logf("Testing %d IPs", len(tests))
|
t.Logf("Testing %d IPs", len(tests))
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.ip, func(t *testing.T) {
|
t.Run(tt.ip, func(t *testing.T) {
|
||||||
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
|
||||||
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@ -13,7 +12,7 @@ import (
|
|||||||
type PeerRule struct {
|
type PeerRule struct {
|
||||||
id string
|
id string
|
||||||
mgmtId []byte
|
mgmtId []byte
|
||||||
ip net.IP
|
ip netip.Addr
|
||||||
ipLayer gopacket.LayerType
|
ipLayer gopacket.LayerType
|
||||||
matchByIP bool
|
matchByIP bool
|
||||||
protoLayer gopacket.LayerType
|
protoLayer gopacket.LayerType
|
||||||
|
@ -2,7 +2,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@ -53,8 +53,8 @@ type TraceResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PacketTrace struct {
|
type PacketTrace struct {
|
||||||
SourceIP net.IP
|
SourceIP netip.Addr
|
||||||
DestinationIP net.IP
|
DestinationIP netip.Addr
|
||||||
Protocol string
|
Protocol string
|
||||||
SourcePort uint16
|
SourcePort uint16
|
||||||
DestinationPort uint16
|
DestinationPort uint16
|
||||||
@ -72,8 +72,8 @@ type TCPState struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PacketBuilder struct {
|
type PacketBuilder struct {
|
||||||
SrcIP net.IP
|
SrcIP netip.Addr
|
||||||
DstIP net.IP
|
DstIP netip.Addr
|
||||||
Protocol fw.Protocol
|
Protocol fw.Protocol
|
||||||
SrcPort uint16
|
SrcPort uint16
|
||||||
DstPort uint16
|
DstPort uint16
|
||||||
@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
|||||||
Version: 4,
|
Version: 4,
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||||
SrcIP: p.SrcIP,
|
SrcIP: p.SrcIP.AsSlice(),
|
||||||
DstIP: p.DstIP,
|
DstIP: p.DstIP.AsSlice(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,7 +260,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
|||||||
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
@ -273,14 +273,14 @@ func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder
|
|||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.nativeRouter {
|
if m.nativeRouter.Load() {
|
||||||
return m.handleNativeRouter(trace)
|
return m.handleNativeRouter(trace)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
||||||
msg := "No existing connection found"
|
msg := "No existing connection found"
|
||||||
if allowed {
|
if allowed {
|
||||||
@ -309,13 +309,12 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
|||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
if !m.localForwarding {
|
if !m.localForwarding {
|
||||||
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
||||||
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||||
|
|
||||||
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||||
@ -341,7 +340,7 @@ func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||||
if !m.routingEnabled {
|
if !m.routingEnabled.Load() {
|
||||||
trace.AddResult(StageRouting, "Routing disabled", false)
|
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||||
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||||
return false
|
return false
|
||||||
@ -357,7 +356,7 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
|||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||||
proto, _ := getProtocolFromPacket(d)
|
proto, _ := getProtocolFromPacket(d)
|
||||||
srcPort, dstPort := getPortsFromPacket(d)
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||||
@ -373,7 +372,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
|
|||||||
}
|
}
|
||||||
trace.AddResult(StageRouteACL, msg, allowed)
|
trace.AddResult(StageRouteACL, msg, allowed)
|
||||||
|
|
||||||
if allowed && m.forwarder != nil {
|
if allowed && m.forwarder.Load() != nil {
|
||||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
0
client/firewall/uspfilter/tracer_test.go
Normal file
0
client/firewall/uspfilter/tracer_test.go
Normal file
@ -10,6 +10,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
@ -66,9 +67,9 @@ func (r RouteRules) Sort() {
|
|||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
// outgoingRules is used for hooks only
|
// outgoingRules is used for hooks only
|
||||||
outgoingRules map[string]RuleSet
|
outgoingRules map[netip.Addr]RuleSet
|
||||||
// incomingRules is used for filtering and hooks
|
// incomingRules is used for filtering and hooks
|
||||||
incomingRules map[string]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
@ -80,9 +81,9 @@ type Manager struct {
|
|||||||
// indicates whether server routes are disabled
|
// indicates whether server routes are disabled
|
||||||
disableServerRoutes bool
|
disableServerRoutes bool
|
||||||
// indicates whether we forward packets not destined for ourselves
|
// indicates whether we forward packets not destined for ourselves
|
||||||
routingEnabled bool
|
routingEnabled atomic.Bool
|
||||||
// indicates whether we leave forwarding and filtering to the native firewall
|
// indicates whether we leave forwarding and filtering to the native firewall
|
||||||
nativeRouter bool
|
nativeRouter atomic.Bool
|
||||||
// indicates whether we track outbound connections
|
// indicates whether we track outbound connections
|
||||||
stateful bool
|
stateful bool
|
||||||
// indicates whether wireguards runs in netstack mode
|
// indicates whether wireguards runs in netstack mode
|
||||||
@ -95,7 +96,7 @@ type Manager struct {
|
|||||||
udpTracker *conntrack.UDPTracker
|
udpTracker *conntrack.UDPTracker
|
||||||
icmpTracker *conntrack.ICMPTracker
|
icmpTracker *conntrack.ICMPTracker
|
||||||
tcpTracker *conntrack.TCPTracker
|
tcpTracker *conntrack.TCPTracker
|
||||||
forwarder *forwarder.Forwarder
|
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
@ -168,18 +169,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
nativeFirewall: nativeFirewall,
|
nativeFirewall: nativeFirewall,
|
||||||
outgoingRules: make(map[string]RuleSet),
|
outgoingRules: make(map[netip.Addr]RuleSet),
|
||||||
incomingRules: make(map[string]RuleSet),
|
incomingRules: make(map[netip.Addr]RuleSet),
|
||||||
wgIface: iface,
|
wgIface: iface,
|
||||||
localipmanager: newLocalIPManager(),
|
localipmanager: newLocalIPManager(),
|
||||||
disableServerRoutes: disableServerRoutes,
|
disableServerRoutes: disableServerRoutes,
|
||||||
routingEnabled: false,
|
|
||||||
stateful: !disableConntrack,
|
stateful: !disableConntrack,
|
||||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
}
|
}
|
||||||
|
m.routingEnabled.Store(false)
|
||||||
|
|
||||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||||
@ -211,7 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||||
if m.forwarder == nil {
|
if m.forwarder.Load() == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||||
@ -255,20 +256,20 @@ func (m *Manager) determineRouting() error {
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case disableUspRouting:
|
case disableUspRouting:
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
log.Info("userspace routing is disabled")
|
log.Info("userspace routing is disabled")
|
||||||
|
|
||||||
case m.disableServerRoutes:
|
case m.disableServerRoutes:
|
||||||
// if server routes are disabled we will let packets pass to the native stack
|
// if server routes are disabled we will let packets pass to the native stack
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = true
|
m.nativeRouter.Store(true)
|
||||||
|
|
||||||
log.Info("server routes are disabled")
|
log.Info("server routes are disabled")
|
||||||
|
|
||||||
case forceUserspaceRouter:
|
case forceUserspaceRouter:
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
log.Info("userspace routing is forced")
|
log.Info("userspace routing is forced")
|
||||||
|
|
||||||
@ -276,19 +277,19 @@ func (m *Manager) determineRouting() error {
|
|||||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||||
// netstack mode won't support native routing as there is no interface
|
// netstack mode won't support native routing as there is no interface
|
||||||
|
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = true
|
m.nativeRouter.Store(true)
|
||||||
|
|
||||||
log.Info("native routing is enabled")
|
log.Info("native routing is enabled")
|
||||||
|
|
||||||
default:
|
default:
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
log.Info("userspace routing enabled by default")
|
log.Info("userspace routing enabled by default")
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.routingEnabled && !m.nativeRouter {
|
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
|
||||||
return m.initForwarder()
|
return m.initForwarder()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -297,24 +298,24 @@ func (m *Manager) determineRouting() error {
|
|||||||
|
|
||||||
// initForwarder initializes the forwarder, it disables routing on errors
|
// initForwarder initializes the forwarder, it disables routing on errors
|
||||||
func (m *Manager) initForwarder() error {
|
func (m *Manager) initForwarder() error {
|
||||||
if m.forwarder != nil {
|
if m.forwarder.Load() != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||||
intf := m.wgIface.GetWGDevice()
|
intf := m.wgIface.GetWGDevice()
|
||||||
if intf == nil {
|
if intf == nil {
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
return errors.New("forwarding not supported")
|
return errors.New("forwarding not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
|
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
return fmt.Errorf("create forwarder: %w", err)
|
return fmt.Errorf("create forwarder: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.forwarder = forwarder
|
m.forwarder.Store(forwarder)
|
||||||
|
|
||||||
log.Debug("forwarder initialized")
|
log.Debug("forwarder initialized")
|
||||||
|
|
||||||
@ -330,7 +331,7 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AddNatRule(pair)
|
return m.nativeFirewall.AddNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -341,7 +342,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
// RemoveNatRule removes a routing firewall rule
|
// RemoveNatRule removes a routing firewall rule
|
||||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.RemoveNatRule(pair)
|
return m.nativeFirewall.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -360,17 +361,23 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
_ string,
|
_ string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
|
// TODO: fix in upper layers
|
||||||
|
i, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid IP: %s", ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
i = i.Unmap()
|
||||||
r := PeerRule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
mgmtId: id,
|
mgmtId: id,
|
||||||
ip: ip,
|
ip: i,
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
matchByIP: true,
|
matchByIP: true,
|
||||||
drop: action == firewall.ActionDrop,
|
drop: action == firewall.ActionDrop,
|
||||||
}
|
}
|
||||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
if i.Is4() {
|
||||||
r.ipLayer = layers.LayerTypeIPv4
|
r.ipLayer = layers.LayerTypeIPv4
|
||||||
r.ip = ipNormalized
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
||||||
@ -395,10 +402,10 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
m.incomingRules[r.ip] = make(RuleSet)
|
||||||
}
|
}
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
m.incomingRules[r.ip][r.id] = r
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
return []firewall.Rule{&r}, nil
|
return []firewall.Rule{&r}, nil
|
||||||
}
|
}
|
||||||
@ -412,13 +419,10 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
ruleID := uuid.New().String()
|
ruleID := uuid.New().String()
|
||||||
rule := RouteRule{
|
rule := RouteRule{
|
||||||
// TODO: consolidate these IDs
|
// TODO: consolidate these IDs
|
||||||
@ -432,14 +436,16 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
action: action,
|
action: action,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
m.routeRules = append(m.routeRules, rule)
|
m.routeRules = append(m.routeRules, rule)
|
||||||
m.routeRules.Sort()
|
m.routeRules.Sort()
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -468,10 +474,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
|
if _, ok := m.incomingRules[r.ip][r.id]; !ok {
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
}
|
}
|
||||||
delete(m.incomingRules[r.ip.String()], r.id)
|
delete(m.incomingRules[r.ip], r.id)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -519,9 +525,6 @@ func (m *Manager) UpdateLocalIPs() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
|
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@ -534,7 +537,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
srcIP, dstIP := m.extractIPs(d)
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
if srcIP == nil {
|
if !srcIP.IsValid() {
|
||||||
|
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -551,14 +555,18 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
|
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
|
||||||
switch d.decoded[0] {
|
switch d.decoded[0] {
|
||||||
case layers.LayerTypeIPv4:
|
case layers.LayerTypeIPv4:
|
||||||
return d.ip4.SrcIP, d.ip4.DstIP
|
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
||||||
|
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||||
|
return src, dst
|
||||||
case layers.LayerTypeIPv6:
|
case layers.LayerTypeIPv6:
|
||||||
return d.ip6.SrcIP, d.ip6.DstIP
|
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
||||||
|
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
||||||
|
return src, dst
|
||||||
default:
|
default:
|
||||||
return nil, nil
|
return netip.Addr{}, netip.Addr{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -585,7 +593,7 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
|||||||
return flags
|
return flags
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) {
|
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr) {
|
||||||
transport := d.decoded[1]
|
transport := d.decoded[1]
|
||||||
switch transport {
|
switch transport {
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
@ -598,7 +606,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) {
|
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr) {
|
||||||
transport := d.decoded[1]
|
transport := d.decoded[1]
|
||||||
switch transport {
|
switch transport {
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
@ -611,8 +619,11 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
|
func (m *Manager) checkUDPHooks(d *decoder, dstIP netip.Addr, packetData []byte) bool {
|
||||||
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
|
for _, ipKey := range []netip.Addr{dstIP, netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
|
||||||
if rules, exists := m.outgoingRules[ipKey]; exists {
|
if rules, exists := m.outgoingRules[ipKey]; exists {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||||
@ -627,9 +638,6 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo
|
|||||||
// dropFilter implements filtering logic for incoming packets.
|
// dropFilter implements filtering logic for incoming packets.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) dropFilter(packetData []byte) bool {
|
func (m *Manager) dropFilter(packetData []byte) bool {
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
|
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@ -638,7 +646,7 @@ func (m *Manager) dropFilter(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
srcIP, dstIP := m.extractIPs(d)
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
if srcIP == nil {
|
if !srcIP.IsValid() {
|
||||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -658,15 +666,13 @@ func (m *Manager) dropFilter(packetData []byte) bool {
|
|||||||
|
|
||||||
// handleLocalTraffic handles local traffic.
|
// handleLocalTraffic handles local traffic.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
|
||||||
if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked {
|
if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked {
|
||||||
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
|
||||||
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
|
||||||
_, pnum := getProtocolFromPacket(d)
|
_, pnum := getProtocolFromPacket(d)
|
||||||
srcPort, dstPort := getPortsFromPacket(d)
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||||
ruleId, pnum, srcAddr, srcPort, dstAddr, dstPort)
|
ruleId, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
FlowID: uuid.New(),
|
FlowID: uuid.New(),
|
||||||
@ -674,8 +680,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData
|
|||||||
RuleID: ruleId,
|
RuleID: ruleId,
|
||||||
Direction: nftypes.Ingress,
|
Direction: nftypes.Ingress,
|
||||||
Protocol: pnum,
|
Protocol: pnum,
|
||||||
SourceIP: srcAddr,
|
SourceIP: srcIP,
|
||||||
DestIP: dstAddr,
|
DestIP: dstIP,
|
||||||
SourcePort: srcPort,
|
SourcePort: srcPort,
|
||||||
DestPort: dstPort,
|
DestPort: dstPort,
|
||||||
// TODO: icmp type/code
|
// TODO: icmp type/code
|
||||||
@ -700,12 +706,12 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder == nil {
|
if m.forwarder.Load() == nil {
|
||||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
|
||||||
m.logger.Error("Failed to inject local packet: %v", err)
|
m.logger.Error("Failed to inject local packet: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -715,16 +721,16 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
|||||||
|
|
||||||
// handleRoutedTraffic handles routed traffic.
|
// handleRoutedTraffic handles routed traffic.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
|
||||||
// Drop if routing is disabled
|
// Drop if routing is disabled
|
||||||
if !m.routingEnabled {
|
if !m.routingEnabled.Load() {
|
||||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||||
srcIP, dstIP)
|
srcIP, dstIP)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pass to native stack if native router is enabled or forced
|
// Pass to native stack if native router is enabled or forced
|
||||||
if m.nativeRouter {
|
if m.nativeRouter.Load() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -732,9 +738,6 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
|
|||||||
srcPort, dstPort := getPortsFromPacket(d)
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
if id, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
|
if id, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
|
||||||
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
|
||||||
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
|
||||||
|
|
||||||
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||||
id, pnum, srcIP, srcPort, dstIP, dstPort)
|
id, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
@ -744,8 +747,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
|
|||||||
RuleID: id,
|
RuleID: id,
|
||||||
Direction: nftypes.Ingress,
|
Direction: nftypes.Ingress,
|
||||||
Protocol: pnum,
|
Protocol: pnum,
|
||||||
SourceIP: srcAddr,
|
SourceIP: srcIP,
|
||||||
DestIP: dstAddr,
|
DestIP: dstIP,
|
||||||
SourcePort: srcPort,
|
SourcePort: srcPort,
|
||||||
DestPort: dstPort,
|
DestPort: dstPort,
|
||||||
// TODO: icmp type/code
|
// TODO: icmp type/code
|
||||||
@ -754,7 +757,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Let forwarder handle the packet if it passed route ACLs
|
// Let forwarder handle the packet if it passed route ACLs
|
||||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
|
||||||
m.logger.Error("Failed to inject incoming packet: %v", err)
|
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -799,7 +802,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
return m.tcpTracker.IsValidInbound(
|
return m.tcpTracker.IsValidInbound(
|
||||||
@ -844,20 +847,22 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
|
|||||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) ([]byte, bool) {
|
func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
if m.isSpecialICMP(d) {
|
if m.isSpecialICMP(d) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
|
||||||
return mgmtId, filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
|
|
||||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
|
||||||
return mgmtId, filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
|
|
||||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
|
||||||
return mgmtId, filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -882,10 +887,10 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
||||||
payloadLayer := d.decoded[1]
|
payloadLayer := d.decoded[1]
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -919,16 +924,13 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *de
|
|||||||
return nil, false, false
|
return nil, false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// routeACLsPass returns treu if the packet is allowed by the route ACLs
|
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||||
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
|
|
||||||
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
|
|
||||||
|
|
||||||
for _, rule := range m.routeRules {
|
for _, rule := range m.routeRules {
|
||||||
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
|
||||||
return rule.mgmtId, rule.action == firewall.ActionAccept
|
return rule.mgmtId, rule.action == firewall.ActionAccept
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -972,9 +974,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
|||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||||
func (m *Manager) AddUDPPacketHook(
|
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
|
||||||
) string {
|
|
||||||
r := PeerRule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
ip: ip,
|
||||||
@ -984,23 +984,22 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
udpHook: hook,
|
udpHook: hook,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip.To4() != nil {
|
if ip.Is4() {
|
||||||
r.ipLayer = layers.LayerTypeIPv4
|
r.ipLayer = layers.LayerTypeIPv4
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if in {
|
if in {
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
m.incomingRules[r.ip][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||||
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
m.outgoingRules[r.ip][r.id] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return r.id
|
return r.id
|
||||||
@ -1048,20 +1047,21 @@ func (m *Manager) DisableRouting() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if m.forwarder == nil {
|
fwder := m.forwarder.Load()
|
||||||
|
if fwder == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
// don't stop forwarder if in use by netstack
|
// don't stop forwarder if in use by netstack
|
||||||
if m.netstack && m.localForwarding {
|
if m.netstack && m.localForwarding {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
m.forwarder.Stop()
|
fwder.Stop()
|
||||||
m.forwarder = nil
|
m.forwarder.Store(nil)
|
||||||
|
|
||||||
log.Debug("forwarder stopped")
|
log.Debug("forwarder stopped")
|
||||||
|
|
||||||
|
@ -1054,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
srcIP := net.ParseIP(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
dstIP := net.ParseIP(tc.dstIP)
|
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -306,8 +306,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
require.NoError(tb, manager.EnableRouting())
|
require.NoError(tb, manager.EnableRouting())
|
||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
require.NotNil(tb, manager)
|
require.NotNil(tb, manager)
|
||||||
require.True(tb, manager.routingEnabled)
|
require.True(tb, manager.routingEnabled.Load())
|
||||||
require.False(tb, manager.nativeRouter)
|
require.False(tb, manager.nativeRouter.Load())
|
||||||
|
|
||||||
tb.Cleanup(func() {
|
tb.Cleanup(func() {
|
||||||
require.NoError(tb, manager.Reset(nil))
|
require.NoError(tb, manager.Reset(nil))
|
||||||
@ -818,8 +818,8 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||||
})
|
})
|
||||||
|
|
||||||
srcIP := net.ParseIP(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
dstIP := net.ParseIP(tc.dstIP)
|
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||||
|
|
||||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
||||||
// to the forwarder
|
// to the forwarder
|
||||||
@ -1006,8 +1006,8 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
for i, p := range tc.packets {
|
for i, p := range tc.packets {
|
||||||
srcIP := net.ParseIP(p.srcIP)
|
srcIP := netip.MustParseAddr(p.srcIP)
|
||||||
dstIP := net.ParseIP(p.dstIP)
|
dstIP := netip.MustParseAddr(p.dstIP)
|
||||||
|
|
||||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||||
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
||||||
|
@ -125,19 +125,19 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := netip.MustParseAddr("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
|
|
||||||
rule2, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok {
|
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -151,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok {
|
if _, ok := m.incomingRules[ip][r.ID()]; ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -162,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
in bool
|
in bool
|
||||||
expDir fw.RuleDirection
|
expDir fw.RuleDirection
|
||||||
ip net.IP
|
ip netip.Addr
|
||||||
dPort uint16
|
dPort uint16
|
||||||
hook func([]byte) bool
|
hook func([]byte) bool
|
||||||
expectedID string
|
expectedID string
|
||||||
@ -171,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name: "Test Outgoing UDP Packet Hook",
|
name: "Test Outgoing UDP Packet Hook",
|
||||||
in: false,
|
in: false,
|
||||||
expDir: fw.RuleDirectionOUT,
|
expDir: fw.RuleDirectionOUT,
|
||||||
ip: net.IPv4(10, 168, 0, 1),
|
ip: netip.MustParseAddr("10.168.0.1"),
|
||||||
dPort: 8000,
|
dPort: 8000,
|
||||||
hook: func([]byte) bool { return true },
|
hook: func([]byte) bool { return true },
|
||||||
},
|
},
|
||||||
@ -179,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name: "Test Incoming UDP Packet Hook",
|
name: "Test Incoming UDP Packet Hook",
|
||||||
in: true,
|
in: true,
|
||||||
expDir: fw.RuleDirectionIN,
|
expDir: fw.RuleDirectionIN,
|
||||||
ip: net.IPv6loopback,
|
ip: netip.MustParseAddr("::1"),
|
||||||
dPort: 9000,
|
dPort: 9000,
|
||||||
hook: func([]byte) bool { return false },
|
hook: func([]byte) bool { return false },
|
||||||
},
|
},
|
||||||
@ -196,11 +196,11 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
|
|
||||||
var addedRule PeerRule
|
var addedRule PeerRule
|
||||||
if tt.in {
|
if tt.in {
|
||||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
for _, rule := range manager.incomingRules[tt.ip] {
|
||||||
addedRule = rule
|
addedRule = rule
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -208,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||||
addedRule = rule
|
addedRule = rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tt.ip.Equal(addedRule.ip) {
|
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -357,7 +357,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
|
|
||||||
// Add a UDP packet hook
|
// Add a UDP packet hook
|
||||||
hookFunc := func(data []byte) bool { return true }
|
hookFunc := func(data []byte) bool { return true }
|
||||||
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
|
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||||
|
|
||||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||||
found := false
|
found := false
|
||||||
@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
hookCalled := false
|
hookCalled := false
|
||||||
hookID := manager.AddUDPPacketHook(
|
hookID := manager.AddUDPPacketHook(
|
||||||
false,
|
false,
|
||||||
net.ParseIP("100.10.0.100"),
|
netip.MustParseAddr("100.10.0.100"),
|
||||||
53,
|
53,
|
||||||
func([]byte) bool {
|
func([]byte) bool {
|
||||||
hookCalled = true
|
hookCalled = true
|
||||||
@ -573,7 +573,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
|
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
require.True(t, exists, "Connection should be tracked after outbound packet")
|
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||||
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
|
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
|
||||||
@ -641,7 +641,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
|
|
||||||
// If the connection should still be valid, verify it exists
|
// If the connection should still be valid, verify it exists
|
||||||
if cp.shouldAllow {
|
if cp.shouldAllow {
|
||||||
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
|
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||||
require.True(t, exists, "Connection should still exist during valid window")
|
require.True(t, exists, "Connection should still exist during valid window")
|
||||||
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
|
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
|
||||||
"LastSeen should be updated for valid responses")
|
"LastSeen should be updated for valid responses")
|
||||||
|
@ -2,6 +2,7 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
@ -19,7 +20,7 @@ type PacketFilter interface {
|
|||||||
//
|
//
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||||
// Hook function receives raw network packet data as argument.
|
// Hook function receives raw network packet data as argument.
|
||||||
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
|
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||||
|
|
||||||
// RemovePacketHook removes hook by ID
|
// RemovePacketHook removes hook by ID
|
||||||
RemovePacketHook(hookID string) error
|
RemovePacketHook(hookID string) error
|
||||||
|
@ -6,6 +6,7 @@ package mocks
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
net "net"
|
net "net"
|
||||||
|
"net/netip"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddUDPPacketHook mocks base method.
|
// AddUDPPacketHook mocks base method.
|
||||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
|
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||||
ret0, _ := ret[0].(string)
|
ret0, _ := ret[0].(string)
|
||||||
|
@ -2,7 +2,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@ -117,5 +117,10 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
ip, err := netip.ParseAddr(s.runtimeIP)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("parse runtime ip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
@ -41,11 +42,21 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
|||||||
srcIP = engine.GetWgAddr()
|
srcIP = engine.GetWgAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
srcAddr, ok := netip.AddrFromSlice(srcIP)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid source IP address")
|
||||||
|
}
|
||||||
|
|
||||||
dstIP := net.ParseIP(req.GetDestinationIp())
|
dstIP := net.ParseIP(req.GetDestinationIp())
|
||||||
if req.GetDestinationIp() == "self" {
|
if req.GetDestinationIp() == "self" {
|
||||||
dstIP = engine.GetWgAddr()
|
dstIP = engine.GetWgAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dstAddr, ok := netip.AddrFromSlice(dstIP)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid source IP address")
|
||||||
|
}
|
||||||
|
|
||||||
if srcIP == nil || dstIP == nil {
|
if srcIP == nil || dstIP == nil {
|
||||||
return nil, fmt.Errorf("invalid IP address")
|
return nil, fmt.Errorf("invalid IP address")
|
||||||
}
|
}
|
||||||
@ -85,8 +96,8 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
builder := &uspfilter.PacketBuilder{
|
builder := &uspfilter.PacketBuilder{
|
||||||
SrcIP: srcIP,
|
SrcIP: srcAddr,
|
||||||
DstIP: dstIP,
|
DstIP: dstAddr,
|
||||||
Protocol: protocol,
|
Protocol: protocol,
|
||||||
SrcPort: uint16(req.GetSourcePort()),
|
SrcPort: uint16(req.GetSourcePort()),
|
||||||
DstPort: uint16(req.GetDestinationPort()),
|
DstPort: uint16(req.GetDestinationPort()),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user