mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 17:58:02 +02:00
Add userspace routing
This commit is contained in:
parent
b3c87cb5d1
commit
4199da4a45
@ -1,6 +1,8 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@ -10,4 +12,6 @@ type IFaceMapper interface {
|
||||
Address() device.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(device.PacketFilter) error
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetWGDevice() *wgdevice.Device
|
||||
}
|
||||
|
16
client/firewall/uspfilter/common/iface.go
Normal file
16
client/firewall/uspfilter/common/iface.go
Normal file
@ -0,0 +1,16 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
device2 "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
GetWGDevice() *device2.Device
|
||||
GetDevice() *device.FilteredDevice
|
||||
}
|
79
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
79
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
@ -0,0 +1,79 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
)
|
||||
|
||||
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
||||
type endpoint struct {
|
||||
dispatcher stack.NetworkDispatcher
|
||||
device *wgdevice.Device
|
||||
mtu uint32
|
||||
}
|
||||
|
||||
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
e.dispatcher = dispatcher
|
||||
}
|
||||
|
||||
func (e *endpoint) IsAttached() bool {
|
||||
return e.dispatcher != nil
|
||||
}
|
||||
|
||||
func (e *endpoint) MTU() uint32 {
|
||||
return e.mtu
|
||||
}
|
||||
|
||||
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
return stack.CapabilityNone
|
||||
}
|
||||
|
||||
func (e *endpoint) MaxHeaderLength() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||
var written int
|
||||
for _, pkt := range pkts.AsSlice() {
|
||||
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
|
||||
|
||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the packet through WireGuard
|
||||
address := netHeader.DestinationAddress()
|
||||
|
||||
// TODO: handle dest ip addresses outside our network
|
||||
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||
if err != nil {
|
||||
log.Errorf("CreateOutboundPacket: %v", err)
|
||||
continue
|
||||
}
|
||||
written++
|
||||
}
|
||||
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (e *endpoint) Wait() {
|
||||
}
|
||||
|
||||
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
|
||||
return header.ARPHardwareNone
|
||||
}
|
||||
|
||||
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
||||
}
|
||||
|
||||
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
120
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
120
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
@ -0,0 +1,120 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
)
|
||||
|
||||
const (
|
||||
receiveWindow = 32768
|
||||
maxInFlight = 1024
|
||||
)
|
||||
|
||||
type Forwarder struct {
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
tcp.NewProtocol,
|
||||
udp.NewProtocol,
|
||||
},
|
||||
HandleLocal: false,
|
||||
})
|
||||
|
||||
mtu, err := iface.GetDevice().MTU()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get MTU: %w", err)
|
||||
}
|
||||
nicID := tcpip.NICID(1)
|
||||
endpoint := &endpoint{
|
||||
device: iface.GetWGDevice(),
|
||||
mtu: uint32(mtu),
|
||||
}
|
||||
|
||||
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||
return nil, fmt.Errorf("failed to create NIC: %w", err)
|
||||
}
|
||||
|
||||
_, bits := iface.Address().Network.Mask.Size()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||
PrefixLen: bits,
|
||||
},
|
||||
}
|
||||
|
||||
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to add protocol address: %w", err)
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating default subnet: %w", err)
|
||||
}
|
||||
|
||||
if s.SetPromiscuousMode(nicID, true); err != nil {
|
||||
return nil, fmt.Errorf("set promiscuous mode: %w", err)
|
||||
}
|
||||
if s.SetSpoofing(nicID, true); err != nil {
|
||||
return nil, fmt.Errorf("set spoofing: %w", err)
|
||||
}
|
||||
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{
|
||||
Destination: defaultSubnet,
|
||||
NIC: nicID,
|
||||
},
|
||||
})
|
||||
|
||||
f := &Forwarder{
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(),
|
||||
}
|
||||
|
||||
// Set up TCP forwarder
|
||||
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
|
||||
// Set up UDP forwarder
|
||||
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
|
||||
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
if len(payload) < header.IPv4MinimumSize {
|
||||
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
defer pkt.DecRef()
|
||||
|
||||
if f.endpoint.dispatcher != nil {
|
||||
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||
}
|
||||
return nil
|
||||
}
|
82
client/firewall/uspfilter/forwarder/tcp.go
Normal file
82
client/firewall/uspfilter/forwarder/tcp.go
Normal file
@ -0,0 +1,82 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
// handleTCP is called by the TCP forwarder for new connections.
|
||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
dstAddr := id.LocalAddress
|
||||
dstPort := id.LocalPort
|
||||
dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort)
|
||||
|
||||
// Dial the destination first
|
||||
dialer := net.Dialer{}
|
||||
outConn, err := dialer.Dial("tcp", dialAddr)
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
wq := waiter.Queue{}
|
||||
|
||||
ep, err2 := r.CreateEndpoint(&wq)
|
||||
if err2 != nil {
|
||||
if err := outConn.Close(); err != nil {
|
||||
log.Errorf("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
|
||||
// Now that we've successfully connected to the destination,
|
||||
// we can complete the incoming connection
|
||||
r.Complete(false)
|
||||
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
go f.proxyTCP(inConn, outConn)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
|
||||
defer func() {
|
||||
if err := inConn.Close(); err != nil {
|
||||
log.Errorf("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil {
|
||||
log.Errorf("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := io.Copy(outConn, inConn)
|
||||
if err != nil {
|
||||
log.Errorf("proxyTCP: copy error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := io.Copy(inConn, outConn)
|
||||
if err != nil {
|
||||
log.Errorf("proxyTCP: copy error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
153
client/firewall/uspfilter/forwarder/udp.go
Normal file
153
client/firewall/uspfilter/forwarder/udp.go
Normal file
@ -0,0 +1,153 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
const (
|
||||
udpTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
type udpPacketConn struct {
|
||||
conn *gonet.UDPConn
|
||||
outConn net.Conn
|
||||
lastTime time.Time
|
||||
}
|
||||
|
||||
type udpForwarder struct {
|
||||
sync.RWMutex
|
||||
conns map[string]*udpPacketConn
|
||||
}
|
||||
|
||||
func newUDPForwarder() *udpForwarder {
|
||||
f := &udpForwarder{
|
||||
conns: make(map[string]*udpPacketConn),
|
||||
}
|
||||
go f.cleanup()
|
||||
return f
|
||||
}
|
||||
|
||||
// cleanup periodically removes idle UDP connections
|
||||
func (f *udpForwarder) cleanup() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
f.Lock()
|
||||
now := time.Now()
|
||||
for addr, conn := range f.conns {
|
||||
if now.Sub(conn.lastTime) > udpTimeout {
|
||||
conn.conn.Close()
|
||||
conn.outConn.Close()
|
||||
delete(f.conns, addr)
|
||||
}
|
||||
}
|
||||
f.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// handleUDP is called by the UDP forwarder for new packets
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
wq := waiter.Queue{}
|
||||
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
log.Errorf("Create UDP endpoint error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||
|
||||
// Try to get existing connection or create a new one
|
||||
f.udpForwarder.Lock()
|
||||
pConn, exists := f.udpForwarder.conns[dstAddr]
|
||||
if !exists {
|
||||
outConn, err := net.Dial("udp", dstAddr)
|
||||
if err != nil {
|
||||
f.udpForwarder.Unlock()
|
||||
if err := inConn.Close(); err != nil {
|
||||
log.Errorf("forwader: UDP inConn close error: %v", err)
|
||||
}
|
||||
log.Errorf("forwarder> UDP dial error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
pConn = &udpPacketConn{
|
||||
conn: inConn,
|
||||
outConn: outConn,
|
||||
lastTime: time.Now(),
|
||||
}
|
||||
f.udpForwarder.conns[dstAddr] = pConn
|
||||
|
||||
go f.proxyUDP(pConn, dstAddr)
|
||||
}
|
||||
f.udpForwarder.Unlock()
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyUDP(pConn *udpPacketConn, dstAddr string) {
|
||||
defer func() {
|
||||
if err := pConn.conn.Close(); err != nil {
|
||||
log.Errorf("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := pConn.outConn.Close(); err != nil {
|
||||
log.Errorf("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Handle outbound to inbound traffic
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
f.copyUDP(pConn.conn, pConn.outConn, dstAddr, "outbound->inbound")
|
||||
}()
|
||||
|
||||
// Handle inbound to outbound traffic
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
f.copyUDP(pConn.outConn, pConn.conn, dstAddr, "inbound->outbound")
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Clean up the connection from the map
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, dstAddr)
|
||||
f.udpForwarder.Unlock()
|
||||
}
|
||||
|
||||
func (f *Forwarder) copyUDP(dst net.Conn, src net.Conn, dstAddr, direction string) {
|
||||
buffer := make([]byte, 65535)
|
||||
for {
|
||||
n, err := src.Read(buffer)
|
||||
if err != nil {
|
||||
log.Errorf("UDP %s read error: %v", direction, err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = dst.Write(buffer[:n])
|
||||
if err != nil {
|
||||
log.Errorf("UDP %s write error: %v", direction, err)
|
||||
continue
|
||||
}
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
if conn, ok := f.udpForwarder.conns[dstAddr]; ok {
|
||||
conn.lastTime = time.Now()
|
||||
}
|
||||
f.udpForwarder.Unlock()
|
||||
}
|
||||
}
|
@ -2,14 +2,15 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
// PeerRule to handle management of rules
|
||||
type PeerRule struct {
|
||||
id string
|
||||
ip net.IP
|
||||
ipLayer gopacket.LayerType
|
||||
@ -25,6 +26,21 @@ type Rule struct {
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) GetRuleID() string {
|
||||
func (r *PeerRule) GetRuleID() string {
|
||||
return r.id
|
||||
}
|
||||
|
||||
type RouteRule struct {
|
||||
id string
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *RouteRule) GetRuleID() string {
|
||||
return r.id
|
||||
}
|
||||
|
@ -14,9 +14,9 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@ -24,34 +24,34 @@ const layerTypeAll = 0
|
||||
|
||||
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||
|
||||
// TODO: Add env var to disable routing
|
||||
|
||||
var (
|
||||
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]Rule
|
||||
type RuleSet map[string]PeerRule
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
outgoingRules map[string]RuleSet
|
||||
incomingRules map[string]RuleSet
|
||||
routeRules map[string]RouteRule
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
wgIface IFaceMapper
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
|
||||
mutex sync.RWMutex
|
||||
|
||||
routingEnabled bool
|
||||
|
||||
stateful bool
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
forwarder *forwarder.Forwarder
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@ -68,11 +68,11 @@ type decoder struct {
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface IFaceMapper) (*Manager, error) {
|
||||
func Create(iface common.IFaceMapper) (*Manager, error) {
|
||||
return create(iface)
|
||||
}
|
||||
|
||||
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
||||
mgr, err := create(iface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -82,7 +82,7 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func create(iface IFaceMapper) (*Manager, error) {
|
||||
func create(iface common.IFaceMapper) (*Manager, error) {
|
||||
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
||||
|
||||
m := &Manager{
|
||||
@ -101,8 +101,11 @@ func create(iface IFaceMapper) (*Manager, error) {
|
||||
},
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
routeRules: make(map[string]RouteRule),
|
||||
wgIface: iface,
|
||||
stateful: !disableConntrack,
|
||||
// TODO: fix
|
||||
routingEnabled: true,
|
||||
}
|
||||
|
||||
// Only initialize trackers if stateful mode is enabled
|
||||
@ -114,8 +117,23 @@ func create(iface IFaceMapper) (*Manager, error) {
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
}
|
||||
|
||||
intf := iface.GetWGDevice()
|
||||
if intf == nil {
|
||||
log.Info("forwarding not supported")
|
||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||
// TODO: Check if native firewall can do the job, in that case just forward everything (restores previous behavior)
|
||||
m.routingEnabled = false
|
||||
} else {
|
||||
var err error
|
||||
m.forwarder, err = forwarder.New(iface)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create forwarder: %v", err)
|
||||
m.routingEnabled = false
|
||||
}
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("set filter: %w", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
@ -161,7 +179,7 @@ func (m *Manager) AddPeerFiltering(
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
r := Rule{
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
@ -217,18 +235,44 @@ func (m *Manager) AddPeerFiltering(
|
||||
return []firewall.Rule{&r}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errRouteNotSupported
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
rule := RouteRule{
|
||||
id: ruleID,
|
||||
sources: sources,
|
||||
destination: destination,
|
||||
proto: proto,
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
}
|
||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
m.routeRules[ruleID] = rule
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := rule.GetRuleID()
|
||||
if _, exists := m.routeRules[ruleID]; !exists {
|
||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||
}
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
|
||||
delete(m.routeRules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
@ -236,7 +280,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
r, ok := rule.(*Rule)
|
||||
r, ok := rule.(*PeerRule)
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
@ -279,7 +323,11 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||
return m.dropFilter(packetData, m.incomingRules)
|
||||
}
|
||||
|
||||
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
|
||||
func (m *Manager) isLocalIP(ip net.IP) bool {
|
||||
// TODO: add other interface IPs and keep track of them
|
||||
return ip.Equal(m.wgIface.Address().IP)
|
||||
}
|
||||
|
||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
@ -300,18 +348,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Always process UDP hooks
|
||||
if d.decoded[1] == layers.LayerTypeUDP {
|
||||
// Track UDP state only if enabled
|
||||
if m.stateful {
|
||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||
}
|
||||
return m.checkUDPHooks(d, dstIP, packetData)
|
||||
}
|
||||
|
||||
// Track other protocols only if stateful mode is enabled
|
||||
// Track all protocols if stateful mode is enabled
|
||||
if m.stateful {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeUDP:
|
||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||
case layers.LayerTypeTCP:
|
||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
||||
case layers.LayerTypeICMPv4:
|
||||
@ -319,6 +360,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Process UDP hooks even if stateful mode is disabled
|
||||
if d.decoded[1] == layers.LayerTypeUDP {
|
||||
return m.checkUDPHooks(d, dstIP, packetData)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@ -409,6 +455,7 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if !m.isValidPacket(d, packetData) {
|
||||
log.Debugf("invalid packet: %v", d.decoded)
|
||||
return true
|
||||
}
|
||||
|
||||
@ -418,18 +465,71 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
if !m.isWireguardTraffic(srcIP, dstIP) {
|
||||
return false
|
||||
}
|
||||
// Check if this is local or routed traffic
|
||||
isLocal := m.isLocalIP(dstIP)
|
||||
|
||||
// Check connection state only if enabled
|
||||
// For all inbound traffic, first check if it matches a tracked connection.
|
||||
// This must happen before any other filtering because the packets are statefully tracked.
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle local traffic - apply peer ACLs
|
||||
if isLocal {
|
||||
return m.applyRules(srcIP, packetData, rules, d)
|
||||
}
|
||||
|
||||
// Handle routed traffic
|
||||
// TODO: Handle replies for [routed network -> netbird peer], we don't need to start the forwarder here
|
||||
// We might need to apply NAT
|
||||
// Don't handle routing if not enabled
|
||||
if !m.routingEnabled {
|
||||
return true
|
||||
}
|
||||
|
||||
// Get protocol and ports for route ACL check
|
||||
proto := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
// Check route ACLs
|
||||
if !m.checkRouteACLs(srcIP, dstIP, proto, srcPort, dstPort) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Let forwarder handle the packet if it passed route ACLs
|
||||
err := m.forwarder.InjectIncomingPacket(packetData)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to inject incoming packet: %v", err)
|
||||
}
|
||||
|
||||
// Default: drop
|
||||
return true
|
||||
}
|
||||
|
||||
func getProtocolFromPacket(d *decoder) firewall.Protocol {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return firewall.ProtocolTCP
|
||||
case layers.LayerTypeUDP:
|
||||
return firewall.ProtocolUDP
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return firewall.ProtocolICMP
|
||||
default:
|
||||
return firewall.ProtocolALL
|
||||
}
|
||||
}
|
||||
|
||||
func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
return uint16(d.udp.SrcPort), uint16(d.udp.DstPort)
|
||||
default:
|
||||
return 0, 0
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
log.Tracef("couldn't decode layer, err: %s", err)
|
||||
@ -498,7 +598,7 @@ func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]R
|
||||
return true
|
||||
}
|
||||
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
for _, rule := range rules {
|
||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||
@ -547,6 +647,56 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (m *Manager) checkRouteACLs(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
||||
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
||||
|
||||
// Default deny if no rules match
|
||||
matched := false
|
||||
|
||||
for _, rule := range m.routeRules {
|
||||
// Check destination
|
||||
if !rule.destination.Contains(dstAddr) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if source matches any source prefix
|
||||
sourceMatched := false
|
||||
for _, src := range rule.sources {
|
||||
if src.Contains(srcAddr) {
|
||||
sourceMatched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sourceMatched {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check protocol
|
||||
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check ports if specified
|
||||
if rule.srcPort != nil && rule.srcPort.Values[0] != int(srcPort) {
|
||||
continue
|
||||
}
|
||||
if rule.dstPort != nil && rule.dstPort.Values[0] != int(dstPort) {
|
||||
continue
|
||||
}
|
||||
|
||||
matched = true
|
||||
if rule.action == firewall.ActionDrop {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return matched
|
||||
}
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
m.wgNetwork = network
|
||||
@ -558,7 +708,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
func (m *Manager) AddUDPPacketHook(
|
||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
||||
) string {
|
||||
r := Rule{
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
@ -577,12 +727,12 @@ func (m *Manager) AddUDPPacketHook(
|
||||
if in {
|
||||
r.direction = firewall.RuleDirectionIN
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
}
|
||||
|
@ -194,7 +194,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
|
||||
var addedRule Rule
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||
|
@ -3,6 +3,8 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
@ -15,4 +17,5 @@ type WGTunDevice interface {
|
||||
DeviceName() string
|
||||
Close() error
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
}
|
||||
|
@ -117,6 +117,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *TunDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||
func (t *TunDevice) assignAddr() error {
|
||||
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@ -153,6 +154,11 @@ func (t *TunKernelDevice) DeviceName() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
// Device returns the wireguard device, not applicable for kernel devices
|
||||
func (t *TunKernelDevice) Device() *device.Device {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
||||
return nil
|
||||
}
|
||||
|
@ -117,3 +117,8 @@ func (t *TunNetstackDevice) DeviceName() string {
|
||||
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *TunNetstackDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
@ -128,6 +128,11 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *USPDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface
|
||||
func (t *USPDevice) assignAddr() error {
|
||||
link := newWGLink(t.name)
|
||||
|
@ -150,6 +150,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
return t.filteredDevice
|
||||
}
|
||||
|
||||
// Device returns the wireguard device
|
||||
func (t *TunDevice) Device() *device.Device {
|
||||
return t.device
|
||||
}
|
||||
|
||||
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
||||
if t.nativeTunDevice == nil {
|
||||
return "", fmt.Errorf("interface has not been initialized yet")
|
||||
|
@ -11,6 +11,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@ -203,6 +205,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice {
|
||||
return w.tun.FilteredDevice()
|
||||
}
|
||||
|
||||
// GetWGDevice returns the WireGuard device
|
||||
func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
||||
return w.tun.Device()
|
||||
}
|
||||
|
||||
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
||||
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return w.configurer.GetStats(peerKey)
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
@ -32,5 +33,6 @@ type IWGIface interface {
|
||||
SetFilter(filter device.PacketFilter) error
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
@ -30,6 +31,7 @@ type IWGIface interface {
|
||||
SetFilter(filter device.PacketFilter) error
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
|
@ -383,10 +383,10 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
||||
if newRoute.Peer == m.pubKey {
|
||||
ownNetworkIDs[haID] = true
|
||||
// only linux is supported for now
|
||||
if runtime.GOOS != "linux" {
|
||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
||||
continue
|
||||
}
|
||||
//if runtime.GOOS != "linux" {
|
||||
// log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
||||
// continue
|
||||
//}
|
||||
newServerRoutesMap[newRoute.ID] = newRoute
|
||||
}
|
||||
}
|
||||
|
2
go.mod
2
go.mod
@ -99,6 +99,7 @@ require (
|
||||
gorm.io/driver/postgres v1.5.7
|
||||
gorm.io/driver/sqlite v1.5.3
|
||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1
|
||||
nhooyr.io/websocket v1.8.11
|
||||
)
|
||||
|
||||
@ -229,7 +230,6 @@ require (
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 // indirect
|
||||
k8s.io/apimachinery v0.26.2 // indirect
|
||||
)
|
||||
|
||||
|
2
go.sum
2
go.sum
@ -527,8 +527,6 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
|
Loading…
x
Reference in New Issue
Block a user