diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 97617a52d..7837c1a23 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -22,14 +22,13 @@ const ( // ICMPConnKey uniquely identifies an ICMP connection type ICMPConnKey struct { - SrcIP netip.Addr - DstIP netip.Addr - Sequence uint16 - ID uint16 + SrcIP netip.Addr + DstIP netip.Addr + ID uint16 } func (i ICMPConnKey) String() string { - return fmt.Sprintf("%s -> %s (%d/%d)", i.SrcIP, i.DstIP, i.ID, i.Sequence) + return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID) } // ICMPConnTrack represents an ICMP connection state @@ -69,12 +68,11 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty return tracker } -func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16) (ICMPConnKey, bool) { +func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16) (ICMPConnKey, bool) { key := ICMPConnKey{ - SrcIP: srcIP, - DstIP: dstIP, - ID: id, - Sequence: seq, + SrcIP: srcIP, + DstIP: dstIP, + ID: id, } t.mutex.RLock() @@ -91,22 +89,21 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint } // TrackOutbound records an outbound ICMP connection -func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) { - if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists { +func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode) { + if _, exists := t.updateIfExists(dstIP, srcIP, id); !exists { // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, id, seq, typecode, nftypes.Egress) + t.track(srcIP, dstIP, id, typecode, nftypes.Egress) } } // TrackInbound records an inbound ICMP Echo Request -func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) { - t.track(srcIP, dstIP, id, seq, typecode, nftypes.Ingress) +func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode) { + t.track(srcIP, dstIP, id, typecode, nftypes.Ingress) } // track is the common implementation for tracking both inbound and outbound ICMP connections -func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) { - // TODO: icmp doesn't need to extend the timeout - key, exists := t.updateIfExists(srcIP, dstIP, id, seq) +func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) { + key, exists := t.updateIfExists(srcIP, dstIP, id) if exists { return } @@ -141,16 +138,15 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq u } // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request -func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, icmpType uint8) bool { +func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8) bool { if icmpType != uint8(layers.ICMPv4TypeEchoReply) { return false } key := ICMPConnKey{ - SrcIP: dstIP, - DstIP: srcIP, - ID: id, - Sequence: seq, + SrcIP: dstIP, + DstIP: srcIP, + ID: id, } t.mutex.RLock() diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go index 259cc21a4..01040feb2 100644 --- a/client/firewall/uspfilter/conntrack/icmp_test.go +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535), 0) + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0) } }) @@ -28,12 +28,12 @@ func BenchmarkICMPTracker(b *testing.B) { // Pre-populate some connections for i := 0; i < 1000; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i), 0) + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0) } b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0) + tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0) } }) } diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 05dd9ff06..e663966ab 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -602,7 +602,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr) { flags := getTCPFlags(&d.tcp) m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags) case layers.LayerTypeICMPv4: - m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode) + m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode) } } @@ -615,7 +615,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr) { flags := getTCPFlags(&d.tcp) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags) case layers.LayerTypeICMPv4: - m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode) + m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode) } } @@ -826,7 +826,6 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr) srcIP, dstIP, d.icmp4.Id, - d.icmp4.Seq, d.icmp4.TypeCode.Type(), )