diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index b6be57ff6..49af28547 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -732,34 +732,10 @@ func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, sr dstAddr := netip.AddrFrom4([4]byte(dstIP.To4())) for _, rule := range m.routeRules { - if !rule.destination.Contains(dstAddr) { - continue + if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) { + return rule.action == firewall.ActionAccept } - - sourceMatched := false - for _, src := range rule.sources { - if src.Contains(srcAddr) { - sourceMatched = true - break - } - } - if !sourceMatched { - continue - } - - if rule.proto != firewall.ProtocolALL && rule.proto != proto { - continue - } - - if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP { - if !m.portsMatch(rule.srcPort, srcPort) || !m.portsMatch(rule.dstPort, dstPort) { - continue - } - } - - return rule.action == firewall.ActionAccept } - return false } @@ -783,9 +759,10 @@ func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto return false } - // Port matches for TCP/UDP only if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP { - return m.portsMatch(rule.srcPort, srcPort) && m.portsMatch(rule.dstPort, dstPort) + if !m.portsMatch(rule.srcPort, srcPort) || !m.portsMatch(rule.dstPort, dstPort) { + return false + } } return true diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go index abfa4e54d..93b947148 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -223,15 +223,15 @@ func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcP return buf.Bytes() } -func setupRoutedManager(t testing.TB, network string) *Manager { - t.Helper() +func setupRoutedManager(tb testing.TB, network string) *Manager { + tb.Helper() - ctrl := gomock.NewController(t) + ctrl := gomock.NewController(tb) dev := mocks.NewMockDevice(ctrl) dev.EXPECT().MTU().Return(1500, nil).AnyTimes() localIP, wgNet, err := net.ParseCIDR(network) - require.NoError(t, err) + require.NoError(tb, err) ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, @@ -250,13 +250,13 @@ func setupRoutedManager(t testing.TB, network string) *Manager { } manager, err := Create(ifaceMock) - require.NoError(t, err) - require.NotNil(t, manager) - require.True(t, manager.routingEnabled) - require.False(t, manager.nativeRouter) + require.NoError(tb, err) + require.NotNil(tb, manager) + require.True(tb, manager.routingEnabled) + require.False(tb, manager.nativeRouter) - t.Cleanup(func() { - require.NoError(t, manager.Reset(nil)) + tb.Cleanup(func() { + require.NoError(tb, manager.Reset(nil)) }) return manager