Add userspace routing

This commit is contained in:
Viktor Liu 2024-12-26 15:07:27 +01:00
parent b3c87cb5d1
commit 4199da4a45
21 changed files with 712 additions and 54 deletions

View File

@ -1,6 +1,8 @@
package firewall package firewall
import ( import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@ -10,4 +12,6 @@ type IFaceMapper interface {
Address() device.WGAddress Address() device.WGAddress
IsUserspaceBind() bool IsUserspaceBind() bool
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device
} }

View File

@ -0,0 +1,16 @@
package common
import (
device2 "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(device.PacketFilter) error
Address() iface.WGAddress
GetWGDevice() *device2.Device
GetDevice() *device.FilteredDevice
}

View File

@ -0,0 +1,79 @@
package forwarder
import (
log "github.com/sirupsen/logrus"
wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
type endpoint struct {
dispatcher stack.NetworkDispatcher
device *wgdevice.Device
mtu uint32
}
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
func (e *endpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (e *endpoint) MTU() uint32 {
return e.mtu
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
}
func (e *endpoint) MaxHeaderLength() uint16 {
return 0
}
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int
for _, pkt := range pkts.AsSlice() {
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil {
continue
}
// Send the packet through WireGuard
address := netHeader.DestinationAddress()
// TODO: handle dest ip addresses outside our network
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
if err != nil {
log.Errorf("CreateOutboundPacket: %v", err)
continue
}
written++
}
return written, nil
}
func (e *endpoint) Wait() {
}
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
}
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true
}

View File

@ -0,0 +1,120 @@
package forwarder
import (
"fmt"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
)
const (
receiveWindow = 32768
maxInFlight = 1024
)
type Forwarder struct {
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
}
func New(iface common.IFaceMapper) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
},
HandleLocal: false,
})
mtu, err := iface.GetDevice().MTU()
if err != nil {
return nil, fmt.Errorf("get MTU: %w", err)
}
nicID := tcpip.NICID(1)
endpoint := &endpoint{
device: iface.GetWGDevice(),
mtu: uint32(mtu),
}
if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("failed to create NIC: %w", err)
}
_, bits := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
PrefixLen: bits,
},
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("failed to add protocol address: %w", err)
}
defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
)
if err != nil {
return nil, fmt.Errorf("creating default subnet: %w", err)
}
if s.SetPromiscuousMode(nicID, true); err != nil {
return nil, fmt.Errorf("set promiscuous mode: %w", err)
}
if s.SetSpoofing(nicID, true); err != nil {
return nil, fmt.Errorf("set spoofing: %w", err)
}
s.SetRouteTable([]tcpip.Route{
{
Destination: defaultSubnet,
NIC: nicID,
},
})
f := &Forwarder{
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(),
}
// Set up TCP forwarder
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
// Set up UDP forwarder
udpForwarder := udp.NewForwarder(s, f.handleUDP)
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
return f, nil
}
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("packet too small: %d bytes", len(payload))
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload),
})
defer pkt.DecRef()
if f.endpoint.dispatcher != nil {
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
}
return nil
}

View File

@ -0,0 +1,82 @@
package forwarder
import (
"fmt"
"io"
"net"
"sync"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
// handleTCP is called by the TCP forwarder for new connections.
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
dstAddr := id.LocalAddress
dstPort := id.LocalPort
dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort)
// Dial the destination first
dialer := net.Dialer{}
outConn, err := dialer.Dial("tcp", dialAddr)
if err != nil {
r.Complete(true)
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, err2 := r.CreateEndpoint(&wq)
if err2 != nil {
if err := outConn.Close(); err != nil {
log.Errorf("forwarder: outConn close error: %v", err)
}
r.Complete(true)
return
}
// Now that we've successfully connected to the destination,
// we can complete the incoming connection
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
go f.proxyTCP(inConn, outConn)
}
func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
defer func() {
if err := inConn.Close(); err != nil {
log.Errorf("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil {
log.Errorf("forwarder: outConn close error: %v", err)
}
}()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, err := io.Copy(outConn, inConn)
if err != nil {
log.Errorf("proxyTCP: copy error: %v", err)
}
}()
go func() {
defer wg.Done()
_, err := io.Copy(inConn, outConn)
if err != nil {
log.Errorf("proxyTCP: copy error: %v", err)
}
}()
wg.Wait()
}

View File

