From 0b9854b2b108782674d6614e537b22e5f5f5a2b3 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 2 Jan 2025 23:34:43 +0100 Subject: [PATCH] Fix tests --- .../firewall/uspfilter/conntrack/common_test.go | 6 ++++++ .../firewall/uspfilter/conntrack/icmp_test.go | 4 ++-- client/firewall/uspfilter/conntrack/tcp_test.go | 14 +++++++------- client/firewall/uspfilter/conntrack/udp_test.go | 11 ++++++----- client/firewall/uspfilter/uspfilter_test.go | 17 +++++++++++++++-- 5 files changed, 36 insertions(+), 16 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index b885470a3..79c1f5d8f 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -3,8 +3,14 @@ package conntrack import ( "net" "testing" + + "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) +var logger = log.NewFromLogrus(logrus.StandardLogger()) + func BenchmarkIPOperations(b *testing.B) { b.Run("MakeIPAddr", func(b *testing.B) { ip := net.ParseIP("192.168.1.1") diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go index e653416f9..32553c836 100644 --- a/client/firewall/uspfilter/conntrack/icmp_test.go +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -7,7 +7,7 @@ import ( func BenchmarkICMPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewICMPTracker(DefaultICMPTimeout, nil) + tracker := NewICMPTracker(DefaultICMPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewICMPTracker(DefaultICMPTimeout, nil) + tracker := NewICMPTracker(DefaultICMPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index c44e7dfa7..5f4c43915 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -9,7 +9,7 @@ import ( ) func TestTCPStateMachine(t *testing.T) { - tracker := NewTCPTracker(DefaultTCPTimeout, nil) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("100.64.0.1") @@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Helper() - tracker = NewTCPTracker(DefaultTCPTimeout, nil) + tracker = NewTCPTracker(DefaultTCPTimeout, logger) tt.test(t) }) } @@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) { } func TestRSTHandling(t *testing.T) { - tracker := NewTCPTracker(DefaultTCPTimeout, nil) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("100.64.0.1") @@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, func BenchmarkTCPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout, nil) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout, nil) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) { }) b.Run("ConcurrentAccess", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout, nil) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) { // Benchmark connection cleanup func BenchmarkCleanup(b *testing.B) { b.Run("TCPCleanup", func(b *testing.B) { - tracker := NewTCPTracker(100*time.Millisecond, nil) // Short timeout for testing + tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing defer tracker.Close() // Pre-populate with expired connections diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index 4e42c484f..ab4e26d4f 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tracker := NewUDPTracker(tt.timeout, nil) + tracker := NewUDPTracker(tt.timeout, logger) assert.NotNil(t, tracker) assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.NotNil(t, tracker.connections) @@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) { } func TestUDPTracker_TrackOutbound(t *testing.T) { - tracker := NewUDPTracker(DefaultUDPTimeout, nil) + tracker := NewUDPTracker(DefaultUDPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.2") @@ -63,7 +63,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) { } func TestUDPTracker_IsValidInbound(t *testing.T) { - tracker := NewUDPTracker(1*time.Second, nil) + tracker := NewUDPTracker(1*time.Second, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.2") @@ -162,6 +162,7 @@ func TestUDPTracker_Cleanup(t *testing.T) { cleanupTicker: time.NewTicker(cleanupInterval), done: make(chan struct{}), ipPool: NewPreallocatedIPs(), + logger: logger, } // Start cleanup routine @@ -211,7 +212,7 @@ func TestUDPTracker_Cleanup(t *testing.T) { func BenchmarkUDPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout, nil) + tracker := NewUDPTracker(DefaultUDPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -224,7 +225,7 @@ func BenchmarkUDPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout, nil) + tracker := NewUDPTracker(DefaultUDPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 14a7efe77..a0c0d23c3 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -9,15 +9,19 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" wgdevice "golang.zx2c4.com/wireguard/device" fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) +var logger = log.NewFromLogrus(logrus.StandardLogger()) + type IFaceMock struct { SetFilterFunc func(device.PacketFilter) error AddressFunc func() iface.WGAddress @@ -284,6 +288,15 @@ func TestManagerReset(t *testing.T) { func TestNotMatchByIP(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("100.10.0.100"), + Network: &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + }, + } + }, } m, err := Create(ifaceMock) @@ -409,7 +422,7 @@ func TestProcessOutgoingHooks(t *testing.T) { Mask: net.CIDRMask(16, 32), } manager.udpTracker.Close() - manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, nil) + manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger) defer func() { require.NoError(t, manager.Reset(nil)) }() @@ -527,7 +540,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { } manager.udpTracker.Close() // Close the existing tracker - manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, nil) + manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger) manager.decoders = sync.Pool{ New: func() any { d := &decoder{