mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 02:08:40 +02:00
Support local IPs in netstack mode
This commit is contained in:
parent
2b8092dfad
commit
911f86ded8
@ -3,6 +3,7 @@ package forwarder
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gvisor.dev/gvisor/pkg/buffer"
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
@ -30,9 +31,11 @@ type Forwarder struct {
|
|||||||
udpForwarder *udpForwarder
|
udpForwarder *udpForwarder
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
ip net.IP
|
||||||
|
netstack bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) {
|
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
||||||
s := stack.New(stack.Options{
|
s := stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
@ -101,6 +104,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) {
|
|||||||
udpForwarder: newUDPForwarder(logger),
|
udpForwarder: newUDPForwarder(logger),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
|
netstack: netstack,
|
||||||
|
ip: iface.Address().IP,
|
||||||
}
|
}
|
||||||
|
|
||||||
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
|
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
|
||||||
@ -142,3 +147,10 @@ func (f *Forwarder) Stop() {
|
|||||||
f.stack.Close()
|
f.stack.Close()
|
||||||
f.stack.Wait()
|
f.stack.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||||
|
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||||
|
return net.IPv4(127, 0, 0, 1)
|
||||||
|
}
|
||||||
|
return addr.AsSlice()
|
||||||
|
}
|
||||||
|
@ -27,7 +27,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
dstIP := net.IP(id.LocalAddress.AsSlice())
|
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||||
dst := &net.IPAddr{IP: dstIP}
|
dst := &net.IPAddr{IP: dstIP}
|
||||||
|
|
||||||
// Get the complete ICMP message (header + data)
|
// Get the complete ICMP message (header + data)
|
||||||
|
@ -17,10 +17,9 @@ import (
|
|||||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||||
id := r.ID()
|
id := r.ID()
|
||||||
|
|
||||||
dstAddr := id.LocalAddress
|
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
dstPort := id.LocalPort
|
|
||||||
dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort)
|
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: handling TCP connection %v", id)
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.Complete(true)
|
r.Complete(true)
|
||||||
|
@ -125,14 +125,13 @@ func (f *udpForwarder) cleanup() {
|
|||||||
|
|
||||||
// handleUDP is called by the UDP forwarder for new packets
|
// handleUDP is called by the UDP forwarder for new packets
|
||||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||||
id := r.ID()
|
|
||||||
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
|
|
||||||
|
|
||||||
if f.ctx.Err() != nil {
|
if f.ctx.Err() != nil {
|
||||||
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
id := r.ID()
|
||||||
|
|
||||||
f.udpForwarder.RLock()
|
f.udpForwarder.RLock()
|
||||||
_, exists := f.udpForwarder.conns[id]
|
_, exists := f.udpForwarder.conns[id]
|
||||||
f.udpForwarder.RUnlock()
|
f.udpForwarder.RUnlock()
|
||||||
@ -141,6 +140,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||||
|
@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -57,6 +58,8 @@ type Manager struct {
|
|||||||
nativeRouter bool
|
nativeRouter bool
|
||||||
// indicates whether we track outbound connections
|
// indicates whether we track outbound connections
|
||||||
stateful bool
|
stateful bool
|
||||||
|
// indicates whether wireguards runs in netstack mode
|
||||||
|
netstack bool
|
||||||
|
|
||||||
localipmanager *localIPManager
|
localipmanager *localIPManager
|
||||||
|
|
||||||
@ -131,6 +134,7 @@ func create(iface common.IFaceMapper) (*Manager, error) {
|
|||||||
stateful: !disableConntrack,
|
stateful: !disableConntrack,
|
||||||
// TODO: support changing log level from logrus
|
// TODO: support changing log level from logrus
|
||||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||||
|
netstack: netstack.IsEnabled(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||||
@ -157,7 +161,7 @@ func create(iface common.IFaceMapper) (*Manager, error) {
|
|||||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||||
} else {
|
} else {
|
||||||
var err error
|
var err error
|
||||||
m.forwarder, err = forwarder.New(iface, m.logger)
|
m.forwarder, err = forwarder.New(iface, m.logger, m.netstack)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create forwarder: %v", err)
|
log.Errorf("failed to create forwarder: %v", err)
|
||||||
} else {
|
} else {
|
||||||
@ -505,16 +509,36 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
|||||||
|
|
||||||
// Handle local traffic - apply peer ACLs
|
// Handle local traffic - apply peer ACLs
|
||||||
if m.localipmanager.IsLocalIP(dstIP) {
|
if m.localipmanager.IsLocalIP(dstIP) {
|
||||||
drop := m.applyRules(srcIP, packetData, rules, d)
|
if m.peerACLsBlock(srcIP, packetData, rules, d) {
|
||||||
if drop {
|
|
||||||
m.logger.Trace("Dropping local packet: src=%s dst=%s rules=denied",
|
m.logger.Trace("Dropping local packet: src=%s dst=%s rules=denied",
|
||||||
srcIP, dstIP)
|
srcIP, dstIP)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
return drop
|
|
||||||
|
// if running in netstack mode we need to pass this to the forwarder
|
||||||
|
if m.netstack {
|
||||||
|
m.logger.Trace("Passing local packet to netstack: src=%s dst=%s", srcIP, dstIP)
|
||||||
|
m.handleNetstackLocalTraffic(packetData)
|
||||||
|
// don't process this packet further
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) {
|
||||||
|
if m.forwarder == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||||
|
m.logger.Error("Failed to inject local packet: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||||
// Drop if routing is disabled
|
// Drop if routing is disabled
|
||||||
if !m.routingEnabled {
|
if !m.routingEnabled {
|
||||||
@ -540,8 +564,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Let forwarder handle the packet if it passed route ACLs
|
// Let forwarder handle the packet if it passed route ACLs
|
||||||
err := m.forwarder.InjectIncomingPacket(packetData)
|
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
||||||
if err != nil {
|
|
||||||
m.logger.Error("Failed to inject incoming packet: %v", err)
|
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -631,7 +654,7 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
|
|||||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
||||||
if m.isSpecialICMP(d) {
|
if m.isSpecialICMP(d) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user