diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 0dff3acc7..2ae983f6e 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "net" + "net/netip" "runtime" + "sync" log "github.com/sirupsen/logrus" "gvisor.dev/gvisor/pkg/buffer" @@ -17,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "github.com/netbirdio/netbird/client/firewall/uspfilter/common" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) @@ -29,8 +32,10 @@ const ( ) type Forwarder struct { - logger *nblog.Logger - flowLogger nftypes.FlowLogger + logger *nblog.Logger + flowLogger nftypes.FlowLogger + // ruleIdMap is used to store the rule ID for a given connection + ruleIdMap sync.Map stack *stack.Stack endpoint *endpoint udpForwarder *udpForwarder @@ -167,3 +172,35 @@ func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { } return addr.AsSlice() } + +func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) { + key := buildKey(srcIP, dstIP, srcPort, dstPort) + f.ruleIdMap.LoadOrStore(key, ruleID) +} + +func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) { + + if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok { + return value.([]byte), true + } else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok { + return value.([]byte), true + } + + return nil, false +} + +func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) { + if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok { + return + } + f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort)) +} + +func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey { + return conntrack.ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } +} diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index a21ec2c87..08d77ed05 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -25,7 +25,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf } flowID := uuid.New() - f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode) + f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0) ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) defer cancel() @@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf // TODO: support non-root conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") if err != nil { - f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err) + f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err) // This will make netstack reply on behalf of the original destination, that's ok for now return false } defer func() { if err := conn.Close(); err != nil { - f.logger.Debug("Failed to close ICMP socket: %v", err) + f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err) } }() @@ -52,36 +52,37 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf payload := fullPacket.AsSlice() if _, err = conn.WriteTo(payload, dst); err != nil { - f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err) + f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err) return true } - f.logger.Trace("Forwarded ICMP packet %v type %v code %v", + f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v", epID(id), icmpHdr.Type(), icmpHdr.Code()) // For Echo Requests, send and handle response if header.ICMPv4Type(icmpType) == header.ICMPv4Echo { - f.handleEchoResponse(icmpHdr, conn, id) - f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode) + rxBytes := pkt.Size() + txBytes := f.handleEchoResponse(icmpHdr, conn, id) + f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } // For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing return true } -func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) { +func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { - f.logger.Error("Failed to set read deadline for ICMP response: %v", err) - return + f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err) + return 0 } response := make([]byte, f.endpoint.mtu) n, _, err := conn.ReadFrom(response) if err != nil { if !isTimeout(err) { - f.logger.Error("Failed to read ICMP response: %v", err) + f.logger.Error("forwarder: Failed to read ICMP response: %v", err) } - return + return 0 } ipHdr := make([]byte, header.IPv4MinimumSize) @@ -100,28 +101,54 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon fullPacket = append(fullPacket, response[:n]...) if err := f.InjectIncomingPacket(fullPacket); err != nil { - f.logger.Error("Failed to inject ICMP response: %v", err) + f.logger.Error("forwarder: Failed to inject ICMP response: %v", err) - return + return 0 } - f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v", + f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v", epID(id), icmpHdr.Type(), icmpHdr.Code()) + + return len(fullPacket) } // sendICMPEvent stores flow events for ICMP packets -func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) { - f.flowLogger.StoreEvent(nftypes.EventFields{ +func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) { + var rxPackets, txPackets uint64 + if rxBytes > 0 { + rxPackets = 1 + } + if txBytes > 0 { + txPackets = 1 + } + + srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) + dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + + fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, Protocol: nftypes.ICMP, // TODO: handle ipv6 - SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), - DestIP: netip.AddrFrom4(id.LocalAddress.As4()), + SourceIP: srcIp, + DestIP: dstIp, ICMPType: icmpType, ICMPCode: icmpCode, - // TODO: get packets/bytes - }) + RxBytes: rxBytes, + TxBytes: txBytes, + RxPackets: rxPackets, + TxPackets: txPackets, + } + + if typ == nftypes.TypeStart { + if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok { + fields.RuleID = ruleId + } + } else { + f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort) + } + + f.flowLogger.StoreEvent(fields) } diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index 71cd457ef..04b3ae233 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -6,8 +6,10 @@ import ( "io" "net" "net/netip" + "sync" "github.com/google/uuid" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -23,11 +25,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { flowID := uuid.New() - f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil) + f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0) var success bool defer func() { if !success { - f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil) + f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0) } }() @@ -65,67 +67,97 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { } func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) { - defer func() { - if err := inConn.Close(); err != nil { - f.logger.Debug("forwarder: inConn close error: %v", err) - } - if err := outConn.Close(); err != nil { - f.logger.Debug("forwarder: outConn close error: %v", err) - } - ep.Close() - f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep) - }() - - // Create context for managing the proxy goroutines ctx, cancel := context.WithCancel(f.ctx) defer cancel() - errChan := make(chan error, 2) - go func() { - _, err := io.Copy(outConn, inConn) - errChan <- err - }() - - go func() { - _, err := io.Copy(inConn, outConn) - errChan <- err - }() - - select { - case <-ctx.Done(): - f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id)) - return - case err := <-errChan: - if err != nil && !isClosedError(err) { - f.logger.Error("proxyTCP: copy error: %v", err) + <-ctx.Done() + // Close connections and endpoint. + if err := inConn.Close(); err != nil && !isClosedError(err) { + f.logger.Debug("forwarder: inConn close error: %v", err) + } + if err := outConn.Close(); err != nil && !isClosedError(err) { + f.logger.Debug("forwarder: outConn close error: %v", err) + } + + ep.Close() + }() + + var wg sync.WaitGroup + wg.Add(2) + + var ( + bytesFromInToOut int64 // bytes from client to server (tx for client) + bytesFromOutToIn int64 // bytes from server to client (rx for client) + errInToOut error + errOutToIn error + ) + + go func() { + bytesFromInToOut, errInToOut = io.Copy(outConn, inConn) + cancel() + wg.Done() + }() + + go func() { + + bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn) + cancel() + wg.Done() + }() + + wg.Wait() + + if errInToOut != nil { + if !isClosedError(errInToOut) { + f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut) } - f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id)) - return } + if errOutToIn != nil { + if !isClosedError(errOutToIn) { + f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn) + } + } + + var rxPackets, txPackets uint64 + if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { + // fields are flipped since this is the in conn + rxPackets = tcpStats.SegmentsSent.Value() + txPackets = tcpStats.SegmentsReceived.Value() + } + + f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut) + + f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets) } -func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) { +func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { + srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) + dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, Protocol: nftypes.TCP, // TODO: handle ipv6 - SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), - DestIP: netip.AddrFrom4(id.LocalAddress.As4()), + SourceIP: srcIp, + DestIP: dstIp, SourcePort: id.RemotePort, DestPort: id.LocalPort, + RxBytes: rxBytes, + TxBytes: txBytes, + RxPackets: rxPackets, + TxPackets: txPackets, } - if ep != nil { - if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { - // fields are flipped since this is the in conn - // TODO: get bytes - fields.RxPackets = tcpStats.SegmentsSent.Value() - fields.TxPackets = tcpStats.SegmentsReceived.Value() + if typ == nftypes.TypeStart { + if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok { + fields.RuleID = ruleId } + } else { + f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort) } f.flowLogger.StoreEvent(fields) diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index 7ce85e2b6..cb88aa59a 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -149,11 +149,11 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { flowID := uuid.New() - f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil) + f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0) var success bool defer func() { if !success { - f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil) + f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0) } }() @@ -199,7 +199,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { if err := outConn.Close(); err != nil { f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) } - return } f.udpForwarder.conns[id] = pConn @@ -212,68 +211,94 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { } func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) { - defer func() { + + ctx, cancel := context.WithCancel(f.ctx) + defer cancel() + + go func() { + <-ctx.Done() + pConn.cancel() - if err := pConn.conn.Close(); err != nil { + if err := pConn.conn.Close(); err != nil && !isClosedError(err) { f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err) } - if err := pConn.outConn.Close(); err != nil { + if err := pConn.outConn.Close(); err != nil && !isClosedError(err) { f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) } ep.Close() - - f.udpForwarder.Lock() - delete(f.udpForwarder.conns, id) - f.udpForwarder.Unlock() - - f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep) }() - errChan := make(chan error, 2) + var wg sync.WaitGroup + wg.Add(2) + var txBytes, rxBytes int64 + var outboundErr, inboundErr error + + // outbound->inbound: copy from pConn.conn to pConn.outConn go func() { - errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound") + defer wg.Done() + txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound") }() + // inbound->outbound: copy from pConn.outConn to pConn.conn go func() { - errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound") + defer wg.Done() + rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound") }() - select { - case <-ctx.Done(): - f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id)) - return - case err := <-errChan: - if err != nil && !isClosedError(err) { - f.logger.Error("proxyUDP: copy error: %v", err) - } - f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id)) - return + wg.Wait() + + if outboundErr != nil && !isClosedError(outboundErr) { + f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr) } + if inboundErr != nil && !isClosedError(inboundErr) { + f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr) + } + + var rxPackets, txPackets uint64 + if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok { + // fields are flipped since this is the in conn + rxPackets = udpStats.PacketsSent.Value() + txPackets = udpStats.PacketsReceived.Value() + } + + f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes) + + f.udpForwarder.Lock() + delete(f.udpForwarder.conns, id) + f.udpForwarder.Unlock() + + f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets) } // sendUDPEvent stores flow events for UDP connections -func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) { +func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { + srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) + dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, Protocol: nftypes.UDP, // TODO: handle ipv6 - SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), - DestIP: netip.AddrFrom4(id.LocalAddress.As4()), + SourceIP: srcIp, + DestIP: dstIp, SourcePort: id.RemotePort, DestPort: id.LocalPort, + RxBytes: rxBytes, + TxBytes: txBytes, + RxPackets: rxPackets, + TxPackets: txPackets, } - if ep != nil { - if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok { - // fields are flipped since this is the in conn - // TODO: get bytes - fields.RxPackets = tcpStats.PacketsSent.Value() - fields.TxPackets = tcpStats.PacketsReceived.Value() + if typ == nftypes.TypeStart { + if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok { + fields.RuleID = ruleId } + } else { + f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort) } f.flowLogger.StoreEvent(fields) @@ -288,18 +313,20 @@ func (c *udpPacketConn) getIdleDuration() time.Duration { return time.Since(lastSeen) } -func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error { +// copy reads from src and writes to dst. +func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) { bufp := bufPool.Get().(*[]byte) defer bufPool.Put(bufp) buffer := *bufp + var totalBytes int64 = 0 for { if ctx.Err() != nil { - return ctx.Err() + return totalBytes, ctx.Err() } if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil { - return fmt.Errorf("set read deadline: %w", err) + return totalBytes, fmt.Errorf("set read deadline: %w", err) } n, err := src.Read(buffer) @@ -307,14 +334,15 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu if isTimeout(err) { continue } - return fmt.Errorf("read from %s: %w", direction, err) + return totalBytes, fmt.Errorf("read from %s: %w", direction, err) } - _, err = dst.Write(buffer[:n]) + nWritten, err := dst.Write(buffer[:n]) if err != nil { - return fmt.Errorf("write to %s: %w", direction, err) + return totalBytes, fmt.Errorf("write to %s: %w", direction, err) } + totalBytes += int64(nWritten) c.updateLastSeen() } } diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index ccf0be225..11730dbb3 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -824,7 +824,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe proto, pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) - if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass { + ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) + if !pass { m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", ruleID, pnum, srcIP, srcPort, dstIP, dstPort) @@ -850,8 +851,11 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe if fwd == nil { m.logger.Trace("failed to forward routed packet (forwarder not initialized)") } else { + fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID) + if err := fwd.InjectIncomingPacket(packetData); err != nil { m.logger.Error("Failed to inject routed packet: %v", err) + fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort) } }