@ -0,0 +1,153 @@
package forwarder
import (
"fmt"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
udpTimeout = 60 * time.Second
)
type udpPacketConn struct {
conn *gonet.UDPConn
outConn net.Conn
lastTime time.Time
}
type udpForwarder struct {
sync.RWMutex
conns map[string]*udpPacketConn
}
func newUDPForwarder() *udpForwarder {
f := &udpForwarder{
conns: make(map[string]*udpPacketConn),
}
go f.cleanup()
return f
}
// cleanup periodically removes idle UDP connections
func (f *udpForwarder) cleanup() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
f.Lock()
now := time.Now()
for addr, conn := range f.conns {
if now.Sub(conn.lastTime) > udpTimeout {
conn.conn.Close()
conn.outConn.Close()
delete(f.conns, addr)
}
}
f.Unlock()
}
}
// handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
id := r.ID()
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, err := r.CreateEndpoint(&wq)
if err != nil {
log.Errorf("Create UDP endpoint error: %v", err)
return
}
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
// Try to get existing connection or create a new one
f.udpForwarder.Lock()
pConn, exists := f.udpForwarder.conns[dstAddr]
if !exists {
outConn, err := net.Dial("udp", dstAddr)
if err != nil {
f.udpForwarder.Unlock()
if err := inConn.Close(); err != nil {
log.Errorf("forwader: UDP inConn close error: %v", err)
}
log.Errorf("forwarder> UDP dial error: %v", err)
return
}
pConn = &udpPacketConn{
conn: inConn,
outConn: outConn,
lastTime: time.Now(),
}
f.udpForwarder.conns[dstAddr] = pConn
go f.proxyUDP(pConn, dstAddr)
}
f.udpForwarder.Unlock()
}
func (f *Forwarder) proxyUDP(pConn *udpPacketConn, dstAddr string) {
defer func() {
if err := pConn.conn.Close(); err != nil {
log.Errorf("forwarder: inConn close error: %v", err)
}
if err := pConn.outConn.Close(); err != nil {
log.Errorf("forwarder: outConn close error: %v", err)
}
}()
var wg sync.WaitGroup
wg.Add(2)
// Handle outbound to inbound traffic
go func() {
defer wg.Done()
f.copyUDP(pConn.conn, pConn.outConn, dstAddr, "outbound->inbound")
}()
// Handle inbound to outbound traffic
go func() {
defer wg.Done()
f.copyUDP(pConn.outConn, pConn.conn, dstAddr, "inbound->outbound")
}()
wg.Wait()
// Clean up the connection from the map
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, dstAddr)
f.udpForwarder.Unlock()
}
func (f *Forwarder) copyUDP(dst net.Conn, src net.Conn, dstAddr, direction string) {
buffer := make([]byte, 65535)
for {
n, err := src.Read(buffer)
if err != nil {
log.Errorf("UDP %s read error: %v", direction, err)
return
}
_, err = dst.Write(buffer[:n])
if err != nil {
log.Errorf("UDP %s write error: %v", direction, err)
continue
}
f.udpForwarder.Lock()
if conn, ok := f.udpForwarder.conns[dstAddr]; ok {
conn.lastTime = time.Now()
}
f.udpForwarder.Unlock()
}
}

View File

