diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index cc5edc554..53350797c 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -265,8 +265,10 @@ func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder return trace } - if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) { - return trace + if m.localipmanager.IsLocalIP(dstIP) { + if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) { + return trace + } } if !m.handleRouting(trace) { @@ -310,32 +312,40 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string { } func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { - if !m.localForwarding { - trace.AddResult(StageRouting, "Local forwarding disabled", false) - trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false) - return true - } trace.AddResult(StageRouting, "Packet destined for local delivery", true) ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) - strRuleId := "implicit" + strRuleId := "" if ruleId != nil { strRuleId = string(ruleId) } - msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId) if blocked { msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId) + trace.AddResult(StagePeerACL, msg, false) + trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false) + return true } - trace.AddResult(StagePeerACL, msg, !blocked) + trace.AddResult(StagePeerACL, msg, true) + // Handle netstack mode if m.netstack { - m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked) + switch { + case !m.localForwarding: + trace.AddResult(StageCompleted, "Packet sent to virtual stack", true) + case m.forwarder.Load() != nil: + m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true) + trace.AddResult(StageCompleted, msgProcessingCompleted, true) + default: + trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false) + } + return true } - trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked) + // In normal mode, packets are allowed through for local delivery + trace.AddResult(StageCompleted, msgProcessingCompleted, true) return true } @@ -363,7 +373,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n strId := string(id) if id == nil { - strId = "implicit" + strId = "" } msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId) diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go new file mode 100644 index 000000000..19ddceda7 --- /dev/null +++ b/client/firewall/uspfilter/tracer_test.go @@ -0,0 +1,440 @@ +package uspfilter + +import ( + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/require" + + fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) { + t.Logf("Trace results: %v", trace.Results) + actualStages := make([]PacketStage, 0, len(trace.Results)) + for _, result := range trace.Results { + actualStages = append(actualStages, result.Stage) + t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed) + } + + require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages") +} + +func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) { + require.NotEmpty(t, trace.Results, "Trace should have results") + lastResult := trace.Results[len(trace.Results)-1] + require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'") + require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect") +} + +func TestTracePacket(t *testing.T) { + setupTracerTest := func(statefulMode bool) *Manager { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("100.10.0.100"), + Network: &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + }, + } + }, + } + + m, err := Create(ifaceMock, false, flowLogger) + require.NoError(t, err) + + if !statefulMode { + m.stateful = false + } + + return m + } + + createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder { + builder := &PacketBuilder{ + SrcIP: netip.MustParseAddr(srcIP), + DstIP: netip.MustParseAddr(dstIP), + Protocol: protocol, + SrcPort: srcPort, + DstPort: dstPort, + Direction: direction, + } + + if protocol == "tcp" { + builder.TCPState = &TCPState{SYN: true} + } + + return builder + } + + createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder { + return &PacketBuilder{ + SrcIP: netip.MustParseAddr(srcIP), + DstIP: netip.MustParseAddr(dstIP), + Protocol: "icmp", + ICMPType: icmpType, + ICMPCode: icmpCode, + Direction: direction, + } + } + + testCases := []struct { + name string + setup func(*Manager) + packetBuilder func() *PacketBuilder + expectedStages []PacketStage + expectedAllow bool + }{ + { + name: "LocalTraffic_ACLAllowed", + setup: func(m *Manager) { + ip := net.ParseIP("1.1.1.1") + proto := fw.ProtocolTCP + port := &fw.Port{Values: []uint16{80}} + action := fw.ActionAccept + _, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "") + require.NoError(t, err) + }, + packetBuilder: func() *PacketBuilder { + return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN) + }, + expectedStages: []PacketStage{ + StageReceived, + StageConntrack, + StageRouting, + StagePeerACL, + StageCompleted, + }, + expectedAllow: true, + }, + { + name: "LocalTraffic_ACLDenied", + setup: func(m *Manager) { + ip := net.ParseIP("1.1.1.1") + proto := fw.ProtocolTCP + port := &fw.Port{Values: []uint16{80}} + action := fw.ActionDrop + _, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "") + require.NoError(t, err) + }, + packetBuilder: func() *PacketBuilder { + return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN) + }, + expectedStages: []PacketStage{ + StageReceived, + StageConntrack, + StageRouting, + StagePeerACL, + StageCompleted, + }, + expectedAllow: false, + }, + { + name: "LocalTraffic_WithForwarder", + setup: func(m *Manager) { + m.netstack = true + m.localForwarding = true + + m.forwarder.Store(&forwarder.Forwarder{}) + + ip := net.ParseIP("1.1.1.1") + proto := fw.ProtocolTCP + port := &fw.Port{Values: []uint16{80}} + action := fw.ActionAccept + _, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "") + require.NoError(t, err) + }, + packetBuilder: func() *PacketBuilder { + return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN) + }, + expectedStages: []PacketStage{ + StageReceived, + StageConntrack, + StageRouting, + StagePeerACL, + StageForwarding, + StageCompleted, + }, + expectedAllow: true, + }, + { + name: "LocalTraffic_WithoutForwarder", + setup: func(m *Manager) { + m.netstack = true + m.localForwarding = false + + ip := net.ParseIP("1.1.1.1") + proto := fw.ProtocolTCP + port := &fw.Port{Values: []uint16{80}} + action := fw.ActionAccept + _, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "") + require.NoError(t, err) + }, + packetBuilder: func() *PacketBuilder { + return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN) + }, + expectedStages: []PacketStage{ + StageReceived, + StageConntrack, + StageRouting, + StagePeerACL, + StageCompleted, + }, + expectedAllow: true, + }, + { + name: "RoutedTraffic_ACLAllowed", + setup: func(m *Manager) { + m.routingEnabled.Store(true) + m.nativeRouter.Store(false) + + m.forwarder.Store(&forwarder.Forwarder{}) + + src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) + dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) + _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept) + require.NoError(t, err) + }, + packetBuilder: func() *PacketBuilder { + return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + }, + expectedStages: []PacketStage{ + StageReceived, + StageConntrack, + StageRouting, + StageRouteACL, + StageForwarding, + StageCompleted, + }, + expectedAllow: true, + }, + { + name: "RoutedTraffic_ACLDenied", + setup: func(m *Manager) { + m.routingEnabled.Store(true) + m.nativeRouter.Store(false) + + src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) + dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) + _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop) + require.NoError(t, err) + }, + packetBuilder: func() *PacketBuilder { + return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + }, + expectedStages: []PacketStage{ + StageReceived, + StageConntrack, + StageRouting, + StageRouteACL, + StageCompleted, + }, + expectedAllow: false, + }, + { + name: "RoutedTraffic_NativeRouter", + setup: func(m *Manager) { + m.routingEnabled.Store(true) + m.nativeRouter.Store(true) + }, + packetBuilder: func() *PacketBuilder { + return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + }, + expectedStages: []PacketStage{ + StageReceived, + StageConntrack, + StageRouting, + StageRouteACL, + StageForwarding, + StageCompleted, + }, + expectedAllow: true, + }, + { + name: "RoutedTraffic_RoutingDisabled", + setup: func(m *Manager) { + m.routingEnabled.Store(false) + }, + packetBuilder: func() *PacketBuilder { + return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + }, + expectedStages: []PacketStage{ + StageReceived, + StageConntrack, + StageRouting, + StageCompleted, + }, + expectedAllow: false, + }, + { + name: "ConnectionTracking_Hit", + setup: func(m *Manager) { + srcIP := netip.MustParseAddr("100.10.0.100") + dstIP := netip.MustParseAddr("1.1.1.1") + srcPort := uint16(12345) + dstPort := uint16(80) + + m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn) + }, + packetBuilder: func() *PacketBuilder { + pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN) + pb.TCPState = &TCPState{SYN: true, ACK: true} + return pb + }, + expectedStages: []PacketStage{ + StageReceived, + StageConntrack, + StageCompleted, + }, + expectedAllow: true, + }, + { + name: "OutboundTraffic", + setup: func(m *Manager) { + }, + packetBuilder: func() *PacketBuilder { + return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT) + }, + expectedStages: []PacketStage{ + StageReceived, + StageCompleted, + }, + expectedAllow: true, + }, + { + name: "ICMPEchoRequest", + setup: func(m *Manager) { + ip := net.ParseIP("1.1.1.1") + proto := fw.ProtocolICMP + action := fw.ActionAccept + _, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "") + require.NoError(t, err) + }, + packetBuilder: func() *PacketBuilder { + return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN) + }, + expectedStages: []PacketStage{ + StageReceived, + StageConntrack, + StageRouting, + StagePeerACL, + StageCompleted, + }, + expectedAllow: true, + }, + { + name: "ICMPDestinationUnreachable", + setup: func(m *Manager) { + ip := net.ParseIP("1.1.1.1") + proto := fw.ProtocolICMP + action := fw.ActionDrop + _, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "") + require.NoError(t, err) + }, + packetBuilder: func() *PacketBuilder { + return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN) + }, + expectedStages: []PacketStage{ + StageReceived, + StageConntrack, + StageRouting, + StagePeerACL, + StageCompleted, + }, + expectedAllow: true, + }, + { + name: "UDPTraffic_WithoutHook", + setup: func(m *Manager) { + ip := net.ParseIP("1.1.1.1") + proto := fw.ProtocolUDP + port := &fw.Port{Values: []uint16{53}} + action := fw.ActionAccept + _, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "") + require.NoError(t, err) + }, + packetBuilder: func() *PacketBuilder { + return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN) + }, + expectedStages: []PacketStage{ + StageReceived, + StageConntrack, + StageRouting, + StagePeerACL, + StageCompleted, + }, + expectedAllow: true, + }, + { + name: "UDPTraffic_WithHook", + setup: func(m *Manager) { + hookFunc := func([]byte) bool { + return true + } + m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc) + }, + packetBuilder: func() *PacketBuilder { + return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN) + }, + expectedStages: []PacketStage{ + StageReceived, + StageConntrack, + StageRouting, + StagePeerACL, + StageCompleted, + }, + expectedAllow: false, + }, + { + name: "StatefulDisabled_NoTracking", + setup: func(m *Manager) { + m.stateful = false + + ip := net.ParseIP("1.1.1.1") + proto := fw.ProtocolTCP + port := &fw.Port{Values: []uint16{80}} + action := fw.ActionDrop + _, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "") + require.NoError(t, err) + }, + packetBuilder: func() *PacketBuilder { + return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN) + }, + expectedStages: []PacketStage{ + StageReceived, + StageRouting, + StagePeerACL, + StageCompleted, + }, + expectedAllow: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + m := setupTracerTest(true) + + tc.setup(m) + + require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")), + "100.10.0.100 should be recognized as a local IP") + require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")), + "172.17.0.2 should not be recognized as a local IP") + + pb := tc.packetBuilder() + + trace, err := m.TracePacketFromBuilder(pb) + require.NoError(t, err) + + verifyTraceStages(t, trace, tc.expectedStages) + verifyFinalDisposition(t, trace, tc.expectedAllow) + }) + } +}