[client] add byte counters & ruleID for routed traffic on userspace (#3653)

* [client] add byte counters for routed traffic on userspace 
* [client] add allowed ruleID for routed traffic on userspace
This commit is contained in:
hakansa 2025-04-28 10:10:41 +03:00 committed by GitHub
parent 3cf87b6846
commit 84bfecdd37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 235 additions and 107 deletions

View File

@ -4,7 +4,9 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"runtime" "runtime"
"sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
@ -17,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
@ -29,8 +32,10 @@ const (
) )
type Forwarder struct { type Forwarder struct {
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map
stack *stack.Stack stack *stack.Stack
endpoint *endpoint endpoint *endpoint
udpForwarder *udpForwarder udpForwarder *udpForwarder
@ -167,3 +172,35 @@ func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
} }
return addr.AsSlice() return addr.AsSlice()
} }
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
key := buildKey(srcIP, dstIP, srcPort, dstPort)
f.ruleIdMap.LoadOrStore(key, ruleID)
}
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
return value.([]byte), true
}
return nil, false
}
func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return
}
f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort))
}
func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey {
return conntrack.ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
}

View File

@ -25,7 +25,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
} }
flowID := uuid.New() flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode) f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel() defer cancel()
@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
// TODO: support non-root // TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil { if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err) f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
// This will make netstack reply on behalf of the original destination, that's ok for now // This will make netstack reply on behalf of the original destination, that's ok for now
return false return false
} }
defer func() { defer func() {
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err) f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err)
} }
}() }()
@ -52,36 +52,37 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
payload := fullPacket.AsSlice() payload := fullPacket.AsSlice()
if _, err = conn.WriteTo(payload, dst); err != nil { if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err) f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
return true return true
} }
f.logger.Trace("Forwarded ICMP packet %v type %v code %v", f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code()) epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response // For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo { if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
f.handleEchoResponse(icmpHdr, conn, id) rxBytes := pkt.Size()
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode) txBytes := f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
} }
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing // For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true return true
} }
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) { func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err) f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err)
return return 0
} }
response := make([]byte, f.endpoint.mtu) response := make([]byte, f.endpoint.mtu)
n, _, err := conn.ReadFrom(response) n, _, err := conn.ReadFrom(response)
if err != nil { if err != nil {
if !isTimeout(err) { if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err) f.logger.Error("forwarder: Failed to read ICMP response: %v", err)
} }
return return 0
} }
ipHdr := make([]byte, header.IPv4MinimumSize) ipHdr := make([]byte, header.IPv4MinimumSize)
@ -100,28 +101,54 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
fullPacket = append(fullPacket, response[:n]...) fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil { if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err) f.logger.Error("forwarder: Failed to inject ICMP response: %v", err)
return return 0
} }
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v", f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code()) epID(id), icmpHdr.Type(), icmpHdr.Code())
return len(fullPacket)
} }
// sendICMPEvent stores flow events for ICMP packets // sendICMPEvent stores flow events for ICMP packets
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) { func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) {
f.flowLogger.StoreEvent(nftypes.EventFields{ var rxPackets, txPackets uint64
if rxBytes > 0 {
rxPackets = 1
}
if txBytes > 0 {
txPackets = 1
}
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.ICMP, Protocol: nftypes.ICMP,
// TODO: handle ipv6 // TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), SourceIP: srcIp,
DestIP: netip.AddrFrom4(id.LocalAddress.As4()), DestIP: dstIp,
ICMPType: icmpType, ICMPType: icmpType,
ICMPCode: icmpCode, ICMPCode: icmpCode,
// TODO: get packets/bytes RxBytes: rxBytes,
}) TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
}
if typ == nftypes.TypeStart {
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
fields.RuleID = ruleId
}
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
}
f.flowLogger.StoreEvent(fields)
} }

View File

@ -6,8 +6,10 @@ import (
"io" "io"
"net" "net"
"net/netip" "net/netip"
"sync"
"github.com/google/uuid" "github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
@ -23,11 +25,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
flowID := uuid.New() flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil) f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
var success bool var success bool
defer func() { defer func() {
if !success { if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil) f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
} }
}() }()
@ -65,67 +67,97 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
} }
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) { func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
defer func() {
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
}()
// Create context for managing the proxy goroutines
ctx, cancel := context.WithCancel(f.ctx) ctx, cancel := context.WithCancel(f.ctx)
defer cancel() defer cancel()
errChan := make(chan error, 2)
go func() { go func() {
_, err := io.Copy(outConn, inConn) <-ctx.Done()
errChan <- err // Close connections and endpoint.
}() if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: inConn close error: %v", err)
go func() { }
_, err := io.Copy(inConn, outConn) if err := outConn.Close(); err != nil && !isClosedError(err) {
errChan <- err f.logger.Debug("forwarder: outConn close error: %v", err)
}() }
select { ep.Close()
case <-ctx.Done(): }()
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
return var wg sync.WaitGroup
case err := <-errChan: wg.Add(2)
if err != nil && !isClosedError(err) {
f.logger.Error("proxyTCP: copy error: %v", err) var (
bytesFromInToOut int64 // bytes from client to server (tx for client)
bytesFromOutToIn int64 // bytes from server to client (rx for client)
errInToOut error
errOutToIn error
)
go func() {
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
cancel()
wg.Done()
}()
go func() {
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
cancel()
wg.Done()
}()
wg.Wait()
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut)
} }
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
return
} }
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn)
}
}
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
rxPackets = tcpStats.SegmentsSent.Value()
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
} }
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) { func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.TCP, Protocol: nftypes.TCP,
// TODO: handle ipv6 // TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), SourceIP: srcIp,
DestIP: netip.AddrFrom4(id.LocalAddress.As4()), DestIP: dstIp,
SourcePort: id.RemotePort, SourcePort: id.RemotePort,
DestPort: id.LocalPort, DestPort: id.LocalPort,
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
} }
if ep != nil { if typ == nftypes.TypeStart {
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
// fields are flipped since this is the in conn fields.RuleID = ruleId
// TODO: get bytes
fields.RxPackets = tcpStats.SegmentsSent.Value()
fields.TxPackets = tcpStats.SegmentsReceived.Value()
} }
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
} }
f.flowLogger.StoreEvent(fields) f.flowLogger.StoreEvent(fields)

