mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 17:58:02 +02:00
Support local IPs in netstack mode
This commit is contained in:
parent
2b8092dfad
commit
911f86ded8
@ -3,6 +3,7 @@ package forwarder
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
@ -30,9 +31,11 @@ type Forwarder struct {
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip net.IP
|
||||
netstack bool
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) {
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
@ -101,6 +104,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) {
|
||||
udpForwarder: newUDPForwarder(logger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: iface.Address().IP,
|
||||
}
|
||||
|
||||
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
|
||||
@ -142,3 +147,10 @@ func (f *Forwarder) Stop() {
|
||||
f.stack.Close()
|
||||
f.stack.Wait()
|
||||
}
|
||||
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||
return net.IPv4(127, 0, 0, 1)
|
||||
}
|
||||
return addr.AsSlice()
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
}
|
||||
}()
|
||||
|
||||
dstIP := net.IP(id.LocalAddress.AsSlice())
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
// Get the complete ICMP message (header + data)
|
||||
|
@ -17,10 +17,9 @@ import (
|
||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
dstAddr := id.LocalAddress
|
||||
dstPort := id.LocalPort
|
||||
dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort)
|
||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
|
||||
f.logger.Trace("forwarder: handling TCP connection %v", id)
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
|
@ -125,14 +125,13 @@ func (f *udpForwarder) cleanup() {
|
||||
|
||||
// handleUDP is called by the UDP forwarder for new packets
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
|
||||
|
||||
if f.ctx.Err() != nil {
|
||||
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||
return
|
||||
}
|
||||
|
||||
id := r.ID()
|
||||
|
||||
f.udpForwarder.RLock()
|
||||
_, exists := f.udpForwarder.conns[id]
|
||||
f.udpForwarder.RUnlock()
|
||||
@ -141,6 +140,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
return
|
||||
}
|
||||
|
||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@ -57,6 +58,8 @@ type Manager struct {
|
||||
nativeRouter bool
|
||||
// indicates whether we track outbound connections
|
||||
stateful bool
|
||||
// indicates whether wireguards runs in netstack mode
|
||||
netstack bool
|
||||
|
||||
localipmanager *localIPManager
|
||||
|
||||
@ -130,7 +133,8 @@ func create(iface common.IFaceMapper) (*Manager, error) {
|
||||
localipmanager: newLocalIPManager(),
|
||||
stateful: !disableConntrack,
|
||||
// TODO: support changing log level from logrus
|
||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||
netstack: netstack.IsEnabled(),
|
||||
}
|
||||
|
||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||
@ -157,7 +161,7 @@ func create(iface common.IFaceMapper) (*Manager, error) {
|
||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||
} else {
|
||||
var err error
|
||||
m.forwarder, err = forwarder.New(iface, m.logger)
|
||||
m.forwarder, err = forwarder.New(iface, m.logger, m.netstack)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create forwarder: %v", err)
|
||||
} else {
|
||||
@ -505,16 +509,36 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
|
||||
// Handle local traffic - apply peer ACLs
|
||||
if m.localipmanager.IsLocalIP(dstIP) {
|
||||
drop := m.applyRules(srcIP, packetData, rules, d)
|
||||
if drop {
|
||||
if m.peerACLsBlock(srcIP, packetData, rules, d) {
|
||||
m.logger.Trace("Dropping local packet: src=%s dst=%s rules=denied",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
return drop
|
||||
|
||||
// if running in netstack mode we need to pass this to the forwarder
|
||||
if m.netstack {
|
||||
m.logger.Trace("Passing local packet to netstack: src=%s dst=%s", srcIP, dstIP)
|
||||
m.handleNetstackLocalTraffic(packetData)
|
||||
// don't process this packet further
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||
}
|
||||
|
||||
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) {
|
||||
if m.forwarder == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject local packet: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
// Drop if routing is disabled
|
||||
if !m.routingEnabled {
|
||||
@ -540,8 +564,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
|
||||
}
|
||||
|
||||
// Let forwarder handle the packet if it passed route ACLs
|
||||
err := m.forwarder.InjectIncomingPacket(packetData)
|
||||
if err != nil {
|
||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||
}
|
||||
|
||||
@ -631,7 +654,7 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
|
||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||
}
|
||||
|
||||
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||
if m.isSpecialICMP(d) {
|
||||
return false
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user