Track icmp with id only (#3447)

This commit is contained in:
Viktor Liu 2025-03-06 14:51:23 +01:00 committed by GitHub
parent 0a042ac36d
commit b180edbe5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 29 deletions

View File

@ -24,12 +24,11 @@ const (
type ICMPConnKey struct { type ICMPConnKey struct {
SrcIP netip.Addr SrcIP netip.Addr
DstIP netip.Addr DstIP netip.Addr
Sequence uint16
ID uint16 ID uint16
} }
func (i ICMPConnKey) String() string { func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (%d/%d)", i.SrcIP, i.DstIP, i.ID, i.Sequence) return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
} }
// ICMPConnTrack represents an ICMP connection state // ICMPConnTrack represents an ICMP connection state
@ -69,12 +68,11 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
return tracker return tracker
} }
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16) (ICMPConnKey, bool) { func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16) (ICMPConnKey, bool) {
key := ICMPConnKey{ key := ICMPConnKey{
SrcIP: srcIP, SrcIP: srcIP,
DstIP: dstIP, DstIP: dstIP,
ID: id, ID: id,
Sequence: seq,
} }
t.mutex.RLock() t.mutex.RLock()
@ -91,22 +89,21 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
} }
// TrackOutbound records an outbound ICMP connection // TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) { func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, id); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Egress) t.track(srcIP, dstIP, id, typecode, nftypes.Egress)
} }
} }
// TrackInbound records an inbound ICMP Echo Request // TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) { func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode) {
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Ingress) t.track(srcIP, dstIP, id, typecode, nftypes.Ingress)
} }
// track is the common implementation for tracking both inbound and outbound ICMP connections // track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) { func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) {
// TODO: icmp doesn't need to extend the timeout key, exists := t.updateIfExists(srcIP, dstIP, id)
key, exists := t.updateIfExists(srcIP, dstIP, id, seq)
if exists { if exists {
return return
} }
@ -141,7 +138,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq u
} }
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, icmpType uint8) bool { func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) { if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false return false
} }
@ -150,7 +147,6 @@ func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint
SrcIP: dstIP, SrcIP: dstIP,
DstIP: srcIP, DstIP: srcIP,
ID: id, ID: id,
Sequence: seq,
} }
t.mutex.RLock() t.mutex.RLock()

View File

@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535), 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0)
} }
}) })
@ -28,12 +28,12 @@ func BenchmarkICMPTracker(b *testing.B) {
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i), 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0) tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0)
} }
}) })
} }

View File

@ -602,7 +602,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr) {
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags) m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode)
} }
} }
@ -615,7 +615,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr) {
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode)
} }
} }
@ -826,7 +826,6 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr)
srcIP, srcIP,
dstIP, dstIP,
d.icmp4.Id, d.icmp4.Id,
d.icmp4.Seq,
d.icmp4.TypeCode.Type(), d.icmp4.TypeCode.Type(),
) )