diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 466065d31..466c6a18b 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -658,7 +658,8 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool { d := m.decoders.Get().(*decoder) defer m.decoders.Put(d) - if !m.isValidPacket(d, packetData) { + valid, fragment := m.isValidPacket(d, packetData) + if !valid { return true } @@ -668,6 +669,13 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool { return true } + // TODO: pass fragments of routed packets to forwarder + if fragment { + m.logger.Trace("packet is a fragment: src=%v dst=%v id=%v flags=%v", + srcIP, dstIP, d.ip4.Id, d.ip4.Flags) + return false + } + // For all inbound traffic, first check if it matches a tracked connection. // This must happen before any other filtering because the packets are statefully tracked. if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) { @@ -815,17 +823,32 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) { } } -func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { +// isValidPacket checks if the packet is valid. +// It returns true, false if the packet is valid and not a fragment. +// It returns true, true if the packet is a fragment and valid. +func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { m.logger.Trace("couldn't decode packet, err: %s", err) - return false + return false, false } - if len(d.decoded) < 2 { - m.logger.Trace("packet doesn't have network and transport layers") - return false + l := len(d.decoded) + + // L3 and L4 are mandatory + if l >= 2 { + return true, false } - return true + + // Fragments are also valid + if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 { + ip4 := d.ip4 + if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 { + return true, true + } + } + + m.logger.Trace("packet doesn't have network and transport layers") + return false, false } func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {