Reduce complexity and fix linter issues

This commit is contained in:
Viktor Liu 2025-01-03 15:43:28 +01:00
parent c68be6b61b
commit 979fe6bb6a
2 changed files with 15 additions and 38 deletions

View File

@ -732,34 +732,10 @@ func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, sr
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4())) dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
for _, rule := range m.routeRules { for _, rule := range m.routeRules {
if !rule.destination.Contains(dstAddr) { if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
continue return rule.action == firewall.ActionAccept
} }
sourceMatched := false
for _, src := range rule.sources {
if src.Contains(srcAddr) {
sourceMatched = true
break
}
}
if !sourceMatched {
continue
}
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
continue
}
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
if !m.portsMatch(rule.srcPort, srcPort) || !m.portsMatch(rule.dstPort, dstPort) {
continue
}
}
return rule.action == firewall.ActionAccept
} }
return false return false
} }
@ -783,9 +759,10 @@ func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto
return false return false
} }
// Port matches for TCP/UDP only
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP { if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
return m.portsMatch(rule.srcPort, srcPort) && m.portsMatch(rule.dstPort, dstPort) if !m.portsMatch(rule.srcPort, srcPort) || !m.portsMatch(rule.dstPort, dstPort) {
return false
}
} }
return true return true

View File

@ -223,15 +223,15 @@ func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcP
return buf.Bytes() return buf.Bytes()
} }
func setupRoutedManager(t testing.TB, network string) *Manager { func setupRoutedManager(tb testing.TB, network string) *Manager {
t.Helper() tb.Helper()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(tb)
dev := mocks.NewMockDevice(ctrl) dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes() dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
localIP, wgNet, err := net.ParseCIDR(network) localIP, wgNet, err := net.ParseCIDR(network)
require.NoError(t, err) require.NoError(tb, err)
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
@ -250,13 +250,13 @@ func setupRoutedManager(t testing.TB, network string) *Manager {
} }
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock)
require.NoError(t, err) require.NoError(tb, err)
require.NotNil(t, manager) require.NotNil(tb, manager)
require.True(t, manager.routingEnabled) require.True(tb, manager.routingEnabled)
require.False(t, manager.nativeRouter) require.False(tb, manager.nativeRouter)
t.Cleanup(func() { tb.Cleanup(func() {
require.NoError(t, manager.Reset(nil)) require.NoError(tb, manager.Reset(nil))
}) })
return manager return manager