diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index b3430c085..ffd40d098 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -3,6 +3,7 @@ package forwarder import ( "context" "fmt" + "net" log "github.com/sirupsen/logrus" "gvisor.dev/gvisor/pkg/buffer" @@ -30,9 +31,11 @@ type Forwarder struct { udpForwarder *udpForwarder ctx context.Context cancel context.CancelFunc + ip net.IP + netstack bool } -func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) { +func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{ @@ -101,6 +104,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger) (*Forwarder, error) { udpForwarder: newUDPForwarder(logger), ctx: ctx, cancel: cancel, + netstack: netstack, + ip: iface.Address().IP, } tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP) @@ -142,3 +147,10 @@ func (f *Forwarder) Stop() { f.stack.Close() f.stack.Wait() } + +func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { + if f.netstack && f.ip.Equal(addr.AsSlice()) { + return net.IPv4(127, 0, 0, 1) + } + return addr.AsSlice() +} diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index c9fede724..10019f21f 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -27,7 +27,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf } }() - dstIP := net.IP(id.LocalAddress.AsSlice()) + dstIP := f.determineDialAddr(id.LocalAddress) dst := &net.IPAddr{IP: dstIP} // Get the complete ICMP message (header + data) diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index bf5320fe1..efe94bae9 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -17,10 +17,9 @@ import ( func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { id := r.ID() - dstAddr := id.LocalAddress - dstPort := id.LocalPort - dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort) + dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) + f.logger.Trace("forwarder: handling TCP connection %v", id) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) if err != nil { r.Complete(true) diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index 85094baad..a5cba9cb4 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -125,14 +125,13 @@ func (f *udpForwarder) cleanup() { // handleUDP is called by the UDP forwarder for new packets func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { - id := r.ID() - dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort) - if f.ctx.Err() != nil { f.logger.Trace("forwarder: context done, dropping UDP packet") return } + id := r.ID() + f.udpForwarder.RLock() _, exists := f.udpForwarder.conns[id] f.udpForwarder.RUnlock() @@ -141,6 +140,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { return } + dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) if err != nil { f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err) diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 93472210c..11ef68a4d 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -18,6 +18,7 @@ import ( "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -57,6 +58,8 @@ type Manager struct { nativeRouter bool // indicates whether we track outbound connections stateful bool + // indicates whether wireguards runs in netstack mode + netstack bool localipmanager *localIPManager @@ -130,7 +133,8 @@ func create(iface common.IFaceMapper) (*Manager, error) { localipmanager: newLocalIPManager(), stateful: !disableConntrack, // TODO: support changing log level from logrus - logger: nblog.NewFromLogrus(log.StandardLogger()), + logger: nblog.NewFromLogrus(log.StandardLogger()), + netstack: netstack.IsEnabled(), } if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { @@ -157,7 +161,7 @@ func create(iface common.IFaceMapper) (*Manager, error) { // Only supported in userspace mode as we need to inject packets back into wireguard directly } else { var err error - m.forwarder, err = forwarder.New(iface, m.logger) + m.forwarder, err = forwarder.New(iface, m.logger, m.netstack) if err != nil { log.Errorf("failed to create forwarder: %v", err) } else { @@ -505,16 +509,36 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { // Handle local traffic - apply peer ACLs if m.localipmanager.IsLocalIP(dstIP) { - drop := m.applyRules(srcIP, packetData, rules, d) - if drop { + if m.peerACLsBlock(srcIP, packetData, rules, d) { m.logger.Trace("Dropping local packet: src=%s dst=%s rules=denied", srcIP, dstIP) + return true } - return drop + + // if running in netstack mode we need to pass this to the forwarder + if m.netstack { + m.logger.Trace("Passing local packet to netstack: src=%s dst=%s", srcIP, dstIP) + m.handleNetstackLocalTraffic(packetData) + // don't process this packet further + return true + } + + return false } + return m.handleRoutedTraffic(d, srcIP, dstIP, packetData) } +func (m *Manager) handleNetstackLocalTraffic(packetData []byte) { + if m.forwarder == nil { + return + } + + if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { + m.logger.Error("Failed to inject local packet: %v", err) + } +} + func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { // Drop if routing is disabled if !m.routingEnabled { @@ -540,8 +564,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat } // Let forwarder handle the packet if it passed route ACLs - err := m.forwarder.InjectIncomingPacket(packetData) - if err != nil { + if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { m.logger.Error("Failed to inject incoming packet: %v", err) } @@ -631,7 +654,7 @@ func (m *Manager) isSpecialICMP(d *decoder) bool { icmpType == layers.ICMPv4TypeTimeExceeded } -func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { +func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { if m.isSpecialICMP(d) { return false }