Add peer traffic rule IDs to allowed connections in flows (#3442)

This commit is contained in:
Viktor Liu 2025-03-07 13:56:26 +01:00 committed by GitHub
parent 8b07f21c28
commit e8d8bd8f18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 45 additions and 41 deletions

View File

@ -93,18 +93,17 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) { func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, size) t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
} }
} }
// TrackInbound records an inbound ICMP Echo Request // TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) { func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, size) t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
} }
// track is the common implementation for tracking both inbound and outbound ICMP connections // track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, size int) { func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
// TODO: icmp doesn't need to extend the timeout
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size) key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists { if exists {
return return
@ -115,7 +114,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
// non echo requests don't need tracking // non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) { if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, size) t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return return
} }
@ -136,7 +135,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, conn) t.sendEvent(nftypes.TypeStart, conn, ruleId)
} }
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
@ -186,7 +185,7 @@ func (t *ICMPTracker) cleanup() {
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
@ -201,10 +200,11 @@ func (t *ICMPTracker) Close() {
t.mutex.Unlock() t.mutex.Unlock()
} }
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack) { func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{ t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId, FlowID: conn.FlowId,
Type: typ, Type: typ,
RuleID: ruleID,
Direction: conn.Direction, Direction: conn.Direction,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6 Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
SourceIP: conn.SourceIP, SourceIP: conn.SourceIP,
@ -218,10 +218,11 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack) {
}) })
} }
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, size int) { func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
fields := nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
Type: nftypes.TypeStart, Type: nftypes.TypeStart,
RuleID: ruleID,
Direction: direction, Direction: direction,
Protocol: nftypes.ICMP, Protocol: nftypes.ICMP,
SourceIP: srcIP, SourceIP: srcIP,

View File

@ -175,17 +175,17 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) { func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, size) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
} }
} }
// TrackInbound processes an inbound TCP packet and updates connection state // TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) { func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, size) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) { func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists { if exists {
return return
@ -212,7 +212,7 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, conn) t.sendEvent(nftypes.TypeStart, conn, ruleID)
} }
// IsValidInbound checks if an inbound TCP packet matches a tracked connection // IsValidInbound checks if an inbound TCP packet matches a tracked connection
@ -246,7 +246,7 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
conn.UpdateCounters(nftypes.Ingress, size) conn.UpdateCounters(nftypes.Ingress, size)
t.logger.Trace("TCP connection reset: %s", key) t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, conn) t.sendEvent(nftypes.TypeEnd, conn, nil)
return true return true
} }
@ -304,7 +304,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
} else if flags&TCPRst != 0 { } else if flags&TCPRst != 0 {
conn.State = TCPStateClosed conn.State = TCPStateClosed
conn.SetTombstone() conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateFinWait1: case TCPStateFinWait1:
@ -318,7 +318,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
case flags&TCPRst != 0: case flags&TCPRst != 0:
conn.State = TCPStateClosed conn.State = TCPStateClosed
conn.SetTombstone() conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateFinWait2: case TCPStateFinWait2:
@ -326,7 +326,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.State = TCPStateTimeWait conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key) t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, conn) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateClosing: case TCPStateClosing:
@ -335,7 +335,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
// Keep established = false from previous state // Keep established = false from previous state
t.logger.Trace("TCP connection %s closed (simultaneous)", key) t.logger.Trace("TCP connection %s closed (simultaneous)", key)
t.sendEvent(nftypes.TypeEnd, conn) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateCloseWait: case TCPStateCloseWait:
@ -349,7 +349,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.SetTombstone() conn.SetTombstone()
// Send close event for gracefully closed connections // Send close event for gracefully closed connections
t.sendEvent(nftypes.TypeEnd, conn) t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("TCP connection %s closed gracefully", key) t.logger.Trace("TCP connection %s closed gracefully", key)
} }
} }
@ -436,7 +436,7 @@ func (t *TCPTracker) cleanup() {
// event already handled by state change // event already handled by state change
if conn.State != TCPStateTimeWait { if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, conn) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
@ -467,10 +467,11 @@ func isValidFlagCombination(flags uint8) bool {
return true return true
} }
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack) { func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{ t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId, FlowID: conn.FlowId,
Type: typ, Type: typ,
RuleID: ruleID,
Direction: conn.Direction, Direction: conn.Direction,
Protocol: nftypes.TCP, Protocol: nftypes.TCP,
SourceIP: conn.SourceIP, SourceIP: conn.SourceIP,

View File

@ -59,13 +59,13 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, size) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
} }
} }
// TrackInbound records an inbound UDP connection // TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, size) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
} }
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) { func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
@ -90,7 +90,7 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) { func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists { if exists {
return return
@ -113,7 +113,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key) t.logger.Trace("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, conn) t.sendEvent(nftypes.TypeStart, conn, ruleID)
} }
// IsValidInbound checks if an inbound packet matches a tracked connection // IsValidInbound checks if an inbound packet matches a tracked connection
@ -161,7 +161,7 @@ func (t *UDPTracker) cleanup() {
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
@ -196,10 +196,11 @@ func (t *UDPTracker) Timeout() time.Duration {
return t.timeout return t.timeout
} }
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack) { func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{ t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId, FlowID: conn.FlowId,
Type: typ, Type: typ,
RuleID: ruleID,
Direction: conn.Direction, Direction: conn.Direction,
Protocol: nftypes.UDP, Protocol: nftypes.UDP,
SourceIP: conn.SourceIP, SourceIP: conn.SourceIP,

View File

@ -604,16 +604,16 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
} }
} }
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, size int) { func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byte, size int) {
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
} }
} }
@ -684,17 +684,18 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
// handleLocalTraffic handles local traffic. // handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked { ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
if blocked {
_, pnum := getProtocolFromPacket(d) _, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleId, pnum, srcIP, srcPort, dstIP, dstPort) ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{ m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
Type: nftypes.TypeDrop, Type: nftypes.TypeDrop,
RuleID: ruleId, RuleID: ruleID,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: pnum, Protocol: pnum,
SourceIP: srcIP, SourceIP: srcIP,
@ -714,7 +715,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
} }
// track inbound packets to get the correct direction and session id for flows // track inbound packets to get the correct direction and session id for flows
m.trackInbound(d, srcIP, dstIP, size) m.trackInbound(d, srcIP, dstIP, ruleID, size)
return false return false
} }
@ -756,14 +757,14 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
proto, pnum := getProtocolFromPacket(d) proto, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
if id, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass { if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
id, pnum, srcIP, srcPort, dstIP, dstPort) ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{ m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
Type: nftypes.TypeDrop, Type: nftypes.TypeDrop,
RuleID: id, RuleID: ruleID,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: pnum, Protocol: pnum,
SourceIP: srcIP, SourceIP: srcIP,