diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 630f6d04e..fda08154c 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -29,13 +29,15 @@ type ICMPConnKey struct { ID uint16 } -func (i *ICMPConnKey) String() string { - return fmt.Sprintf("%s -> %s (%d/%d)", i.SrcIP, i.DstIP, i.Sequence, i.ID) +func (i ICMPConnKey) String() string { + return fmt.Sprintf("%s -> %s (%d/%d)", i.SrcIP, i.DstIP, i.ID, i.Sequence) } // ICMPConnTrack represents an ICMP connection state type ICMPConnTrack struct { BaseConnTrack + ICMPType uint8 + ICMPCode uint8 } // ICMPTracker manages ICMP connection states @@ -85,25 +87,35 @@ func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq } // TrackOutbound records an outbound ICMP connection -func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { +func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) { if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists { // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, id, seq, nftypes.Egress) + t.track(srcIP, dstIP, id, seq, typecode, nftypes.Egress) } } // TrackInbound records an inbound ICMP Echo Request -func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { - t.track(srcIP, dstIP, id, seq, nftypes.Ingress) +func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) { + t.track(srcIP, dstIP, id, seq, typecode, nftypes.Ingress) } // track is the common implementation for tracking both inbound and outbound ICMP connections -func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, direction nftypes.Direction) { +func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) { + // TODO: icmp doesn't need to extend the timeout key, exists := t.updateIfExists(srcIP, dstIP, id, seq) if exists { return } + typ, code := typecode.Type(), typecode.Code() + + // non echo requests don't need tracking + if typ != uint8(layers.ICMPv4TypeEchoRequest) { + t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) + t.sendStartEvent(direction, key, typ, code) + return + } + conn := &ICMPConnTrack{ BaseConnTrack: BaseConnTrack{ FlowId: uuid.New(), @@ -111,6 +123,8 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, d SourceIP: key.SrcIP, DestIP: key.DstIP, }, + ICMPType: typ, + ICMPCode: code, } conn.UpdateLastSeen() @@ -118,7 +132,7 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, d t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace("New %s ICMP connection %s", conn.Direction, key) + t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) t.sendEvent(nftypes.TypeStart, key, conn) } @@ -186,7 +200,21 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, key ICMPConnKey, conn *ICMPCon Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6 SourceIP: key.SrcIP, DestIP: key.DstIP, - // TODO: add icmp code/type, + ICMPType: conn.ICMPType, + ICMPCode: conn.ICMPCode, + }) +} + +func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, key ICMPConnKey, typ, code uint8) { + t.flowLogger.StoreEvent(nftypes.EventFields{ + FlowID: uuid.New(), + Type: nftypes.TypeStart, + Direction: direction, + Protocol: nftypes.ICMP, + SourceIP: key.SrcIP, + DestIP: key.DstIP, + ICMPType: typ, + ICMPCode: code, }) } diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go index ef5317d41..b8328ae94 100644 --- a/client/firewall/uspfilter/conntrack/icmp_test.go +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535)) + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535), 0) } }) @@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) { // Pre-populate some connections for i := 0; i < 1000; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i)) + tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i), 0) } b.ResetTimer() diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index baccab3fb..7485c5267 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -590,9 +590,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) { flags := getTCPFlags(&d.tcp) m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags) case layers.LayerTypeICMPv4: - if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { - m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq) - } + m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode) } } @@ -605,9 +603,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) { flags := getTCPFlags(&d.tcp) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags) case layers.LayerTypeICMPv4: - if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { - m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq) - } + m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode) } }