@ -2,14 +2,15 @@ package uspfilter
import ( import (
"net" "net"
"net/netip"
"github.com/google/gopacket" "github.com/google/gopacket"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
) )
// Rule to handle management of rules // PeerRule to handle management of rules
type Rule struct { type PeerRule struct {
id string id string
ip net.IP ip net.IP
ipLayer gopacket.LayerType ipLayer gopacket.LayerType
@ -25,6 +26,21 @@ type Rule struct {
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id
func (r *Rule) GetRuleID() string { func (r *PeerRule) GetRuleID() string {
return r.id
}
type RouteRule struct {
id string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action
}
// GetRuleID returns the rule id
func (r *RouteRule) GetRuleID() string {
return r.id return r.id
} }

View File

@ -14,9 +14,9 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@ -24,34 +24,34 @@ const layerTypeAll = 0
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK" const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
// TODO: Add env var to disable routing
var ( var (
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
) )
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(device.PacketFilter) error
Address() iface.WGAddress
}
// RuleSet is a set of rules grouped by a string key // RuleSet is a set of rules grouped by a string key
type RuleSet map[string]Rule type RuleSet map[string]PeerRule
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
outgoingRules map[string]RuleSet outgoingRules map[string]RuleSet
incomingRules map[string]RuleSet incomingRules map[string]RuleSet
routeRules map[string]RouteRule
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
wgIface IFaceMapper wgIface common.IFaceMapper
nativeFirewall firewall.Manager nativeFirewall firewall.Manager
mutex sync.RWMutex mutex sync.RWMutex
routingEnabled bool
stateful bool stateful bool
udpTracker *conntrack.UDPTracker udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker tcpTracker *conntrack.TCPTracker
forwarder *forwarder.Forwarder
} }
// decoder for packages // decoder for packages
@ -68,11 +68,11 @@ type decoder struct {
} }
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface IFaceMapper) (*Manager, error) { func Create(iface common.IFaceMapper) (*Manager, error) {
return create(iface) return create(iface)
} }
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) { func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
mgr, err := create(iface) mgr, err := create(iface)
if err != nil { if err != nil {
return nil, err return nil, err
@ -82,7 +82,7 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager
return mgr, nil return mgr, nil
} }
func create(iface IFaceMapper) (*Manager, error) { func create(iface common.IFaceMapper) (*Manager, error) {
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
m := &Manager{ m := &Manager{
@ -101,8 +101,11 @@ func create(iface IFaceMapper) (*Manager, error) {
}, },
outgoingRules: make(map[string]RuleSet), outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet),
routeRules: make(map[string]RouteRule),
wgIface: iface, wgIface: iface,
stateful: !disableConntrack, stateful: !disableConntrack,
// TODO: fix
routingEnabled: true,
} }
// Only initialize trackers if stateful mode is enabled // Only initialize trackers if stateful mode is enabled
@ -114,8 +117,23 @@ func create(iface IFaceMapper) (*Manager, error) {
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
} }
intf := iface.GetWGDevice()
if intf == nil {
log.Info("forwarding not supported")
// Only supported in userspace mode as we need to inject packets back into wireguard directly
// TODO: Check if native firewall can do the job, in that case just forward everything (restores previous behavior)
m.routingEnabled = false
} else {
var err error
m.forwarder, err = forwarder.New(iface)
if err != nil {
log.Errorf("failed to create forwarder: %v", err)
m.routingEnabled = false
}
}
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
return nil, err return nil, fmt.Errorf("set filter: %w", err)
} }
return m, nil return m, nil
} }
@ -161,7 +179,7 @@ func (m *Manager) AddPeerFiltering(
ipsetName string, ipsetName string,
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
r := Rule{ r := PeerRule{
id: uuid.New().String(), id: uuid.New().String(),
ip: ip, ip: ip,
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
@ -217,18 +235,44 @@ func (m *Manager) AddPeerFiltering(
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { func (m *Manager) AddRouteFiltering(
if m.nativeFirewall == nil { sources []netip.Prefix,
return nil, errRouteNotSupported destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := uuid.New().String()
rule := RouteRule{
id: ruleID,
sources: sources,
destination: destination,
proto: proto,
srcPort: sPort,
dstPort: dPort,
action: action,
} }
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
m.routeRules[ruleID] = rule
return &rule, nil
} }
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeFirewall == nil { m.mutex.Lock()
return errRouteNotSupported defer m.mutex.Unlock()
ruleID := rule.GetRuleID()
if _, exists := m.routeRules[ruleID]; !exists {
return fmt.Errorf("route rule not found: %s", ruleID)
} }
return m.nativeFirewall.DeleteRouteRule(rule)
delete(m.routeRules, ruleID)
return nil
} }
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@ -236,7 +280,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
r, ok := rule.(*Rule) r, ok := rule.(*PeerRule)
if !ok { if !ok {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
@ -279,7 +323,11 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
return m.dropFilter(packetData, m.incomingRules) return m.dropFilter(packetData, m.incomingRules)
} }
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP func (m *Manager) isLocalIP(ip net.IP) bool {
// TODO: add other interface IPs and keep track of them
return ip.Equal(m.wgIface.Address().IP)
}
func (m *Manager) processOutgoingHooks(packetData []byte) bool { func (m *Manager) processOutgoingHooks(packetData []byte) bool {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
@ -300,18 +348,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
return false return false
} }
// Always process UDP hooks // Track all protocols if stateful mode is enabled
if d.decoded[1] == layers.LayerTypeUDP {
// Track UDP state only if enabled
if m.stateful {
m.trackUDPOutbound(d, srcIP, dstIP)
}
return m.checkUDPHooks(d, dstIP, packetData)
}
// Track other protocols only if stateful mode is enabled
if m.stateful { if m.stateful {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeUDP:
m.trackUDPOutbound(d, srcIP, dstIP)
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
m.trackTCPOutbound(d, srcIP, dstIP) m.trackTCPOutbound(d, srcIP, dstIP)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
@ -319,6 +360,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
} }
} }
// Process UDP hooks even if stateful mode is disabled
if d.decoded[1] == layers.LayerTypeUDP {
return m.checkUDPHooks(d, dstIP, packetData)
}
return false return false
} }
@ -409,6 +455,7 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
defer m.decoders.Put(d) defer m.decoders.Put(d)
if !m.isValidPacket(d, packetData) { if !m.isValidPacket(d, packetData) {
log.Debugf("invalid packet: %v", d.decoded)
return true return true
} }
@ -418,16 +465,69 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
return true return true
} }
if !m.isWireguardTraffic(srcIP, dstIP) { // Check if this is local or routed traffic
return false isLocal := m.isLocalIP(dstIP)
}
// Check connection state only if enabled // For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked.
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
return false return false
} }
return m.applyRules(srcIP, packetData, rules, d) // Handle local traffic - apply peer ACLs
if isLocal {
return m.applyRules(srcIP, packetData, rules, d)
}
// Handle routed traffic
// TODO: Handle replies for [routed network -> netbird peer], we don't need to start the forwarder here
// We might need to apply NAT
// Don't handle routing if not enabled
if !m.routingEnabled {
return true
}
// Get protocol and ports for route ACL check
proto := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
// Check route ACLs
if !m.checkRouteACLs(srcIP, dstIP, proto, srcPort, dstPort) {
return true
}
// Let forwarder handle the packet if it passed route ACLs
err := m.forwarder.InjectIncomingPacket(packetData)
if err != nil {
log.Errorf("Failed to inject incoming packet: %v", err)
}
// Default: drop
return true
}
func getProtocolFromPacket(d *decoder) firewall.Protocol {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return firewall.ProtocolTCP
case layers.LayerTypeUDP:
return firewall.ProtocolUDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return firewall.ProtocolICMP
default:
return firewall.ProtocolALL
}
}
func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
return uint16(d.udp.SrcPort), uint16(d.udp.DstPort)
default:
return 0, 0
}
} }
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
@ -498,7 +598,7 @@ func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]R
return true return true
} }
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) { func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
payloadLayer := d.decoded[1] payloadLayer := d.decoded[1]
for _, rule := range rules { for _, rule := range rules {
if rule.matchByIP && !ip.Equal(rule.ip) { if rule.matchByIP && !ip.Equal(rule.ip) {
@ -547,6 +647,56 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
return false, false return false, false
} }
func (m *Manager) checkRouteACLs(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
// Default deny if no rules match
matched := false
for _, rule := range m.routeRules {
// Check destination
if !rule.destination.Contains(dstAddr) {
continue
}
// Check if source matches any source prefix
sourceMatched := false
for _, src := range rule.sources {
if src.Contains(srcAddr) {
sourceMatched = true
break
}
}
if !sourceMatched {
continue
}
// Check protocol
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
continue
}
// Check ports if specified
if rule.srcPort != nil && rule.srcPort.Values[0] != int(srcPort) {
continue
}
if rule.dstPort != nil && rule.dstPort.Values[0] != int(dstPort) {
continue
}
matched = true
if rule.action == firewall.ActionDrop {
return false
}
}
return matched
}
// SetNetwork of the wireguard interface to which filtering applied // SetNetwork of the wireguard interface to which filtering applied
func (m *Manager) SetNetwork(network *net.IPNet) { func (m *Manager) SetNetwork(network *net.IPNet) {
m.wgNetwork = network m.wgNetwork = network
@ -558,7 +708,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
func (m *Manager) AddUDPPacketHook( func (m *Manager) AddUDPPacketHook(
in bool, ip net.IP, dPort uint16, hook func([]byte) bool, in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
) string { ) string {
r := Rule{ r := PeerRule{
id: uuid.New().String(), id: uuid.New().String(),
ip: ip, ip: ip,
protoLayer: layers.LayerTypeUDP, protoLayer: layers.LayerTypeUDP,
@ -577,12 +727,12 @@ func (m *Manager) AddUDPPacketHook(
if in { if in {
r.direction = firewall.RuleDirectionIN r.direction = firewall.RuleDirectionIN
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]Rule) m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
} }
m.incomingRules[r.ip.String()][r.id] = r m.incomingRules[r.ip.String()][r.id] = r
} else { } else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok { if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(map[string]Rule) m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
} }
m.outgoingRules[r.ip.String()][r.id] = r m.outgoingRules[r.ip.String()][r.id] = r
} }

View File

@ -194,7 +194,7 @@ func TestAddUDPPacketHook(t *testing.T) {
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule Rule var addedRule PeerRule
if tt.in { if tt.in {
if len(manager.incomingRules[tt.ip.String()]) != 1 { if len(manager.incomingRules[tt.ip.String()]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))

View File

@ -3,6 +3,8 @@
package iface package iface
import ( import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@ -15,4 +17,5 @@ type WGTunDevice interface {
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
} }

View File

@ -117,6 +117,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device
func (t *TunDevice) Device() *device.Device {
return t.device
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided // assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *TunDevice) assignAddr() error { func (t *TunDevice) assignAddr() error {
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())

View File

@ -9,6 +9,7 @@ import (
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@ -153,6 +154,11 @@ func (t *TunKernelDevice) DeviceName() string {
return t.name return t.name
} }
// Device returns the wireguard device, not applicable for kernel devices
func (t *TunKernelDevice) Device() *device.Device {
return nil
}
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice { func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
return nil return nil
} }

View File

@ -117,3 +117,8 @@ func (t *TunNetstackDevice) DeviceName() string {
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice { func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device
func (t *TunNetstackDevice) Device() *device.Device {
return t.device
}

View File

@ -128,6 +128,11 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device
func (t *USPDevice) Device() *device.Device {
return t.device
}
// assignAddr Adds IP address to the tunnel interface // assignAddr Adds IP address to the tunnel interface
func (t *USPDevice) assignAddr() error { func (t *USPDevice) assignAddr() error {
link := newWGLink(t.name) link := newWGLink(t.name)

View File

@ -150,6 +150,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device
func (t *TunDevice) Device() *device.Device {
return t.device
}
func (t *TunDevice) GetInterfaceGUIDString() (string, error) { func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
if t.nativeTunDevice == nil { if t.nativeTunDevice == nil {
return "", fmt.Errorf("interface has not been initialized yet") return "", fmt.Errorf("interface has not been initialized yet")

View File

@ -11,6 +11,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@ -203,6 +205,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice {
return w.tun.FilteredDevice() return w.tun.FilteredDevice()
} }
// GetWGDevice returns the WireGuard device
func (w *WGIface) GetWGDevice() *wgdevice.Device {
return w.tun.Device()
}
// GetStats returns the last handshake time, rx and tx bytes for the given peer // GetStats returns the last handshake time, rx and tx bytes for the given peer
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) { func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return w.configurer.GetStats(peerKey) return w.configurer.GetStats(peerKey)

View File

@ -6,6 +6,7 @@ import (
"net" "net"
"time" "time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
@ -32,5 +33,6 @@ type IWGIface interface {
SetFilter(filter device.PacketFilter) error SetFilter(filter device.PacketFilter) error
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device
GetStats(peerKey string) (configurer.WGStats, error) GetStats(peerKey string) (configurer.WGStats, error)
} }

View File

@ -4,6 +4,7 @@ import (
"net" "net"
"time" "time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
@ -30,6 +31,7 @@ type IWGIface interface {
SetFilter(filter device.PacketFilter) error SetFilter(filter device.PacketFilter) error
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device
GetStats(peerKey string) (configurer.WGStats, error) GetStats(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDString() (string, error) GetInterfaceGUIDString() (string, error)
} }

View File

@ -383,10 +383,10 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
if newRoute.Peer == m.pubKey { if newRoute.Peer == m.pubKey {
ownNetworkIDs[haID] = true ownNetworkIDs[haID] = true
// only linux is supported for now // only linux is supported for now
if runtime.GOOS != "linux" { //if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) // log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
continue // continue
} //}
newServerRoutesMap[newRoute.ID] = newRoute newServerRoutesMap[newRoute.ID] = newRoute
} }
} }

2
go.mod
View File

@ -99,6 +99,7 @@ require (
gorm.io/driver/postgres v1.5.7 gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.3 gorm.io/driver/sqlite v1.5.3
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1
nhooyr.io/websocket v1.8.11 nhooyr.io/websocket v1.8.11
) )
@ -229,7 +230,6 @@ require (
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 // indirect
k8s.io/apimachinery v0.26.2 // indirect k8s.io/apimachinery v0.26.2 // indirect
) )

2
go.sum
View File

@ -527,8 +527,6 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=