View File

@ -149,11 +149,11 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
flowID := uuid.New() flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil) f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
var success bool var success bool
defer func() { defer func() {
if !success { if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil) f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
} }
}() }()
@ -199,7 +199,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
return return
} }
f.udpForwarder.conns[id] = pConn f.udpForwarder.conns[id] = pConn
@ -212,68 +211,94 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
} }
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) { func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
defer func() {
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
pConn.cancel() pConn.cancel()
if err := pConn.conn.Close(); err != nil { if err := pConn.conn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
} }
if err := pConn.outConn.Close(); err != nil { if err := pConn.outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
ep.Close() ep.Close()
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
}() }()
errChan := make(chan error, 2) var wg sync.WaitGroup
wg.Add(2)
var txBytes, rxBytes int64
var outboundErr, inboundErr error
// outbound->inbound: copy from pConn.conn to pConn.outConn
go func() { go func() {
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound") defer wg.Done()
txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}() }()
// inbound->outbound: copy from pConn.outConn to pConn.conn
go func() { go func() {
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound") defer wg.Done()
rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}() }()
select { wg.Wait()
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id)) if outboundErr != nil && !isClosedError(outboundErr) {
return f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr)
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyUDP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
return
} }
if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr)
}
var rxPackets, txPackets uint64
if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
rxPackets = udpStats.PacketsSent.Value()
txPackets = udpStats.PacketsReceived.Value()
}
f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets)
} }
// sendUDPEvent stores flow events for UDP connections // sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) { func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.UDP, Protocol: nftypes.UDP,
// TODO: handle ipv6 // TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), SourceIP: srcIp,
DestIP: netip.AddrFrom4(id.LocalAddress.As4()), DestIP: dstIp,
SourcePort: id.RemotePort, SourcePort: id.RemotePort,
DestPort: id.LocalPort, DestPort: id.LocalPort,
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
} }
if ep != nil { if typ == nftypes.TypeStart {
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok { if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
// fields are flipped since this is the in conn fields.RuleID = ruleId
// TODO: get bytes
fields.RxPackets = tcpStats.PacketsSent.Value()
fields.TxPackets = tcpStats.PacketsReceived.Value()
} }
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
} }
f.flowLogger.StoreEvent(fields) f.flowLogger.StoreEvent(fields)
@ -288,18 +313,20 @@ func (c *udpPacketConn) getIdleDuration() time.Duration {
return time.Since(lastSeen) return time.Since(lastSeen)
} }
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error { // copy reads from src and writes to dst.
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) {
bufp := bufPool.Get().(*[]byte) bufp := bufPool.Get().(*[]byte)
defer bufPool.Put(bufp) defer bufPool.Put(bufp)
buffer := *bufp buffer := *bufp
var totalBytes int64 = 0
for { for {
if ctx.Err() != nil { if ctx.Err() != nil {
return ctx.Err() return totalBytes, ctx.Err()
} }
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil { if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err) return totalBytes, fmt.Errorf("set read deadline: %w", err)
} }
n, err := src.Read(buffer) n, err := src.Read(buffer)
@ -307,14 +334,15 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
if isTimeout(err) { if isTimeout(err) {
continue continue
} }
return fmt.Errorf("read from %s: %w", direction, err) return totalBytes, fmt.Errorf("read from %s: %w", direction, err)
} }
_, err = dst.Write(buffer[:n]) nWritten, err := dst.Write(buffer[:n])
if err != nil { if err != nil {
return fmt.Errorf("write to %s: %w", direction, err) return totalBytes, fmt.Errorf("write to %s: %w", direction, err)
} }
totalBytes += int64(nWritten)
c.updateLastSeen() c.updateLastSeen()
} }
} }

View File

@ -824,7 +824,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
proto, pnum := getProtocolFromPacket(d) proto, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass { ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
if !pass {
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort) ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
@ -850,8 +851,11 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
if fwd == nil { if fwd == nil {
m.logger.Trace("failed to forward routed packet (forwarder not initialized)") m.logger.Trace("failed to forward routed packet (forwarder not initialized)")
} else { } else {
fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID)
if err := fwd.InjectIncomingPacket(packetData); err != nil { if err := fwd.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject routed packet: %v", err) m.logger.Error("Failed to inject routed packet: %v", err)
fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort)
} }
} }