mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-24 19:51:33 +02:00
Add userspace flow implementation (#3393)
This commit is contained in:
parent
cccc615783
commit
fa748a7ec2
@ -10,17 +10,18 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewFirewall creates a firewall manager instance
|
// NewFirewall creates a firewall manager instance
|
||||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// use userspace packet filtering firewall
|
// use userspace packet filtering firewall
|
||||||
fm, err := uspfilter.Create(iface, disableServerRoutes)
|
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -33,7 +34,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|||||||
// FWType is the type for the firewall type
|
// FWType is the type for the firewall type
|
||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
// on the linux system we try to user nftables or iptables
|
// on the linux system we try to user nftables or iptables
|
||||||
// in any case, because we need to allow netbird interface traffic
|
// in any case, because we need to allow netbird interface traffic
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
@ -47,7 +48,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableS
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
}
|
}
|
||||||
return createUserspaceFirewall(iface, fm, disableServerRoutes)
|
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||||
@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
|
||||||
var errUsp error
|
var errUsp error
|
||||||
if fm != nil {
|
if fm != nil {
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||||
} else {
|
} else {
|
||||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
|
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errUsp != nil {
|
if errUsp != nil {
|
||||||
|
@ -22,17 +22,17 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
|||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder != nil {
|
if m.forwarder != nil {
|
||||||
|
@ -31,17 +31,17 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
|||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder != nil {
|
if m.forwarder != nil {
|
||||||
|
@ -1,20 +1,26 @@
|
|||||||
// common.go
|
|
||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"net/netip"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BaseConnTrack provides common fields and locking for all connection types
|
// BaseConnTrack provides common fields and locking for all connection types
|
||||||
type BaseConnTrack struct {
|
type BaseConnTrack struct {
|
||||||
SourceIP net.IP
|
FlowId uuid.UUID
|
||||||
DestIP net.IP
|
Direction nftypes.Direction
|
||||||
|
SourceIP netip.Addr
|
||||||
|
DestIP netip.Addr
|
||||||
SourcePort uint16
|
SourcePort uint16
|
||||||
DestPort uint16
|
DestPort uint16
|
||||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
lastSeen atomic.Int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// these small methods will be inlined by the compiler
|
// these small methods will be inlined by the compiler
|
||||||
@ -35,92 +41,27 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
|
|||||||
return time.Since(lastSeen) > timeout
|
return time.Since(lastSeen) > timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPAddr is a fixed-size IP address to avoid allocations
|
|
||||||
type IPAddr [16]byte
|
|
||||||
|
|
||||||
// MakeIPAddr creates an IPAddr from net.IP
|
|
||||||
func MakeIPAddr(ip net.IP) (addr IPAddr) {
|
|
||||||
// Optimization: check for v4 first as it's more common
|
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
|
||||||
copy(addr[12:], ip4)
|
|
||||||
} else {
|
|
||||||
copy(addr[:], ip.To16())
|
|
||||||
}
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnKey uniquely identifies a connection
|
// ConnKey uniquely identifies a connection
|
||||||
type ConnKey struct {
|
type ConnKey struct {
|
||||||
SrcIP IPAddr
|
SrcIP netip.Addr
|
||||||
DstIP IPAddr
|
DstIP netip.Addr
|
||||||
SrcPort uint16
|
SrcPort uint16
|
||||||
DstPort uint16
|
DstPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c ConnKey) String() string {
|
||||||
|
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||||
|
}
|
||||||
|
|
||||||
// makeConnKey creates a connection key
|
// makeConnKey creates a connection key
|
||||||
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
||||||
|
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
||||||
|
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
||||||
|
|
||||||
return ConnKey{
|
return ConnKey{
|
||||||
SrcIP: MakeIPAddr(srcIP),
|
SrcIP: srcAddr,
|
||||||
DstIP: MakeIPAddr(dstIP),
|
DstIP: dstAddr,
|
||||||
SrcPort: srcPort,
|
SrcPort: srcPort,
|
||||||
DstPort: dstPort,
|
DstPort: dstPort,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateIPs checks if IPs match without allocation
|
|
||||||
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
|
|
||||||
if ip4 := pktIP.To4(); ip4 != nil {
|
|
||||||
// Compare IPv4 addresses (last 4 bytes)
|
|
||||||
for i := 0; i < 4; i++ {
|
|
||||||
if connIP[12+i] != ip4[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Compare full IPv6 addresses
|
|
||||||
ip6 := pktIP.To16()
|
|
||||||
for i := 0; i < 16; i++ {
|
|
||||||
if connIP[i] != ip6[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
|
|
||||||
type PreallocatedIPs struct {
|
|
||||||
sync.Pool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPreallocatedIPs creates a new IP pool
|
|
||||||
func NewPreallocatedIPs() *PreallocatedIPs {
|
|
||||||
return &PreallocatedIPs{
|
|
||||||
Pool: sync.Pool{
|
|
||||||
New: func() interface{} {
|
|
||||||
ip := make(net.IP, 16)
|
|
||||||
return &ip
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get retrieves an IP from the pool
|
|
||||||
func (p *PreallocatedIPs) Get() net.IP {
|
|
||||||
return *p.Pool.Get().(*net.IP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put returns an IP to the pool
|
|
||||||
func (p *PreallocatedIPs) Put(ip net.IP) {
|
|
||||||
p.Pool.Put(&ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// copyIP copies an IP address efficiently
|
|
||||||
func copyIP(dst, src net.IP) {
|
|
||||||
if len(src) == 16 {
|
|
||||||
copy(dst, src)
|
|
||||||
} else {
|
|
||||||
// Handle IPv4
|
|
||||||
copy(dst[12:], src.To4())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,50 +1,23 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
var flowLogger = netflow.NewManager(context.Background()).GetLogger()
|
||||||
func BenchmarkIPOperations(b *testing.B) {
|
|
||||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = MakeIPAddr(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("ValidateIPs", func(b *testing.B) {
|
|
||||||
ip1 := net.ParseIP("192.168.1.1")
|
|
||||||
ip2 := net.ParseIP("192.168.1.1")
|
|
||||||
addr := MakeIPAddr(ip1)
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = ValidateIPs(addr, ip2)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("IPPool", func(b *testing.B) {
|
|
||||||
pool := NewPreallocatedIPs()
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
ip := pool.Get()
|
|
||||||
pool.Put(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Memory pressure tests
|
// Memory pressure tests
|
||||||
func BenchmarkMemoryPressure(b *testing.B) {
|
func BenchmarkMemoryPressure(b *testing.B) {
|
||||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
@ -69,7 +42,7 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
|
@ -1,13 +1,17 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -19,18 +23,19 @@ const (
|
|||||||
|
|
||||||
// ICMPConnKey uniquely identifies an ICMP connection
|
// ICMPConnKey uniquely identifies an ICMP connection
|
||||||
type ICMPConnKey struct {
|
type ICMPConnKey struct {
|
||||||
// Supports both IPv4 and IPv6
|
SrcIP netip.Addr
|
||||||
SrcIP [16]byte
|
DstIP netip.Addr
|
||||||
DstIP [16]byte
|
Sequence uint16
|
||||||
Sequence uint16 // ICMP sequence number
|
ID uint16
|
||||||
ID uint16 // ICMP identifier
|
}
|
||||||
|
|
||||||
|
func (i *ICMPConnKey) String() string {
|
||||||
|
return fmt.Sprintf("%s -> %s (%d/%d)", i.SrcIP, i.DstIP, i.Sequence, i.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPConnTrack represents an ICMP connection state
|
// ICMPConnTrack represents an ICMP connection state
|
||||||
type ICMPConnTrack struct {
|
type ICMPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
Sequence uint16
|
|
||||||
ID uint16
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPTracker manages ICMP connection states
|
// ICMPTracker manages ICMP connection states
|
||||||
@ -41,11 +46,11 @@ type ICMPTracker struct {
|
|||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
ipPool *PreallocatedIPs
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewICMPTracker creates a new ICMP connection tracker
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultICMPTimeout
|
timeout = DefaultICMPTimeout
|
||||||
}
|
}
|
||||||
@ -56,41 +61,65 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
|||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
ipPool: NewPreallocatedIPs(),
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine()
|
go tracker.cleanupRoutine()
|
||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound ICMP Echo Request
|
func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) (ICMPConnKey, bool) {
|
||||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
|
||||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
if !exists {
|
t.mutex.RUnlock()
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &ICMPConnTrack{
|
if exists {
|
||||||
BaseConnTrack: BaseConnTrack{
|
|
||||||
SourceIP: srcIPCopy,
|
|
||||||
DestIP: dstIPCopy,
|
|
||||||
},
|
|
||||||
ID: id,
|
|
||||||
Sequence: seq,
|
|
||||||
}
|
|
||||||
conn.UpdateLastSeen()
|
conn.UpdateLastSeen()
|
||||||
t.connections[key] = conn
|
|
||||||
|
|
||||||
t.logger.Trace("New ICMP connection %v", key)
|
return key, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound records an outbound ICMP connection
|
||||||
|
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||||
|
if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
|
t.track(srcIP, dstIP, id, seq, nftypes.Egress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackInbound records an inbound ICMP Echo Request
|
||||||
|
func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||||
|
t.track(srcIP, dstIP, id, seq, nftypes.Ingress)
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||||
|
func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, direction nftypes.Direction) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, id, seq)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &ICMPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: key.SrcIP,
|
||||||
|
DestIP: key.DstIP,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
conn.UpdateLastSeen()
|
t.logger.Trace("New %s ICMP connection %s", conn.Direction, key)
|
||||||
|
t.sendEvent(nftypes.TypeStart, key, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||||
@ -105,18 +134,13 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
conn.UpdateLastSeen()
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
return true
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
|
||||||
conn.ID == id &&
|
|
||||||
conn.Sequence == seq
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ICMPTracker) cleanupRoutine() {
|
func (t *ICMPTracker) cleanupRoutine() {
|
||||||
@ -129,17 +153,17 @@ func (t *ICMPTracker) cleanupRoutine() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ICMPTracker) cleanup() {
|
func (t *ICMPTracker) cleanup() {
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
defer t.mutex.Unlock()
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
|
t.logger.Debug("Removed ICMP connection %s (timeout)", &key)
|
||||||
|
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -150,19 +174,29 @@ func (t *ICMPTracker) Close() {
|
|||||||
close(t.done)
|
close(t.done)
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *ICMPTracker) sendEvent(typ nftypes.Type, key ICMPConnKey, conn *ICMPConnTrack) {
|
||||||
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: conn.FlowId,
|
||||||
|
Type: typ,
|
||||||
|
Direction: conn.Direction,
|
||||||
|
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||||
|
SourceIP: key.SrcIP,
|
||||||
|
DestIP: key.DstIP,
|
||||||
|
// TODO: add icmp code/type,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// makeICMPKey creates an ICMP connection key
|
// makeICMPKey creates an ICMP connection key
|
||||||
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
||||||
|
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
||||||
|
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
||||||
return ICMPConnKey{
|
return ICMPConnKey{
|
||||||
SrcIP: MakeIPAddr(srcIP),
|
SrcIP: srcAddr,
|
||||||
DstIP: MakeIPAddr(dstIP),
|
DstIP: dstAddr,
|
||||||
ID: id,
|
ID: id,
|
||||||
Sequence: seq,
|
Sequence: seq,
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
func BenchmarkICMPTracker(b *testing.B) {
|
func BenchmarkICMPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
@ -8,7 +8,10 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -39,6 +42,35 @@ const (
|
|||||||
// TCPState represents the state of a TCP connection
|
// TCPState represents the state of a TCP connection
|
||||||
type TCPState int
|
type TCPState int
|
||||||
|
|
||||||
|
func (s TCPState) String() string {
|
||||||
|
switch s {
|
||||||
|
case TCPStateNew:
|
||||||
|
return "New"
|
||||||
|
case TCPStateSynSent:
|
||||||
|
return "SYN Sent"
|
||||||
|
case TCPStateSynReceived:
|
||||||
|
return "SYN Received"
|
||||||
|
case TCPStateEstablished:
|
||||||
|
return "Established"
|
||||||
|
case TCPStateFinWait1:
|
||||||
|
return "FIN Wait 1"
|
||||||
|
case TCPStateFinWait2:
|
||||||
|
return "FIN Wait 2"
|
||||||
|
case TCPStateClosing:
|
||||||
|
return "Closing"
|
||||||
|
case TCPStateTimeWait:
|
||||||
|
return "Time Wait"
|
||||||
|
case TCPStateCloseWait:
|
||||||
|
return "Close Wait"
|
||||||
|
case TCPStateLastAck:
|
||||||
|
return "Last ACK"
|
||||||
|
case TCPStateClosed:
|
||||||
|
return "Closed"
|
||||||
|
default:
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TCPStateNew TCPState = iota
|
TCPStateNew TCPState = iota
|
||||||
TCPStateSynSent
|
TCPStateSynSent
|
||||||
@ -53,19 +85,12 @@ const (
|
|||||||
TCPStateClosed
|
TCPStateClosed
|
||||||
)
|
)
|
||||||
|
|
||||||
// TCPConnKey uniquely identifies a TCP connection
|
|
||||||
type TCPConnKey struct {
|
|
||||||
SrcIP [16]byte
|
|
||||||
DstIP [16]byte
|
|
||||||
SrcPort uint16
|
|
||||||
DstPort uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// TCPConnTrack represents a TCP connection state
|
// TCPConnTrack represents a TCP connection state
|
||||||
type TCPConnTrack struct {
|
type TCPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
State TCPState
|
State TCPState
|
||||||
established atomic.Bool
|
established atomic.Bool
|
||||||
|
tombstone atomic.Bool
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,6 +104,16 @@ func (t *TCPConnTrack) SetEstablished(state bool) {
|
|||||||
t.established.Store(state)
|
t.established.Store(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsTombstone safely checks if the connection is marked for deletion
|
||||||
|
func (t *TCPConnTrack) IsTombstone() bool {
|
||||||
|
return t.tombstone.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTombstone safely marks the connection for deletion
|
||||||
|
func (t *TCPConnTrack) SetTombstone() {
|
||||||
|
t.tombstone.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
// TCPTracker manages TCP connection states
|
// TCPTracker manages TCP connection states
|
||||||
type TCPTracker struct {
|
type TCPTracker struct {
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
@ -87,68 +122,94 @@ type TCPTracker struct {
|
|||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
ipPool *PreallocatedIPs
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTCPTracker creates a new TCP connection tracker
|
// NewTCPTracker creates a new TCP connection tracker
|
||||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = DefaultTCPTimeout
|
||||||
|
}
|
||||||
|
|
||||||
tracker := &TCPTracker{
|
tracker := &TCPTracker{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
connections: make(map[ConnKey]*TCPConnTrack),
|
connections: make(map[ConnKey]*TCPConnTrack),
|
||||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
ipPool: NewPreallocatedIPs(),
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine()
|
go tracker.cleanupRoutine()
|
||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound processes an outbound TCP packet and updates connection state
|
func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) {
|
||||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
|
||||||
// Create key before lock
|
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
if !exists {
|
t.mutex.RUnlock()
|
||||||
// Use preallocated IPs
|
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &TCPConnTrack{
|
if exists {
|
||||||
BaseConnTrack: BaseConnTrack{
|
conn.Lock()
|
||||||
SourceIP: srcIPCopy,
|
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
|
||||||
DestIP: dstIPCopy,
|
|
||||||
SourcePort: srcPort,
|
|
||||||
DestPort: dstPort,
|
|
||||||
},
|
|
||||||
State: TCPStateNew,
|
|
||||||
}
|
|
||||||
conn.UpdateLastSeen()
|
conn.UpdateLastSeen()
|
||||||
conn.established.Store(false)
|
conn.Unlock()
|
||||||
t.connections[key] = conn
|
|
||||||
|
|
||||||
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
return key, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound records an outbound TCP connection
|
||||||
|
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||||
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||||
|
func (t *TCPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||||
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress)
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
|
func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &TCPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: key.SrcIP,
|
||||||
|
DestIP: key.DstIP,
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.established.Store(false)
|
||||||
|
conn.tombstone.Store(false)
|
||||||
|
|
||||||
|
t.logger.Trace("New %s TCP connection: %s", direction, key)
|
||||||
|
t.updateState(key, conn, flags, direction == nftypes.Egress)
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
// Lock individual connection for state update
|
t.sendEvent(nftypes.TypeStart, key, conn)
|
||||||
conn.Lock()
|
|
||||||
t.updateState(conn, flags, true)
|
|
||||||
conn.Unlock()
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||||
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
||||||
if !isValidFlagCombination(flags) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
@ -159,21 +220,25 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle RST packets
|
// Handle RST flag specially - it always causes transition to closed
|
||||||
if flags&TCPRst != 0 {
|
if flags&TCPRst != 0 {
|
||||||
conn.Lock()
|
if conn.IsTombstone() {
|
||||||
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
|
|
||||||
conn.State = TCPStateClosed
|
|
||||||
conn.SetEstablished(false)
|
|
||||||
conn.Unlock()
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn.Lock()
|
||||||
|
conn.SetTombstone()
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetEstablished(false)
|
||||||
conn.Unlock()
|
conn.Unlock()
|
||||||
return false
|
|
||||||
|
t.logger.Trace("TCP connection reset: %s", key)
|
||||||
|
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.Lock()
|
conn.Lock()
|
||||||
t.updateState(conn, flags, false)
|
t.updateState(key, conn, flags, false)
|
||||||
conn.UpdateLastSeen()
|
conn.UpdateLastSeen()
|
||||||
isEstablished := conn.IsEstablished()
|
isEstablished := conn.IsEstablished()
|
||||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||||
@ -183,18 +248,15 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateState updates the TCP connection state based on flags
|
// updateState updates the TCP connection state based on flags
|
||||||
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||||
// Handle RST flag specially - it always causes transition to closed
|
state := conn.State
|
||||||
if flags&TCPRst != 0 {
|
defer func() {
|
||||||
conn.State = TCPStateClosed
|
if state != conn.State {
|
||||||
conn.SetEstablished(false)
|
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
|
switch state {
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch conn.State {
|
|
||||||
case TCPStateNew:
|
case TCPStateNew:
|
||||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||||
conn.State = TCPStateSynSent
|
conn.State = TCPStateSynSent
|
||||||
@ -241,6 +303,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
case TCPStateFinWait2:
|
case TCPStateFinWait2:
|
||||||
if flags&TCPFin != 0 {
|
if flags&TCPFin != 0 {
|
||||||
conn.State = TCPStateTimeWait
|
conn.State = TCPStateTimeWait
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection %s completed", key)
|
||||||
|
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateClosing:
|
case TCPStateClosing:
|
||||||
@ -248,8 +313,8 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
conn.State = TCPStateTimeWait
|
conn.State = TCPStateTimeWait
|
||||||
// Keep established = false from previous state
|
// Keep established = false from previous state
|
||||||
|
|
||||||
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
|
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateCloseWait:
|
case TCPStateCloseWait:
|
||||||
@ -260,17 +325,12 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
case TCPStateLastAck:
|
case TCPStateLastAck:
|
||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateClosed
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetTombstone()
|
||||||
|
|
||||||
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
|
// Send close event for gracefully closed connections
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||||
|
t.logger.Trace("TCP connection %s closed gracefully", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateTimeWait:
|
|
||||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
|
||||||
// This is handled by the cleanup routine
|
|
||||||
|
|
||||||
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
|
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -331,6 +391,12 @@ func (t *TCPTracker) cleanup() {
|
|||||||
defer t.mutex.Unlock()
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
|
if conn.IsTombstone() {
|
||||||
|
// Clean up tombstoned connections without sending an event
|
||||||
|
delete(t.connections, key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
var timeout time.Duration
|
var timeout time.Duration
|
||||||
switch {
|
switch {
|
||||||
case conn.State == TCPStateTimeWait:
|
case conn.State == TCPStateTimeWait:
|
||||||
@ -341,14 +407,16 @@ func (t *TCPTracker) cleanup() {
|
|||||||
timeout = TCPHandshakeTimeout
|
timeout = TCPHandshakeTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
lastSeen := conn.GetLastSeen()
|
if conn.timeoutExceeded(timeout) {
|
||||||
if time.Since(lastSeen) > timeout {
|
|
||||||
// Return IPs to pool
|
// Return IPs to pool
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
t.logger.Trace("Cleaned up timed-out TCP connection %s", &key)
|
||||||
|
|
||||||
|
// event already handled by state change
|
||||||
|
if conn.State != TCPStateTimeWait {
|
||||||
|
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -360,10 +428,6 @@ func (t *TCPTracker) Close() {
|
|||||||
|
|
||||||
// Clean up all remaining IPs
|
// Clean up all remaining IPs
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
@ -381,3 +445,16 @@ func isValidFlagCombination(flags uint8) bool {
|
|||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TCPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *TCPConnTrack) {
|
||||||
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: conn.FlowId,
|
||||||
|
Type: typ,
|
||||||
|
Direction: conn.Direction,
|
||||||
|
Protocol: nftypes.TCP,
|
||||||
|
SourceIP: key.SrcIP,
|
||||||
|
DestIP: key.DstIP,
|
||||||
|
SourcePort: key.SrcPort,
|
||||||
|
DestPort: key.DstPort,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestTCPStateMachine(t *testing.T) {
|
func TestTCPStateMachine(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := net.ParseIP("100.64.0.1")
|
||||||
@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
tt.test(t)
|
tt.test(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRSTHandling(t *testing.T) {
|
func TestRSTHandling(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := net.ParseIP("100.64.0.1")
|
||||||
@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
|
|||||||
|
|
||||||
func BenchmarkTCPTracker(b *testing.B) {
|
func BenchmarkTCPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
// Benchmark connection cleanup
|
// Benchmark connection cleanup
|
||||||
func BenchmarkCleanup(b *testing.B) {
|
func BenchmarkCleanup(b *testing.B) {
|
||||||
b.Run("TCPCleanup", func(b *testing.B) {
|
b.Run("TCPCleanup", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
|
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Pre-populate with expired connections
|
// Pre-populate with expired connections
|
||||||
|
@ -5,7 +5,10 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -28,11 +31,11 @@ type UDPTracker struct {
|
|||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
ipPool *PreallocatedIPs
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUDPTracker creates a new UDP connection tracker
|
// NewUDPTracker creates a new UDP connection tracker
|
||||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultUDPTimeout
|
timeout = DefaultUDPTimeout
|
||||||
}
|
}
|
||||||
@ -43,7 +46,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
|||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
ipPool: NewPreallocatedIPs(),
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine()
|
go tracker.cleanupRoutine()
|
||||||
@ -52,32 +55,57 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
|||||||
|
|
||||||
// TrackOutbound records an outbound UDP connection
|
// TrackOutbound records an outbound UDP connection
|
||||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||||
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
|
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackInbound records an inbound UDP connection
|
||||||
|
func (t *UDPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||||
|
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) (ConnKey, bool) {
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
if !exists {
|
t.mutex.RUnlock()
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &UDPConnTrack{
|
if exists {
|
||||||
BaseConnTrack: BaseConnTrack{
|
|
||||||
SourceIP: srcIPCopy,
|
|
||||||
DestIP: dstIPCopy,
|
|
||||||
SourcePort: srcPort,
|
|
||||||
DestPort: dstPort,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
conn.UpdateLastSeen()
|
conn.UpdateLastSeen()
|
||||||
t.connections[key] = conn
|
return key, true
|
||||||
|
|
||||||
t.logger.Trace("New UDP connection: %v", conn)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
|
func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, direction nftypes.Direction) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &UDPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: key.SrcIP,
|
||||||
|
DestIP: key.DstIP,
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
conn.UpdateLastSeen()
|
t.logger.Trace("New %s UDP connection: %s", direction, key)
|
||||||
|
t.sendEvent(nftypes.TypeStart, key, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||||
@ -88,18 +116,13 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
conn.UpdateLastSeen()
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
return true
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
|
||||||
conn.DestPort == srcPort &&
|
|
||||||
conn.SourcePort == dstPort
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupRoutine periodically removes stale connections
|
// cleanupRoutine periodically removes stale connections
|
||||||
@ -120,11 +143,10 @@ func (t *UDPTracker) cleanup() {
|
|||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
|
t.logger.Trace("Removed UDP connection %s (timeout)", key)
|
||||||
|
t.sendEvent(nftypes.TypeEnd, key, conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -135,10 +157,6 @@ func (t *UDPTracker) Close() {
|
|||||||
close(t.done)
|
close(t.done)
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
@ -150,14 +168,23 @@ func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, d
|
|||||||
|
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
if !exists {
|
return conn, exists
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return conn, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Timeout returns the configured timeout duration for the tracker
|
// Timeout returns the configured timeout duration for the tracker
|
||||||
func (t *UDPTracker) Timeout() time.Duration {
|
func (t *UDPTracker) Timeout() time.Duration {
|
||||||
return t.timeout
|
return t.timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *UDPConnTrack) {
|
||||||
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: conn.FlowId,
|
||||||
|
Type: typ,
|
||||||
|
Direction: conn.Direction,
|
||||||
|
Protocol: nftypes.UDP,
|
||||||
|
SourceIP: key.SrcIP,
|
||||||
|
DestIP: key.DstIP,
|
||||||
|
SourcePort: key.SrcPort,
|
||||||
|
DestPort: key.DstPort,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -29,7 +30,7 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tracker := NewUDPTracker(tt.timeout, logger)
|
tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
|
||||||
assert.NotNil(t, tracker)
|
assert.NotNil(t, tracker)
|
||||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||||
assert.NotNil(t, tracker.connections)
|
assert.NotNil(t, tracker.connections)
|
||||||
@ -40,29 +41,29 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||||
dstIP := net.ParseIP("192.168.1.3")
|
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
tracker.TrackOutbound(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := makeConnKey(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
|
||||||
conn, exists := tracker.connections[key]
|
conn, exists := tracker.connections[key]
|
||||||
require.True(t, exists)
|
require.True(t, exists)
|
||||||
assert.True(t, conn.SourceIP.Equal(srcIP))
|
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
|
||||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
assert.True(t, conn.DestIP.Compare(dstIP) == 0)
|
||||||
assert.Equal(t, srcPort, conn.SourcePort)
|
assert.Equal(t, srcPort, conn.SourcePort)
|
||||||
assert.Equal(t, dstPort, conn.DestPort)
|
assert.Equal(t, dstPort, conn.DestPort)
|
||||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(1*time.Second, logger)
|
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := net.ParseIP("192.168.1.2")
|
||||||
@ -160,8 +161,8 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
ipPool: NewPreallocatedIPs(),
|
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
@ -211,7 +212,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkUDPTracker(b *testing.B) {
|
func BenchmarkUDPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
@ -224,7 +225,7 @@ func BenchmarkUDPTracker(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := net.ParseIP("192.168.1.1")
|
||||||
|
@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -29,6 +30,7 @@ const (
|
|||||||
|
|
||||||
type Forwarder struct {
|
type Forwarder struct {
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
stack *stack.Stack
|
stack *stack.Stack
|
||||||
endpoint *endpoint
|
endpoint *endpoint
|
||||||
udpForwarder *udpForwarder
|
udpForwarder *udpForwarder
|
||||||
@ -38,7 +40,7 @@ type Forwarder struct {
|
|||||||
netstack bool
|
netstack bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
|
||||||
s := stack.New(stack.Options{
|
s := stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
@ -102,9 +104,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwar
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
f := &Forwarder{
|
f := &Forwarder{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
stack: s,
|
stack: s,
|
||||||
endpoint: endpoint,
|
endpoint: endpoint,
|
||||||
udpForwarder: newUDPForwarder(mtu, logger),
|
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
netstack: netstack,
|
netstack: netstack,
|
||||||
|
@ -3,14 +3,21 @@ package forwarder
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleICMP handles ICMP packets from the network stack
|
// handleICMP handles ICMP packets from the network stack
|
||||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||||
|
flowID := uuid.New()
|
||||||
|
f.sendICMPEvent(nftypes.TypeStart, flowID, id)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@ -20,6 +27,8 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
||||||
|
|
||||||
|
f.sendICMPEvent(nftypes.TypeEnd, flowID, id)
|
||||||
|
|
||||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -27,6 +36,8 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
if err := conn.Close(); err != nil {
|
if err := conn.Close(); err != nil {
|
||||||
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
f.sendICMPEvent(nftypes.TypeEnd, flowID, id)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||||
@ -101,9 +112,25 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
|
|||||||
|
|
||||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sendICMPEvent stores flow events for ICMP packets
|
||||||
|
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
|
||||||
|
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: 1,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
SourcePort: id.LocalPort,
|
||||||
|
DestPort: id.RemotePort,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -5,18 +5,25 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
"gvisor.dev/gvisor/pkg/waiter"
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleTCP is called by the TCP forwarder for new connections.
|
// handleTCP is called by the TCP forwarder for new connections.
|
||||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||||
id := r.ID()
|
id := r.ID()
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
f.sendTCPEvent(nftypes.TypeStart, flowID, id)
|
||||||
|
|
||||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
|
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||||
@ -46,10 +53,10 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
|
|
||||||
f.logger.Trace("forwarder: established TCP connection %v", id)
|
f.logger.Trace("forwarder: established TCP connection %v", id)
|
||||||
|
|
||||||
go f.proxyTCP(id, inConn, outConn, ep)
|
go f.proxyTCP(id, inConn, outConn, ep, flowID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
|
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := inConn.Close(); err != nil {
|
if err := inConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||||
@ -58,6 +65,8 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
}
|
}
|
||||||
ep.Close()
|
ep.Close()
|
||||||
|
|
||||||
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Create context for managing the proxy goroutines
|
// Create context for managing the proxy goroutines
|
||||||
@ -88,3 +97,18 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
|
||||||
|
|
||||||
|
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: 6,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
SourcePort: id.LocalPort,
|
||||||
|
DestPort: id.RemotePort,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -5,10 +5,12 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
@ -16,6 +18,7 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/waiter"
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -28,15 +31,17 @@ type udpPacketConn struct {
|
|||||||
lastSeen atomic.Int64
|
lastSeen atomic.Int64
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
ep tcpip.Endpoint
|
ep tcpip.Endpoint
|
||||||
|
flowID uuid.UUID
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpForwarder struct {
|
type udpForwarder struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
flowLogger nftypes.FlowLogger
|
||||||
bufPool sync.Pool
|
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||||
ctx context.Context
|
bufPool sync.Pool
|
||||||
cancel context.CancelFunc
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
type idleConn struct {
|
type idleConn struct {
|
||||||
@ -44,13 +49,14 @@ type idleConn struct {
|
|||||||
conn *udpPacketConn
|
conn *udpPacketConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
|
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
f := &udpForwarder{
|
f := &udpForwarder{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
flowLogger: flowLogger,
|
||||||
ctx: ctx,
|
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||||
cancel: cancel,
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
b := make([]byte, mtu)
|
b := make([]byte, mtu)
|
||||||
@ -83,6 +89,21 @@ func (f *udpForwarder) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sendUDPEvent stores flow events for UDP connections
|
||||||
|
func (f *udpForwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
|
||||||
|
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: 17,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
SourcePort: id.LocalPort,
|
||||||
|
DestPort: id.RemotePort,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// cleanup periodically removes idle UDP connections
|
// cleanup periodically removes idle UDP connections
|
||||||
func (f *udpForwarder) cleanup() {
|
func (f *udpForwarder) cleanup() {
|
||||||
ticker := time.NewTicker(time.Minute)
|
ticker := time.NewTicker(time.Minute)
|
||||||
@ -119,6 +140,8 @@ func (f *udpForwarder) cleanup() {
|
|||||||
f.Unlock()
|
f.Unlock()
|
||||||
|
|
||||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
||||||
|
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, idle.conn.flowID, idle.id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -141,10 +164,14 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
f.sendUDPEvent(nftypes.TypeStart, flowID, id)
|
||||||
|
|
||||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
|
||||||
// TODO: Send ICMP error message
|
// TODO: Send ICMP error message
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -157,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
if err := outConn.Close(); err != nil {
|
if err := outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
}
|
}
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,6 +196,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
outConn: outConn,
|
outConn: outConn,
|
||||||
cancel: connCancel,
|
cancel: connCancel,
|
||||||
ep: ep,
|
ep: ep,
|
||||||
|
flowID: flowID,
|
||||||
}
|
}
|
||||||
pConn.updateLastSeen()
|
pConn.updateLastSeen()
|
||||||
|
|
||||||
@ -182,6 +211,8 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
if err := outConn.Close(); err != nil {
|
if err := outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f.udpForwarder.conns[id] = pConn
|
f.udpForwarder.conns[id] = pConn
|
||||||
@ -206,6 +237,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
f.udpForwarder.Lock()
|
f.udpForwarder.Lock()
|
||||||
delete(f.udpForwarder.conns, id)
|
delete(f.udpForwarder.conns, id)
|
||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
errChan := make(chan error, 2)
|
errChan := make(chan error, 2)
|
||||||
@ -231,6 +264,21 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sendUDPEvent stores flow events for UDP connections, mirrors the TCP version
|
||||||
|
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
|
||||||
|
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: 17, // UDP protocol number
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
SourcePort: id.LocalPort,
|
||||||
|
DestPort: id.RemotePort,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (c *udpPacketConn) updateLastSeen() {
|
func (c *udpPacketConn) updateLastSeen() {
|
||||||
c.lastSeen.Store(time.Now().UnixNano())
|
c.lastSeen.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
// Package log provides a high-performance, non-blocking logger for userspace networking
|
||||||
package log
|
package log
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -13,13 +13,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxBatchSize = 1024 * 16 // 16KB max batch size
|
maxBatchSize = 1024 * 16
|
||||||
maxMessageSize = 1024 * 2 // 2KB per message
|
maxMessageSize = 1024 * 2
|
||||||
bufferSize = 1024 * 256 // 256KB ring buffer
|
|
||||||
defaultFlushInterval = 2 * time.Second
|
defaultFlushInterval = 2 * time.Second
|
||||||
|
logChannelSize = 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
// Level represents log severity
|
|
||||||
type Level uint32
|
type Level uint32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -42,32 +41,37 @@ var levelStrings = map[Level]string{
|
|||||||
LevelTrace: "TRAC",
|
LevelTrace: "TRAC",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logger is a high-performance, non-blocking logger
|
type logMessage struct {
|
||||||
type Logger struct {
|
level Level
|
||||||
output io.Writer
|
format string
|
||||||
level atomic.Uint32
|
args []any
|
||||||
buffer *ringBuffer
|
|
||||||
shutdown chan struct{}
|
|
||||||
closeOnce sync.Once
|
|
||||||
wg sync.WaitGroup
|
|
||||||
|
|
||||||
// Reusable buffer pool for formatting messages
|
|
||||||
bufPool sync.Pool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Logger is a high-performance, non-blocking logger
|
||||||
|
type Logger struct {
|
||||||
|
output io.Writer
|
||||||
|
level atomic.Uint32
|
||||||
|
msgChannel chan logMessage
|
||||||
|
shutdown chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
wg sync.WaitGroup
|
||||||
|
bufPool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
|
||||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
l := &Logger{
|
l := &Logger{
|
||||||
output: logrusLogger.Out,
|
output: logrusLogger.Out,
|
||||||
buffer: newRingBuffer(bufferSize),
|
msgChannel: make(chan logMessage, logChannelSize),
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() any {
|
||||||
// Pre-allocate buffer for message formatting
|
|
||||||
b := make([]byte, 0, maxMessageSize)
|
b := make([]byte, 0, maxMessageSize)
|
||||||
return &b
|
return &b
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
logrusLevel := logrusLogger.GetLevel()
|
logrusLevel := logrusLogger.GetLevel()
|
||||||
l.level.Store(uint32(logrusLevel))
|
l.level.Store(uint32(logrusLevel))
|
||||||
level := levelStrings[Level(logrusLevel)]
|
level := levelStrings[Level(logrusLevel)]
|
||||||
@ -79,97 +83,149 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
|||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLevel sets the logging level
|
||||||
func (l *Logger) SetLevel(level Level) {
|
func (l *Logger) SetLevel(level Level) {
|
||||||
l.level.Store(uint32(level))
|
l.level.Store(uint32(level))
|
||||||
|
|
||||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
func (l *Logger) log(level Level, format string, args ...any) {
|
||||||
*buf = (*buf)[:0]
|
select {
|
||||||
|
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
|
||||||
// Timestamp
|
default:
|
||||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
|
||||||
*buf = append(*buf, ' ')
|
|
||||||
|
|
||||||
// Level
|
|
||||||
*buf = append(*buf, levelStrings[level]...)
|
|
||||||
*buf = append(*buf, ' ')
|
|
||||||
|
|
||||||
// Message
|
|
||||||
if len(args) > 0 {
|
|
||||||
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
|
||||||
} else {
|
|
||||||
*buf = append(*buf, format...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*buf = append(*buf, '\n')
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
// Error logs a message at error level
|
||||||
bufp := l.bufPool.Get().(*[]byte)
|
func (l *Logger) Error(format string, args ...any) {
|
||||||
l.formatMessage(bufp, level, format, args...)
|
|
||||||
|
|
||||||
if len(*bufp) > maxMessageSize {
|
|
||||||
*bufp = (*bufp)[:maxMessageSize]
|
|
||||||
}
|
|
||||||
_, _ = l.buffer.Write(*bufp)
|
|
||||||
|
|
||||||
l.bufPool.Put(bufp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Error(format string, args ...interface{}) {
|
|
||||||
if l.level.Load() >= uint32(LevelError) {
|
if l.level.Load() >= uint32(LevelError) {
|
||||||
l.log(LevelError, format, args...)
|
l.log(LevelError, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
// Warn logs a message at warning level
|
||||||
|
func (l *Logger) Warn(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelWarn) {
|
if l.level.Load() >= uint32(LevelWarn) {
|
||||||
l.log(LevelWarn, format, args...)
|
l.log(LevelWarn, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Info(format string, args ...interface{}) {
|
// Info logs a message at info level
|
||||||
|
func (l *Logger) Info(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelInfo) {
|
if l.level.Load() >= uint32(LevelInfo) {
|
||||||
l.log(LevelInfo, format, args...)
|
l.log(LevelInfo, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
// Debug logs a message at debug level
|
||||||
|
func (l *Logger) Debug(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelDebug) {
|
if l.level.Load() >= uint32(LevelDebug) {
|
||||||
l.log(LevelDebug, format, args...)
|
l.log(LevelDebug, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Trace(format string, args ...interface{}) {
|
// Trace logs a message at trace level
|
||||||
|
func (l *Logger) Trace(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelTrace) {
|
if l.level.Load() >= uint32(LevelTrace) {
|
||||||
l.log(LevelTrace, format, args...)
|
l.log(LevelTrace, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// worker periodically flushes the buffer
|
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
|
||||||
|
*buf = (*buf)[:0]
|
||||||
|
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05.000-07:00")
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
*buf = append(*buf, levelStrings[level]...)
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
|
||||||
|
var msg string
|
||||||
|
if len(args) > 0 {
|
||||||
|
msg = fmt.Sprintf(format, args...)
|
||||||
|
} else {
|
||||||
|
msg = format
|
||||||
|
}
|
||||||
|
*buf = append(*buf, msg...)
|
||||||
|
*buf = append(*buf, '\n')
|
||||||
|
|
||||||
|
if len(*buf) > maxMessageSize {
|
||||||
|
*buf = (*buf)[:maxMessageSize]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processMessage handles a single log message and adds it to the buffer
|
||||||
|
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
|
||||||
|
bufp := l.bufPool.Get().(*[]byte)
|
||||||
|
defer l.bufPool.Put(bufp)
|
||||||
|
|
||||||
|
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
|
||||||
|
|
||||||
|
if len(*buffer)+len(*bufp) > maxBatchSize {
|
||||||
|
_, _ = l.output.Write(*buffer)
|
||||||
|
*buffer = (*buffer)[:0]
|
||||||
|
}
|
||||||
|
*buffer = append(*buffer, *bufp...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// flushBuffer writes the accumulated buffer to output
|
||||||
|
func (l *Logger) flushBuffer(buffer *[]byte) {
|
||||||
|
if len(*buffer) > 0 {
|
||||||
|
_, _ = l.output.Write(*buffer)
|
||||||
|
*buffer = (*buffer)[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processBatch processes as many messages as possible without blocking
|
||||||
|
func (l *Logger) processBatch(buffer *[]byte) {
|
||||||
|
for len(*buffer) < maxBatchSize {
|
||||||
|
select {
|
||||||
|
case msg := <-l.msgChannel:
|
||||||
|
l.processMessage(msg, buffer)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleShutdown manages the graceful shutdown sequence with timeout
|
||||||
|
func (l *Logger) handleShutdown(buffer *[]byte) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msg := <-l.msgChannel:
|
||||||
|
l.processMessage(msg, buffer)
|
||||||
|
case <-ctx.Done():
|
||||||
|
l.flushBuffer(buffer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(l.msgChannel) == 0 {
|
||||||
|
l.flushBuffer(buffer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// worker is the main goroutine that processes log messages
|
||||||
func (l *Logger) worker() {
|
func (l *Logger) worker() {
|
||||||
defer l.wg.Done()
|
defer l.wg.Done()
|
||||||
|
|
||||||
ticker := time.NewTicker(defaultFlushInterval)
|
ticker := time.NewTicker(defaultFlushInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
buf := make([]byte, 0, maxBatchSize)
|
buffer := make([]byte, 0, maxBatchSize)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-l.shutdown:
|
case <-l.shutdown:
|
||||||
|
l.handleShutdown(&buffer)
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
// Read accumulated messages
|
l.flushBuffer(&buffer)
|
||||||
n, _ := l.buffer.Read(buf[:cap(buf)])
|
case msg := <-l.msgChannel:
|
||||||
if n == 0 {
|
l.processMessage(msg, &buffer)
|
||||||
continue
|
l.processBatch(&buffer)
|
||||||
}
|
|
||||||
|
|
||||||
// Write batch
|
|
||||||
_, _ = l.output.Write(buf[:n])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
121
client/firewall/uspfilter/log/log_test.go
Normal file
121
client/firewall/uspfilter/log/log_test.go
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
package log_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type discard struct{}
|
||||||
|
|
||||||
|
func (d *discard) Write(p []byte) (n int, err error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger(b *testing.B) {
|
||||||
|
simpleMessage := "Connection established"
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4 // TCPStateEstablished
|
||||||
|
|
||||||
|
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
|
||||||
|
protocol := "TCP"
|
||||||
|
direction := "outbound"
|
||||||
|
flags := uint16(0x18) // ACK + PSH
|
||||||
|
sequence := uint32(123456789)
|
||||||
|
acknowledged := uint32(987654321)
|
||||||
|
payloadSize := 1460
|
||||||
|
fragmented := false
|
||||||
|
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
|
||||||
|
|
||||||
|
b.Run("SimpleMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(simpleMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ConntrackMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ComplexMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLoggerParallel tests the logger under concurrent load
|
||||||
|
func BenchmarkLoggerParallel(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
|
||||||
|
func BenchmarkLoggerBurst(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
for j := 0; j < 100; j++ {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestLogger() *log.Logger {
|
||||||
|
logrusLogger := logrus.New()
|
||||||
|
logrusLogger.SetOutput(&discard{})
|
||||||
|
logrusLogger.SetLevel(logrus.TraceLevel)
|
||||||
|
return log.NewFromLogrus(logrusLogger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupLogger(logger *log.Logger) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_ = logger.Stop(ctx)
|
||||||
|
}
|
@ -1,85 +0,0 @@
|
|||||||
package log
|
|
||||||
|
|
||||||
import "sync"
|
|
||||||
|
|
||||||
// ringBuffer is a simple ring buffer implementation
|
|
||||||
type ringBuffer struct {
|
|
||||||
buf []byte
|
|
||||||
size int
|
|
||||||
r, w int64 // Read and write positions
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRingBuffer(size int) *ringBuffer {
|
|
||||||
return &ringBuffer{
|
|
||||||
buf: make([]byte, size),
|
|
||||||
size: size,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
|
||||||
if len(p) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
if len(p) > r.size {
|
|
||||||
p = p[:r.size]
|
|
||||||
}
|
|
||||||
|
|
||||||
n = len(p)
|
|
||||||
|
|
||||||
// Write data, handling wrap-around
|
|
||||||
pos := int(r.w % int64(r.size))
|
|
||||||
writeLen := min(len(p), r.size-pos)
|
|
||||||
copy(r.buf[pos:], p[:writeLen])
|
|
||||||
|
|
||||||
// If we have more data and need to wrap around
|
|
||||||
if writeLen < len(p) {
|
|
||||||
copy(r.buf, p[writeLen:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update write position
|
|
||||||
r.w += int64(n)
|
|
||||||
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
if r.w == r.r {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate available data accounting for wraparound
|
|
||||||
available := int(r.w - r.r)
|
|
||||||
if available < 0 {
|
|
||||||
available += r.size
|
|
||||||
}
|
|
||||||
available = min(available, r.size)
|
|
||||||
|
|
||||||
// Limit read to buffer size
|
|
||||||
toRead := min(available, len(p))
|
|
||||||
if toRead == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read data, handling wrap-around
|
|
||||||
pos := int(r.r % int64(r.size))
|
|
||||||
readLen := min(toRead, r.size-pos)
|
|
||||||
n = copy(p, r.buf[pos:pos+readLen])
|
|
||||||
|
|
||||||
// If we need more data and need to wrap around
|
|
||||||
if readLen < toRead {
|
|
||||||
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update read position
|
|
||||||
r.r += int64(n)
|
|
||||||
|
|
||||||
return n, nil
|
|
||||||
}
|
|
@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -96,6 +97,7 @@ type Manager struct {
|
|||||||
tcpTracker *conntrack.TCPTracker
|
tcpTracker *conntrack.TCPTracker
|
||||||
forwarder *forwarder.Forwarder
|
forwarder *forwarder.Forwarder
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@ -112,16 +114,16 @@ type decoder struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
|
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||||
return create(iface, nil, disableServerRoutes)
|
return create(iface, nil, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||||
if nativeFirewall == nil {
|
if nativeFirewall == nil {
|
||||||
return nil, errors.New("native firewall is nil")
|
return nil, errors.New("native firewall is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
|
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -148,7 +150,7 @@ func parseCreateEnv() (bool, bool) {
|
|||||||
return disableConntrack, enableLocalForwarding
|
return disableConntrack, enableLocalForwarding
|
||||||
}
|
}
|
||||||
|
|
||||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||||
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
@ -174,6 +176,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
routingEnabled: false,
|
routingEnabled: false,
|
||||||
stateful: !disableConntrack,
|
stateful: !disableConntrack,
|
||||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||||
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
}
|
}
|
||||||
@ -185,9 +188,9 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
if disableConntrack {
|
if disableConntrack {
|
||||||
log.Info("conntrack is disabled")
|
log.Info("conntrack is disabled")
|
||||||
} else {
|
} else {
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, flowLogger)
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
// netstack needs the forwarder for local traffic
|
// netstack needs the forwarder for local traffic
|
||||||
@ -304,7 +307,7 @@ func (m *Manager) initForwarder() error {
|
|||||||
return errors.New("forwarding not supported")
|
return errors.New("forwarding not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
|
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.routingEnabled = false
|
m.routingEnabled = false
|
||||||
return fmt.Errorf("create forwarder: %w", err)
|
return fmt.Errorf("create forwarder: %w", err)
|
||||||
@ -533,14 +536,7 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
|||||||
|
|
||||||
// Track all protocols if stateful mode is enabled
|
// Track all protocols if stateful mode is enabled
|
||||||
if m.stateful {
|
if m.stateful {
|
||||||
switch d.decoded[1] {
|
m.trackOutbound(d, srcIP, dstIP)
|
||||||
case layers.LayerTypeUDP:
|
|
||||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
|
||||||
case layers.LayerTypeTCP:
|
|
||||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
|
||||||
case layers.LayerTypeICMPv4:
|
|
||||||
m.trackICMPOutbound(d, srcIP, dstIP)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process UDP hooks even if stateful mode is disabled
|
// Process UDP hooks even if stateful mode is disabled
|
||||||
@ -562,17 +558,6 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
|
||||||
flags := getTCPFlags(&d.tcp)
|
|
||||||
m.tcpTracker.TrackOutbound(
|
|
||||||
srcIP,
|
|
||||||
dstIP,
|
|
||||||
uint16(d.tcp.SrcPort),
|
|
||||||
uint16(d.tcp.DstPort),
|
|
||||||
flags,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTCPFlags(tcp *layers.TCP) uint8 {
|
func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||||
var flags uint8
|
var flags uint8
|
||||||
if tcp.SYN {
|
if tcp.SYN {
|
||||||
@ -596,13 +581,34 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
|||||||
return flags
|
return flags
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||||
m.udpTracker.TrackOutbound(
|
transport := d.decoded[1]
|
||||||
srcIP,
|
switch transport {
|
||||||
dstIP,
|
case layers.LayerTypeUDP:
|
||||||
uint16(d.udp.SrcPort),
|
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort))
|
||||||
uint16(d.udp.DstPort),
|
case layers.LayerTypeTCP:
|
||||||
)
|
flags := getTCPFlags(&d.tcp)
|
||||||
|
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
|
||||||
|
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) {
|
||||||
|
transport := d.decoded[1]
|
||||||
|
switch transport {
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort))
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
flags := getTCPFlags(&d.tcp)
|
||||||
|
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
|
||||||
|
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
|
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
|
||||||
@ -618,17 +624,6 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
|
||||||
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
|
|
||||||
m.icmpTracker.TrackOutbound(
|
|
||||||
srcIP,
|
|
||||||
dstIP,
|
|
||||||
d.icmp4.Id,
|
|
||||||
d.icmp4.Seq,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// dropFilter implements filtering logic for incoming packets.
|
// dropFilter implements filtering logic for incoming packets.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) dropFilter(packetData []byte) bool {
|
func (m *Manager) dropFilter(packetData []byte) bool {
|
||||||
@ -675,6 +670,9 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData
|
|||||||
return m.handleNetstackLocalTraffic(packetData)
|
return m.handleNetstackLocalTraffic(packetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// track inbound packets to get the correct direction and session id for flows
|
||||||
|
m.trackInbound(d, srcIP, dstIP)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Create manager and basic setup
|
// Create manager and basic setup
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -203,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -251,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -450,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -577,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -668,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -787,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
@ -875,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Reset(nil))
|
||||||
})
|
})
|
||||||
|
@ -34,7 +34,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, manager)
|
require.NotNil(t, manager)
|
||||||
|
|
||||||
@ -302,7 +302,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(tb, manager.EnableRouting())
|
require.NoError(tb, manager.EnableRouting())
|
||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
require.NotNil(tb, manager)
|
require.NotNil(tb, manager)
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -18,9 +20,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
var flowLogger = netflow.NewManager(context.Background()).GetLogger()
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
@ -62,7 +66,7 @@ func TestManagerCreate(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -82,7 +86,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -116,7 +120,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -187,7 +191,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
@ -236,7 +240,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -279,7 +283,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -347,7 +351,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// creating manager instance
|
// creating manager instance
|
||||||
manager, err := Create(iface, false)
|
manager, err := Create(iface, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
@ -393,7 +397,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
func TestProcessOutgoingHooks(t *testing.T) {
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@ -401,7 +405,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(16, 32),
|
Mask: net.CIDRMask(16, 32),
|
||||||
}
|
}
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Reset(nil))
|
require.NoError(t, manager.Reset(nil))
|
||||||
}()
|
}()
|
||||||
@ -479,7 +483,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
@ -506,7 +510,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@ -515,7 +519,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
|
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||||
manager.decoders = sync.Pool{
|
manager.decoders = sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
d := &decoder{
|
d := &decoder{
|
||||||
@ -534,8 +538,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Set up packet parameters
|
// Set up packet parameters
|
||||||
srcIP := net.ParseIP("100.10.0.1")
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
dstIP := net.ParseIP("100.10.0.100")
|
dstIP := netip.MustParseAddr("100.10.0.100")
|
||||||
srcPort := uint16(51334)
|
srcPort := uint16(51334)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
@ -543,8 +547,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
outboundIPv4 := &layers.IPv4{
|
outboundIPv4 := &layers.IPv4{
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Version: 4,
|
Version: 4,
|
||||||
SrcIP: srcIP,
|
SrcIP: srcIP.AsSlice(),
|
||||||
DstIP: dstIP,
|
DstIP: dstIP.AsSlice(),
|
||||||
Protocol: layers.IPProtocolUDP,
|
Protocol: layers.IPProtocolUDP,
|
||||||
}
|
}
|
||||||
outboundUDP := &layers.UDP{
|
outboundUDP := &layers.UDP{
|
||||||
@ -573,11 +577,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
|
||||||
|
|
||||||
require.True(t, exists, "Connection should be tracked after outbound packet")
|
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
|
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
|
||||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
|
require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
|
||||||
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
||||||
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
||||||
|
|
||||||
@ -585,8 +589,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
inboundIPv4 := &layers.IPv4{
|
inboundIPv4 := &layers.IPv4{
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Version: 4,
|
Version: 4,
|
||||||
SrcIP: dstIP, // Original destination is now source
|
SrcIP: dstIP.AsSlice(), // Original destination is now source
|
||||||
DstIP: srcIP, // Original source is now destination
|
DstIP: srcIP.AsSlice(), // Original source is now destination
|
||||||
Protocol: layers.IPProtocolUDP,
|
Protocol: layers.IPProtocolUDP,
|
||||||
}
|
}
|
||||||
inboundUDP := &layers.UDP{
|
inboundUDP := &layers.UDP{
|
||||||
@ -641,7 +645,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
|
|
||||||
// If the connection should still be valid, verify it exists
|
// If the connection should still be valid, verify it exists
|
||||||
if cp.shouldAllow {
|
if cp.shouldAllow {
|
||||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
|
||||||
require.True(t, exists, "Connection should still exist during valid window")
|
require.True(t, exists, "Connection should still exist during valid window")
|
||||||
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
|
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
|
||||||
"LastSeen should be updated for valid responses")
|
"LastSeen should be updated for valid responses")
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -10,9 +11,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var flowLogger = netflow.NewManager(context.Background()).GetLogger()
|
||||||
|
|
||||||
func TestDefaultManager(t *testing.T) {
|
func TestDefaultManager(t *testing.T) {
|
||||||
networkMap := &mgmProto.NetworkMap{
|
networkMap := &mgmProto.NetworkMap{
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
@ -52,7 +56,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
@ -346,7 +350,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
@ -29,6 +30,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var flowLogger = netflow.NewManager(context.Background()).GetLogger()
|
||||||
|
|
||||||
type mocWGIface struct {
|
type mocWGIface struct {
|
||||||
filter device.PacketFilter
|
filter device.PacketFilter
|
||||||
}
|
}
|
||||||
@ -916,7 +919,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pf, err := uspfilter.Create(wgIface, false)
|
pf, err := uspfilter.Create(wgIface, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create uspfilter: %v", err)
|
t.Fatalf("failed to create uspfilter: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -35,7 +35,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
@ -191,7 +191,7 @@ type Engine struct {
|
|||||||
persistNetworkMap bool
|
persistNetworkMap bool
|
||||||
latestNetworkMap *mgmProto.NetworkMap
|
latestNetworkMap *mgmProto.NetworkMap
|
||||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||||
flowManager types.FlowManager
|
flowManager nftypes.FlowManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@ -454,7 +454,7 @@ func (e *Engine) createFirewall() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
|
||||||
if err != nil || e.firewall == nil {
|
if err != nil || e.firewall == nil {
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
log.Errorf("failed creating firewall manager: %s", err)
|
||||||
return nil
|
return nil
|
||||||
@ -721,11 +721,11 @@ func (e *Engine) handleFlowUpdate(config *mgmProto.FlowConfig) error {
|
|||||||
return e.flowManager.Update(flowConfig)
|
return e.flowManager.Update(flowConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*types.FlowConfig, error) {
|
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*nftypes.FlowConfig, error) {
|
||||||
if config.GetInterval() == nil {
|
if config.GetInterval() == nil {
|
||||||
return nil, errors.New("flow interval is nil")
|
return nil, errors.New("flow interval is nil")
|
||||||
}
|
}
|
||||||
return &types.FlowConfig{
|
return &nftypes.FlowConfig{
|
||||||
Enabled: config.GetEnabled(),
|
Enabled: config.GetEnabled(),
|
||||||
URL: config.GetUrl(),
|
URL: config.GetUrl(),
|
||||||
TokenPayload: config.GetTokenPayload(),
|
TokenPayload: config.GetTokenPayload(),
|
||||||
|
@ -7,17 +7,52 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Protocol uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProtocolUnknown = 0
|
||||||
|
ICMP = 1
|
||||||
|
TCP = 6
|
||||||
|
UDP = 17
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p Protocol) String() string {
|
||||||
|
switch p {
|
||||||
|
case 1:
|
||||||
|
return "ICMP"
|
||||||
|
case 6:
|
||||||
|
return "TCP"
|
||||||
|
case 17:
|
||||||
|
return "UDP"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Type int
|
type Type int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TypeStart = iota
|
TypeUnknown = iota
|
||||||
|
TypeStart
|
||||||
TypeEnd
|
TypeEnd
|
||||||
)
|
)
|
||||||
|
|
||||||
type Direction int
|
type Direction int
|
||||||
|
|
||||||
|
func (d Direction) String() string {
|
||||||
|
switch d {
|
||||||
|
case Ingress:
|
||||||
|
return "ingress"
|
||||||
|
case Egress:
|
||||||
|
return "egress"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Ingress = iota
|
DirectionUnknown = iota
|
||||||
|
Ingress
|
||||||
Egress
|
Egress
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user