From e8d8bd8f182a4b0b7d11561ce95c8274598c63ad Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 7 Mar 2025 13:56:26 +0100 Subject: [PATCH] Add peer traffic rule IDs to allowed connections in flows (#3442) --- client/firewall/uspfilter/conntrack/icmp.go | 21 ++++++++-------- client/firewall/uspfilter/conntrack/tcp.go | 27 +++++++++++---------- client/firewall/uspfilter/conntrack/udp.go | 15 ++++++------ client/firewall/uspfilter/uspfilter.go | 23 +++++++++--------- 4 files changed, 45 insertions(+), 41 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 5a736d132..49cc832e6 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -93,18 +93,17 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) { if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists { // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, id, typecode, nftypes.Egress, size) + t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size) } } // TrackInbound records an inbound ICMP Echo Request -func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) { - t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, size) +func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) { + t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size) } // track is the common implementation for tracking both inbound and outbound ICMP connections -func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, size int) { - // TODO: icmp doesn't need to extend the timeout +func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) { key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size) if exists { return @@ -115,7 +114,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec // non echo requests don't need tracking if typ != uint8(layers.ICMPv4TypeEchoRequest) { t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) - t.sendStartEvent(direction, srcIP, dstIP, typ, code, size) + t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size) return } @@ -136,7 +135,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec t.mutex.Unlock() t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) - t.sendEvent(nftypes.TypeStart, conn) + t.sendEvent(nftypes.TypeStart, conn, ruleId) } // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request @@ -186,7 +185,7 @@ func (t *ICMPTracker) cleanup() { t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) - t.sendEvent(nftypes.TypeEnd, conn) + t.sendEvent(nftypes.TypeEnd, conn, nil) } } } @@ -201,10 +200,11 @@ func (t *ICMPTracker) Close() { t.mutex.Unlock() } -func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack) { +func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) { t.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: conn.FlowId, Type: typ, + RuleID: ruleID, Direction: conn.Direction, Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6 SourceIP: conn.SourceIP, @@ -218,10 +218,11 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack) { }) } -func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, size int) { +func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) { fields := nftypes.EventFields{ FlowID: uuid.New(), Type: nftypes.TypeStart, + RuleID: ruleID, Direction: direction, Protocol: nftypes.ICMP, SourceIP: srcIP, diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 96bbf0220..b5e470bf9 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -175,17 +175,17 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists { // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, size) + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size) } } // TrackInbound processes an inbound TCP packet and updates connection state -func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, size) +func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) { + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) } // track is the common implementation for tracking both inbound and outbound connections -func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) { +func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) if exists { return @@ -212,7 +212,7 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d t.connections[key] = conn t.mutex.Unlock() - t.sendEvent(nftypes.TypeStart, conn) + t.sendEvent(nftypes.TypeStart, conn, ruleID) } // IsValidInbound checks if an inbound TCP packet matches a tracked connection @@ -246,7 +246,7 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort conn.UpdateCounters(nftypes.Ingress, size) t.logger.Trace("TCP connection reset: %s", key) - t.sendEvent(nftypes.TypeEnd, conn) + t.sendEvent(nftypes.TypeEnd, conn, nil) return true } @@ -304,7 +304,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i } else if flags&TCPRst != 0 { conn.State = TCPStateClosed conn.SetTombstone() - t.sendEvent(nftypes.TypeEnd, conn) + t.sendEvent(nftypes.TypeEnd, conn, nil) } case TCPStateFinWait1: @@ -318,7 +318,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i case flags&TCPRst != 0: conn.State = TCPStateClosed conn.SetTombstone() - t.sendEvent(nftypes.TypeEnd, conn) + t.sendEvent(nftypes.TypeEnd, conn, nil) } case TCPStateFinWait2: @@ -326,7 +326,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i conn.State = TCPStateTimeWait t.logger.Trace("TCP connection %s completed", key) - t.sendEvent(nftypes.TypeEnd, conn) + t.sendEvent(nftypes.TypeEnd, conn, nil) } case TCPStateClosing: @@ -335,7 +335,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i // Keep established = false from previous state t.logger.Trace("TCP connection %s closed (simultaneous)", key) - t.sendEvent(nftypes.TypeEnd, conn) + t.sendEvent(nftypes.TypeEnd, conn, nil) } case TCPStateCloseWait: @@ -349,7 +349,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i conn.SetTombstone() // Send close event for gracefully closed connections - t.sendEvent(nftypes.TypeEnd, conn) + t.sendEvent(nftypes.TypeEnd, conn, nil) t.logger.Trace("TCP connection %s closed gracefully", key) } } @@ -436,7 +436,7 @@ func (t *TCPTracker) cleanup() { // event already handled by state change if conn.State != TCPStateTimeWait { - t.sendEvent(nftypes.TypeEnd, conn) + t.sendEvent(nftypes.TypeEnd, conn, nil) } } } @@ -467,10 +467,11 @@ func isValidFlagCombination(flags uint8) bool { return true } -func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack) { +func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) { t.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: conn.FlowId, Type: typ, + RuleID: ruleID, Direction: conn.Direction, Protocol: nftypes.TCP, SourceIP: conn.SourceIP, diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index 351b3c93a..94db24f5f 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -59,13 +59,13 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, size) + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size) } } // TrackInbound records an inbound UDP connection -func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, size) +func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) { + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size) } func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) { @@ -90,7 +90,7 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort } // track is the common implementation for tracking both inbound and outbound connections -func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) { +func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) { key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) if exists { return @@ -113,7 +113,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d t.mutex.Unlock() t.logger.Trace("New %s UDP connection: %s", direction, key) - t.sendEvent(nftypes.TypeStart, conn) + t.sendEvent(nftypes.TypeStart, conn, ruleID) } // IsValidInbound checks if an inbound packet matches a tracked connection @@ -161,7 +161,7 @@ func (t *UDPTracker) cleanup() { t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) - t.sendEvent(nftypes.TypeEnd, conn) + t.sendEvent(nftypes.TypeEnd, conn, nil) } } } @@ -196,10 +196,11 @@ func (t *UDPTracker) Timeout() time.Duration { return t.timeout } -func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack) { +func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) { t.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: conn.FlowId, Type: typ, + RuleID: ruleID, Direction: conn.Direction, Protocol: nftypes.UDP, SourceIP: conn.SourceIP, diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 23dc95517..723ef6299 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -604,16 +604,16 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) { } } -func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, size int) { +func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byte, size int) { transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size) case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) case layers.LayerTypeICMPv4: - m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size) + m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size) } } @@ -684,17 +684,18 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool { // handleLocalTraffic handles local traffic. // If it returns true, the packet should be dropped. func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { - if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked { + ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) + if blocked { _, pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", - ruleId, pnum, srcIP, srcPort, dstIP, dstPort) + ruleID, pnum, srcIP, srcPort, dstIP, dstPort) m.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: uuid.New(), Type: nftypes.TypeDrop, - RuleID: ruleId, + RuleID: ruleID, Direction: nftypes.Ingress, Protocol: pnum, SourceIP: srcIP, @@ -714,7 +715,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet } // track inbound packets to get the correct direction and session id for flows - m.trackInbound(d, srcIP, dstIP, size) + m.trackInbound(d, srcIP, dstIP, ruleID, size) return false } @@ -756,14 +757,14 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe proto, pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) - if id, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass { + if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass { m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", - id, pnum, srcIP, srcPort, dstIP, dstPort) + ruleID, pnum, srcIP, srcPort, dstIP, dstPort) m.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: uuid.New(), Type: nftypes.TypeDrop, - RuleID: id, + RuleID: ruleID, Direction: nftypes.Ingress, Protocol: pnum, SourceIP: srcIP,