Support local IPs in netstack mode

This commit is contained in:
Viktor Liu 2025-01-02 17:53:07 +01:00
parent 2b8092dfad
commit 911f86ded8
5 changed files with 50 additions and 16 deletions

View File

@ -3,6 +3,7 @@ package forwarder
import ( import (
"context" "context"
"fmt" "fmt"
"net"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
@ -30,9 +31,11 @@ type Forwarder struct {
udpForwarder *udpForwarder udpForwarder *udpForwarder
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
ip net.IP
netstack bool
} }
func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) { func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
s := stack.New(stack.Options{ s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{ TransportProtocols: []stack.TransportProtocolFactory{
@ -101,6 +104,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) {
udpForwarder: newUDPForwarder(logger), udpForwarder: newUDPForwarder(logger),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack,
ip: iface.Address().IP,
} }
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP) tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
@ -142,3 +147,10 @@ func (f *Forwarder) Stop() {
f.stack.Close() f.stack.Close()
f.stack.Wait() f.stack.Wait()
} }
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr.AsSlice()) {
return net.IPv4(127, 0, 0, 1)
}
return addr.AsSlice()
}

View File

@ -27,7 +27,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
} }
}() }()
dstIP := net.IP(id.LocalAddress.AsSlice()) dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP} dst := &net.IPAddr{IP: dstIP}
// Get the complete ICMP message (header + data) // Get the complete ICMP message (header + data)

View File

@ -17,10 +17,9 @@ import (
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID() id := r.ID()
dstAddr := id.LocalAddress dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
dstPort := id.LocalPort
dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort)
f.logger.Trace("forwarder: handling TCP connection %v", id)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil { if err != nil {
r.Complete(true) r.Complete(true)

View File

@ -125,14 +125,13 @@ func (f *udpForwarder) cleanup() {
// handleUDP is called by the UDP forwarder for new packets // handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
id := r.ID()
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
if f.ctx.Err() != nil { if f.ctx.Err() != nil {
f.logger.Trace("forwarder: context done, dropping UDP packet") f.logger.Trace("forwarder: context done, dropping UDP packet")
return return
} }
id := r.ID()
f.udpForwarder.RLock() f.udpForwarder.RLock()
_, exists := f.udpForwarder.conns[id] _, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock() f.udpForwarder.RUnlock()
@ -141,6 +140,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
return return
} }
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil { if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err) f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)

View File

@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@ -57,6 +58,8 @@ type Manager struct {
nativeRouter bool nativeRouter bool
// indicates whether we track outbound connections // indicates whether we track outbound connections
stateful bool stateful bool
// indicates whether wireguards runs in netstack mode
netstack bool
localipmanager *localIPManager localipmanager *localIPManager
@ -131,6 +134,7 @@ func create(iface common.IFaceMapper) (*Manager, error) {
stateful: !disableConntrack, stateful: !disableConntrack,
// TODO: support changing log level from logrus // TODO: support changing log level from logrus
logger: nblog.NewFromLogrus(log.StandardLogger()), logger: nblog.NewFromLogrus(log.StandardLogger()),
netstack: netstack.IsEnabled(),
} }
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
@ -157,7 +161,7 @@ func create(iface common.IFaceMapper) (*Manager, error) {
// Only supported in userspace mode as we need to inject packets back into wireguard directly // Only supported in userspace mode as we need to inject packets back into wireguard directly
} else { } else {
var err error var err error
m.forwarder, err = forwarder.New(iface, m.logger) m.forwarder, err = forwarder.New(iface, m.logger, m.netstack)
if err != nil { if err != nil {
log.Errorf("failed to create forwarder: %v", err) log.Errorf("failed to create forwarder: %v", err)
} else { } else {
@ -505,16 +509,36 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
// Handle local traffic - apply peer ACLs // Handle local traffic - apply peer ACLs
if m.localipmanager.IsLocalIP(dstIP) { if m.localipmanager.IsLocalIP(dstIP) {
drop := m.applyRules(srcIP, packetData, rules, d) if m.peerACLsBlock(srcIP, packetData, rules, d) {
if drop {
m.logger.Trace("Dropping local packet: src=%s dst=%s rules=denied", m.logger.Trace("Dropping local packet: src=%s dst=%s rules=denied",
srcIP, dstIP) srcIP, dstIP)
return true
} }
return drop
// if running in netstack mode we need to pass this to the forwarder
if m.netstack {
m.logger.Trace("Passing local packet to netstack: src=%s dst=%s", srcIP, dstIP)
m.handleNetstackLocalTraffic(packetData)
// don't process this packet further
return true
} }
return false
}
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData) return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
} }
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) {
if m.forwarder == nil {
return
}
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject local packet: %v", err)
}
}
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
// Drop if routing is disabled // Drop if routing is disabled
if !m.routingEnabled { if !m.routingEnabled {
@ -540,8 +564,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
} }
// Let forwarder handle the packet if it passed route ACLs // Let forwarder handle the packet if it passed route ACLs
err := m.forwarder.InjectIncomingPacket(packetData) if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
if err != nil {
m.logger.Error("Failed to inject incoming packet: %v", err) m.logger.Error("Failed to inject incoming packet: %v", err)
} }
@ -631,7 +654,7 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
icmpType == layers.ICMPv4TypeTimeExceeded icmpType == layers.ICMPv4TypeTimeExceeded
} }
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
if m.isSpecialICMP(d) { if m.isSpecialICMP(d) {
return false return false
} }