Support local IPs in netstack mode

This commit is contained in:
Viktor Liu 2025-01-02 17:53:07 +01:00
parent 2b8092dfad
commit 911f86ded8
5 changed files with 50 additions and 16 deletions

View File

@ -3,6 +3,7 @@ package forwarder
import (
"context"
"fmt"
"net"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer"
@ -30,9 +31,11 @@ type Forwarder struct {
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip net.IP
netstack bool
}
func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) {
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
@ -101,6 +104,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) {
udpForwarder: newUDPForwarder(logger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: iface.Address().IP,
}
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
@ -142,3 +147,10 @@ func (f *Forwarder) Stop() {
f.stack.Close()
f.stack.Wait()
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr.AsSlice()) {
return net.IPv4(127, 0, 0, 1)
}
return addr.AsSlice()
}

View File

@ -27,7 +27,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
}
}()
dstIP := net.IP(id.LocalAddress.AsSlice())
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
// Get the complete ICMP message (header + data)

View File

@ -17,10 +17,9 @@ import (
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
dstAddr := id.LocalAddress
dstPort := id.LocalPort
dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort)
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
f.logger.Trace("forwarder: handling TCP connection %v", id)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)

View File

@ -125,14 +125,13 @@ func (f *udpForwarder) cleanup() {
// handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
id := r.ID()
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
if f.ctx.Err() != nil {
f.logger.Trace("forwarder: context done, dropping UDP packet")
return
}
id := r.ID()
f.udpForwarder.RLock()
_, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
@ -141,6 +140,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
return
}
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)

View File

@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@ -57,6 +58,8 @@ type Manager struct {
nativeRouter bool
// indicates whether we track outbound connections
stateful bool
// indicates whether wireguards runs in netstack mode
netstack bool
localipmanager *localIPManager
@ -130,7 +133,8 @@ func create(iface common.IFaceMapper) (*Manager, error) {
localipmanager: newLocalIPManager(),
stateful: !disableConntrack,
// TODO: support changing log level from logrus
logger: nblog.NewFromLogrus(log.StandardLogger()),
logger: nblog.NewFromLogrus(log.StandardLogger()),
netstack: netstack.IsEnabled(),
}
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
@ -157,7 +161,7 @@ func create(iface common.IFaceMapper) (*Manager, error) {
// Only supported in userspace mode as we need to inject packets back into wireguard directly
} else {
var err error
m.forwarder, err = forwarder.New(iface, m.logger)
m.forwarder, err = forwarder.New(iface, m.logger, m.netstack)
if err != nil {
log.Errorf("failed to create forwarder: %v", err)
} else {
@ -505,16 +509,36 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
// Handle local traffic - apply peer ACLs
if m.localipmanager.IsLocalIP(dstIP) {
drop := m.applyRules(srcIP, packetData, rules, d)
if drop {
if m.peerACLsBlock(srcIP, packetData, rules, d) {
m.logger.Trace("Dropping local packet: src=%s dst=%s rules=denied",
srcIP, dstIP)
return true
}
return drop
// if running in netstack mode we need to pass this to the forwarder
if m.netstack {
m.logger.Trace("Passing local packet to netstack: src=%s dst=%s", srcIP, dstIP)
m.handleNetstackLocalTraffic(packetData)
// don't process this packet further
return true
}
return false
}
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
}
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) {
if m.forwarder == nil {
return
}
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject local packet: %v", err)
}
}
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
// Drop if routing is disabled
if !m.routingEnabled {
@ -540,8 +564,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
}
// Let forwarder handle the packet if it passed route ACLs
err := m.forwarder.InjectIncomingPacket(packetData)
if err != nil {
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject incoming packet: %v", err)
}
@ -631,7 +654,7 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
icmpType == layers.ICMPv4TypeTimeExceeded
}
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
if m.isSpecialICMP(d) {
return false
}