Fix tests

This commit is contained in:
Viktor Liu 2025-01-02 23:34:43 +01:00
parent f772a21f37
commit 0b9854b2b1
5 changed files with 36 additions and 16 deletions

View File

@ -3,8 +3,14 @@ package conntrack
import ( import (
"net" "net"
"testing" "testing"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger())
func BenchmarkIPOperations(b *testing.B) { func BenchmarkIPOperations(b *testing.B) {
b.Run("MakeIPAddr", func(b *testing.B) { b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1") ip := net.ParseIP("192.168.1.1")

View File

@ -7,7 +7,7 @@ import (
func BenchmarkICMPTracker(b *testing.B) { func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, nil) tracker := NewICMPTracker(DefaultICMPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, nil) tracker := NewICMPTracker(DefaultICMPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")

View File

@ -9,7 +9,7 @@ import (
) )
func TestTCPStateMachine(t *testing.T) { func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := net.ParseIP("100.64.0.1")
@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Helper() t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout, nil) tracker = NewTCPTracker(DefaultTCPTimeout, logger)
tt.test(t) tt.test(t)
}) })
} }
@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
} }
func TestRSTHandling(t *testing.T) { func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := net.ParseIP("100.64.0.1")
@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
func BenchmarkTCPTracker(b *testing.B) { func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
}) })
b.Run("ConcurrentAccess", func(b *testing.B) { b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, nil) tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
// Benchmark connection cleanup // Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) { func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) { b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, nil) // Short timeout for testing tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
defer tracker.Close() defer tracker.Close()
// Pre-populate with expired connections // Pre-populate with expired connections

View File

@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout, nil) tracker := NewUDPTracker(tt.timeout, logger)
assert.NotNil(t, tracker) assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections) assert.NotNil(t, tracker.connections)
@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
} }
func TestUDPTracker_TrackOutbound(t *testing.T) { func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout, nil) tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := net.ParseIP("192.168.1.2")
@ -63,7 +63,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
} }
func TestUDPTracker_IsValidInbound(t *testing.T) { func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, nil) tracker := NewUDPTracker(1*time.Second, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := net.ParseIP("192.168.1.2")
@ -162,6 +162,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
cleanupTicker: time.NewTicker(cleanupInterval), cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}), done: make(chan struct{}),
ipPool: NewPreallocatedIPs(), ipPool: NewPreallocatedIPs(),
logger: logger,
} }
// Start cleanup routine // Start cleanup routine
@ -211,7 +212,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) { func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, nil) tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -224,7 +225,7 @@ func BenchmarkUDPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, nil) tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")

View File

@ -9,15 +9,19 @@ import (
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger())
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
AddressFunc func() iface.WGAddress AddressFunc func() iface.WGAddress
@ -284,6 +288,15 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) { func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@ -409,7 +422,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32), Mask: net.CIDRMask(16, 32),
} }
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, nil) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
defer func() { defer func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Reset(nil))
}() }()
@ -527,7 +540,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, nil) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
New: func() any { New: func() any {
d := &decoder{ d := &decoder{