mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 17:58:02 +02:00
Add userspace routing
This commit is contained in:
parent
b3c87cb5d1
commit
4199da4a45
@ -1,6 +1,8 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -10,4 +12,6 @@ type IFaceMapper interface {
|
|||||||
Address() device.WGAddress
|
Address() device.WGAddress
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
SetFilter(device.PacketFilter) error
|
SetFilter(device.PacketFilter) error
|
||||||
|
GetDevice() *device.FilteredDevice
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
}
|
}
|
||||||
|
16
client/firewall/uspfilter/common/iface.go
Normal file
16
client/firewall/uspfilter/common/iface.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
device2 "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
|
type IFaceMapper interface {
|
||||||
|
SetFilter(device.PacketFilter) error
|
||||||
|
Address() iface.WGAddress
|
||||||
|
GetWGDevice() *device2.Device
|
||||||
|
GetDevice() *device.FilteredDevice
|
||||||
|
}
|
79
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
79
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
)
|
||||||
|
|
||||||
|
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
||||||
|
type endpoint struct {
|
||||||
|
dispatcher stack.NetworkDispatcher
|
||||||
|
device *wgdevice.Device
|
||||||
|
mtu uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||||
|
e.dispatcher = dispatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) IsAttached() bool {
|
||||||
|
return e.dispatcher != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) MTU() uint32 {
|
||||||
|
return e.mtu
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||||
|
return stack.CapabilityNone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) MaxHeaderLength() uint16 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||||
|
var written int
|
||||||
|
for _, pkt := range pkts.AsSlice() {
|
||||||
|
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
|
||||||
|
|
||||||
|
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||||
|
if data == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the packet through WireGuard
|
||||||
|
address := netHeader.DestinationAddress()
|
||||||
|
|
||||||
|
// TODO: handle dest ip addresses outside our network
|
||||||
|
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("CreateOutboundPacket: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
written++
|
||||||
|
}
|
||||||
|
|
||||||
|
return written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Wait() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
|
||||||
|
return header.ARPHardwareNone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||||
|
return true
|
||||||
|
}
|
120
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
120
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
receiveWindow = 32768
|
||||||
|
maxInFlight = 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
type Forwarder struct {
|
||||||
|
stack *stack.Stack
|
||||||
|
endpoint *endpoint
|
||||||
|
udpForwarder *udpForwarder
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(iface common.IFaceMapper) (*Forwarder, error) {
|
||||||
|
s := stack.New(stack.Options{
|
||||||
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||||
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
|
tcp.NewProtocol,
|
||||||
|
udp.NewProtocol,
|
||||||
|
},
|
||||||
|
HandleLocal: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
mtu, err := iface.GetDevice().MTU()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get MTU: %w", err)
|
||||||
|
}
|
||||||
|
nicID := tcpip.NICID(1)
|
||||||
|
endpoint := &endpoint{
|
||||||
|
device: iface.GetWGDevice(),
|
||||||
|
mtu: uint32(mtu),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create NIC: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, bits := iface.Address().Network.Mask.Size()
|
||||||
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
|
Protocol: ipv4.ProtocolNumber,
|
||||||
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||||
|
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||||
|
PrefixLen: bits,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add protocol address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultSubnet, err := tcpip.NewSubnet(
|
||||||
|
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||||
|
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating default subnet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.SetPromiscuousMode(nicID, true); err != nil {
|
||||||
|
return nil, fmt.Errorf("set promiscuous mode: %w", err)
|
||||||
|
}
|
||||||
|
if s.SetSpoofing(nicID, true); err != nil {
|
||||||
|
return nil, fmt.Errorf("set spoofing: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SetRouteTable([]tcpip.Route{
|
||||||
|
{
|
||||||
|
Destination: defaultSubnet,
|
||||||
|
NIC: nicID,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
f := &Forwarder{
|
||||||
|
stack: s,
|
||||||
|
endpoint: endpoint,
|
||||||
|
udpForwarder: newUDPForwarder(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up TCP forwarder
|
||||||
|
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
|
||||||
|
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||||
|
|
||||||
|
// Set up UDP forwarder
|
||||||
|
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
||||||
|
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||||
|
|
||||||
|
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||||
|
if len(payload) < header.IPv4MinimumSize {
|
||||||
|
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
Payload: buffer.MakeWithData(payload),
|
||||||
|
})
|
||||||
|
defer pkt.DecRef()
|
||||||
|
|
||||||
|
if f.endpoint.dispatcher != nil {
|
||||||
|
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
82
client/firewall/uspfilter/forwarder/tcp.go
Normal file
82
client/firewall/uspfilter/forwarder/tcp.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleTCP is called by the TCP forwarder for new connections.
|
||||||
|
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||||
|
id := r.ID()
|
||||||
|
|
||||||
|
dstAddr := id.LocalAddress
|
||||||
|
dstPort := id.LocalPort
|
||||||
|
dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort)
|
||||||
|
|
||||||
|
// Dial the destination first
|
||||||
|
dialer := net.Dialer{}
|
||||||
|
outConn, err := dialer.Dial("tcp", dialAddr)
|
||||||
|
if err != nil {
|
||||||
|
r.Complete(true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create wait queue for blocking syscalls
|
||||||
|
wq := waiter.Queue{}
|
||||||
|
|
||||||
|
ep, err2 := r.CreateEndpoint(&wq)
|
||||||
|
if err2 != nil {
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
log.Errorf("forwarder: outConn close error: %v", err)
|
||||||
|
}
|
||||||
|
r.Complete(true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we've successfully connected to the destination,
|
||||||
|
// we can complete the incoming connection
|
||||||
|
r.Complete(false)
|
||||||
|
|
||||||
|
inConn := gonet.NewTCPConn(&wq, ep)
|
||||||
|
|
||||||
|
go f.proxyTCP(inConn, outConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) {
|
||||||
|
defer func() {
|
||||||
|
if err := inConn.Close(); err != nil {
|
||||||
|
log.Errorf("forwarder: inConn close error: %v", err)
|
||||||
|
}
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
log.Errorf("forwarder: outConn close error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, err := io.Copy(outConn, inConn)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("proxyTCP: copy error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, err := io.Copy(inConn, outConn)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("proxyTCP: copy error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
153
client/firewall/uspfilter/forwarder/udp.go
Normal file
153
client/firewall/uspfilter/forwarder/udp.go
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
udpTimeout = 60 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type udpPacketConn struct {
|
||||||
|
conn *gonet.UDPConn
|
||||||
|
outConn net.Conn
|
||||||
|
lastTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpForwarder struct {
|
||||||
|
sync.RWMutex
|
||||||
|
conns map[string]*udpPacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPForwarder() *udpForwarder {
|
||||||
|
f := &udpForwarder{
|
||||||
|
conns: make(map[string]*udpPacketConn),
|
||||||
|
}
|
||||||
|
go f.cleanup()
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup periodically removes idle UDP connections
|
||||||
|
func (f *udpForwarder) cleanup() {
|
||||||
|
ticker := time.NewTicker(time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
f.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
for addr, conn := range f.conns {
|
||||||
|
if now.Sub(conn.lastTime) > udpTimeout {
|
||||||
|
conn.conn.Close()
|
||||||
|
conn.outConn.Close()
|
||||||
|
delete(f.conns, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUDP is called by the UDP forwarder for new packets
|
||||||
|
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||||
|
id := r.ID()
|
||||||
|
dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
|
||||||
|
|
||||||
|
// Create wait queue for blocking syscalls
|
||||||
|
wq := waiter.Queue{}
|
||||||
|
|
||||||
|
ep, err := r.CreateEndpoint(&wq)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Create UDP endpoint error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||||
|
|
||||||
|
// Try to get existing connection or create a new one
|
||||||
|
f.udpForwarder.Lock()
|
||||||
|
pConn, exists := f.udpForwarder.conns[dstAddr]
|
||||||
|
if !exists {
|
||||||
|
outConn, err := net.Dial("udp", dstAddr)
|
||||||
|
if err != nil {
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
if err := inConn.Close(); err != nil {
|
||||||
|
log.Errorf("forwader: UDP inConn close error: %v", err)
|
||||||
|
}
|
||||||
|
log.Errorf("forwarder> UDP dial error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pConn = &udpPacketConn{
|
||||||
|
conn: inConn,
|
||||||
|
outConn: outConn,
|
||||||
|
lastTime: time.Now(),
|
||||||
|
}
|
||||||
|
f.udpForwarder.conns[dstAddr] = pConn
|
||||||
|
|
||||||
|
go f.proxyUDP(pConn, dstAddr)
|
||||||
|
}
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) proxyUDP(pConn *udpPacketConn, dstAddr string) {
|
||||||
|
defer func() {
|
||||||
|
if err := pConn.conn.Close(); err != nil {
|
||||||
|
log.Errorf("forwarder: inConn close error: %v", err)
|
||||||
|
}
|
||||||
|
if err := pConn.outConn.Close(); err != nil {
|
||||||
|
log.Errorf("forwarder: outConn close error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
// Handle outbound to inbound traffic
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
f.copyUDP(pConn.conn, pConn.outConn, dstAddr, "outbound->inbound")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Handle inbound to outbound traffic
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
f.copyUDP(pConn.outConn, pConn.conn, dstAddr, "inbound->outbound")
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Clean up the connection from the map
|
||||||
|
f.udpForwarder.Lock()
|
||||||
|
delete(f.udpForwarder.conns, dstAddr)
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) copyUDP(dst net.Conn, src net.Conn, dstAddr, direction string) {
|
||||||
|
buffer := make([]byte, 65535)
|
||||||
|
for {
|
||||||
|
n, err := src.Read(buffer)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("UDP %s read error: %v", direction, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = dst.Write(buffer[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("UDP %s write error: %v", direction, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
f.udpForwarder.Lock()
|
||||||
|
if conn, ok := f.udpForwarder.conns[dstAddr]; ok {
|
||||||
|
conn.lastTime = time.Now()
|
||||||
|
}
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
}
|
||||||
|
}
|
@ -2,14 +2,15 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Rule to handle management of rules
|
// PeerRule to handle management of rules
|
||||||
type Rule struct {
|
type PeerRule struct {
|
||||||
id string
|
id string
|
||||||
ip net.IP
|
ip net.IP
|
||||||
ipLayer gopacket.LayerType
|
ipLayer gopacket.LayerType
|
||||||
@ -25,6 +26,21 @@ type Rule struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *PeerRule) GetRuleID() string {
|
||||||
|
return r.id
|
||||||
|
}
|
||||||
|
|
||||||
|
type RouteRule struct {
|
||||||
|
id string
|
||||||
|
sources []netip.Prefix
|
||||||
|
destination netip.Prefix
|
||||||
|
proto firewall.Protocol
|
||||||
|
srcPort *firewall.Port
|
||||||
|
dstPort *firewall.Port
|
||||||
|
action firewall.Action
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRuleID returns the rule id
|
||||||
|
func (r *RouteRule) GetRuleID() string {
|
||||||
return r.id
|
return r.id
|
||||||
}
|
}
|
||||||
|
@ -14,9 +14,9 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -24,34 +24,34 @@ const layerTypeAll = 0
|
|||||||
|
|
||||||
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||||
|
|
||||||
|
// TODO: Add env var to disable routing
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
||||||
)
|
)
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
|
||||||
type IFaceMapper interface {
|
|
||||||
SetFilter(device.PacketFilter) error
|
|
||||||
Address() iface.WGAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
// RuleSet is a set of rules grouped by a string key
|
// RuleSet is a set of rules grouped by a string key
|
||||||
type RuleSet map[string]Rule
|
type RuleSet map[string]PeerRule
|
||||||
|
|
||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
outgoingRules map[string]RuleSet
|
outgoingRules map[string]RuleSet
|
||||||
incomingRules map[string]RuleSet
|
incomingRules map[string]RuleSet
|
||||||
|
routeRules map[string]RouteRule
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
|
||||||
|
routingEnabled bool
|
||||||
|
|
||||||
stateful bool
|
stateful bool
|
||||||
udpTracker *conntrack.UDPTracker
|
udpTracker *conntrack.UDPTracker
|
||||||
icmpTracker *conntrack.ICMPTracker
|
icmpTracker *conntrack.ICMPTracker
|
||||||
tcpTracker *conntrack.TCPTracker
|
tcpTracker *conntrack.TCPTracker
|
||||||
|
forwarder *forwarder.Forwarder
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@ -68,11 +68,11 @@ type decoder struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface IFaceMapper) (*Manager, error) {
|
func Create(iface common.IFaceMapper) (*Manager, error) {
|
||||||
return create(iface)
|
return create(iface)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
||||||
mgr, err := create(iface)
|
mgr, err := create(iface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -82,7 +82,7 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager
|
|||||||
return mgr, nil
|
return mgr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func create(iface IFaceMapper) (*Manager, error) {
|
func create(iface common.IFaceMapper) (*Manager, error) {
|
||||||
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
@ -101,8 +101,11 @@ func create(iface IFaceMapper) (*Manager, error) {
|
|||||||
},
|
},
|
||||||
outgoingRules: make(map[string]RuleSet),
|
outgoingRules: make(map[string]RuleSet),
|
||||||
incomingRules: make(map[string]RuleSet),
|
incomingRules: make(map[string]RuleSet),
|
||||||
|
routeRules: make(map[string]RouteRule),
|
||||||
wgIface: iface,
|
wgIface: iface,
|
||||||
stateful: !disableConntrack,
|
stateful: !disableConntrack,
|
||||||
|
// TODO: fix
|
||||||
|
routingEnabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only initialize trackers if stateful mode is enabled
|
// Only initialize trackers if stateful mode is enabled
|
||||||
@ -114,8 +117,23 @@ func create(iface IFaceMapper) (*Manager, error) {
|
|||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
intf := iface.GetWGDevice()
|
||||||
|
if intf == nil {
|
||||||
|
log.Info("forwarding not supported")
|
||||||
|
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||||
|
// TODO: Check if native firewall can do the job, in that case just forward everything (restores previous behavior)
|
||||||
|
m.routingEnabled = false
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
m.forwarder, err = forwarder.New(iface)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create forwarder: %v", err)
|
||||||
|
m.routingEnabled = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("set filter: %w", err)
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
@ -161,7 +179,7 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
r := Rule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
ip: ip,
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
@ -217,18 +235,44 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
return []firewall.Rule{&r}, nil
|
return []firewall.Rule{&r}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
func (m *Manager) AddRouteFiltering(
|
||||||
if m.nativeFirewall == nil {
|
sources []netip.Prefix,
|
||||||
return nil, errRouteNotSupported
|
destination netip.Prefix,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
ruleID := uuid.New().String()
|
||||||
|
rule := RouteRule{
|
||||||
|
id: ruleID,
|
||||||
|
sources: sources,
|
||||||
|
destination: destination,
|
||||||
|
proto: proto,
|
||||||
|
srcPort: sPort,
|
||||||
|
dstPort: dPort,
|
||||||
|
action: action,
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
|
||||||
|
m.routeRules[ruleID] = rule
|
||||||
|
|
||||||
|
return &rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
if m.nativeFirewall == nil {
|
m.mutex.Lock()
|
||||||
return errRouteNotSupported
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
ruleID := rule.GetRuleID()
|
||||||
|
if _, exists := m.routeRules[ruleID]; !exists {
|
||||||
|
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
|
||||||
|
delete(m.routeRules, ruleID)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@ -236,7 +280,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
r, ok := rule.(*Rule)
|
r, ok := rule.(*PeerRule)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
@ -279,7 +323,11 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
|
|||||||
return m.dropFilter(packetData, m.incomingRules)
|
return m.dropFilter(packetData, m.incomingRules)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
|
func (m *Manager) isLocalIP(ip net.IP) bool {
|
||||||
|
// TODO: add other interface IPs and keep track of them
|
||||||
|
return ip.Equal(m.wgIface.Address().IP)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
@ -300,18 +348,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always process UDP hooks
|
// Track all protocols if stateful mode is enabled
|
||||||
if d.decoded[1] == layers.LayerTypeUDP {
|
|
||||||
// Track UDP state only if enabled
|
|
||||||
if m.stateful {
|
|
||||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
|
||||||
}
|
|
||||||
return m.checkUDPHooks(d, dstIP, packetData)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Track other protocols only if stateful mode is enabled
|
|
||||||
if m.stateful {
|
if m.stateful {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
m.trackTCPOutbound(d, srcIP, dstIP)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
@ -319,6 +360,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process UDP hooks even if stateful mode is disabled
|
||||||
|
if d.decoded[1] == layers.LayerTypeUDP {
|
||||||
|
return m.checkUDPHooks(d, dstIP, packetData)
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -409,6 +455,7 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
|||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
if !m.isValidPacket(d, packetData) {
|
if !m.isValidPacket(d, packetData) {
|
||||||
|
log.Debugf("invalid packet: %v", d.decoded)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -418,16 +465,69 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.isWireguardTraffic(srcIP, dstIP) {
|
// Check if this is local or routed traffic
|
||||||
return false
|
isLocal := m.isLocalIP(dstIP)
|
||||||
}
|
|
||||||
|
|
||||||
// Check connection state only if enabled
|
// For all inbound traffic, first check if it matches a tracked connection.
|
||||||
|
// This must happen before any other filtering because the packets are statefully tracked.
|
||||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle local traffic - apply peer ACLs
|
||||||
|
if isLocal {
|
||||||
return m.applyRules(srcIP, packetData, rules, d)
|
return m.applyRules(srcIP, packetData, rules, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle routed traffic
|
||||||
|
// TODO: Handle replies for [routed network -> netbird peer], we don't need to start the forwarder here
|
||||||
|
// We might need to apply NAT
|
||||||
|
// Don't handle routing if not enabled
|
||||||
|
if !m.routingEnabled {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get protocol and ports for route ACL check
|
||||||
|
proto := getProtocolFromPacket(d)
|
||||||
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
|
// Check route ACLs
|
||||||
|
if !m.checkRouteACLs(srcIP, dstIP, proto, srcPort, dstPort) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let forwarder handle the packet if it passed route ACLs
|
||||||
|
err := m.forwarder.InjectIncomingPacket(packetData)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to inject incoming packet: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default: drop
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func getProtocolFromPacket(d *decoder) firewall.Protocol {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
return firewall.ProtocolTCP
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
return firewall.ProtocolUDP
|
||||||
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
|
return firewall.ProtocolICMP
|
||||||
|
default:
|
||||||
|
return firewall.ProtocolALL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
return uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort)
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
return uint16(d.udp.SrcPort), uint16(d.udp.DstPort)
|
||||||
|
default:
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||||
@ -498,7 +598,7 @@ func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]R
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
|
||||||
payloadLayer := d.decoded[1]
|
payloadLayer := d.decoded[1]
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||||
@ -547,6 +647,56 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
|||||||
return false, false
|
return false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) checkRouteACLs(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
|
srcAddr, _ := netip.AddrFromSlice(srcIP)
|
||||||
|
dstAddr, _ := netip.AddrFromSlice(dstIP)
|
||||||
|
|
||||||
|
// Default deny if no rules match
|
||||||
|
matched := false
|
||||||
|
|
||||||
|
for _, rule := range m.routeRules {
|
||||||
|
// Check destination
|
||||||
|
if !rule.destination.Contains(dstAddr) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if source matches any source prefix
|
||||||
|
sourceMatched := false
|
||||||
|
for _, src := range rule.sources {
|
||||||
|
if src.Contains(srcAddr) {
|
||||||
|
sourceMatched = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !sourceMatched {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check protocol
|
||||||
|
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check ports if specified
|
||||||
|
if rule.srcPort != nil && rule.srcPort.Values[0] != int(srcPort) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if rule.dstPort != nil && rule.dstPort.Values[0] != int(dstPort) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
matched = true
|
||||||
|
if rule.action == firewall.ActionDrop {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return matched
|
||||||
|
}
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
// SetNetwork of the wireguard interface to which filtering applied
|
||||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||||
m.wgNetwork = network
|
m.wgNetwork = network
|
||||||
@ -558,7 +708,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
|||||||
func (m *Manager) AddUDPPacketHook(
|
func (m *Manager) AddUDPPacketHook(
|
||||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
||||||
) string {
|
) string {
|
||||||
r := Rule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
ip: ip,
|
||||||
protoLayer: layers.LayerTypeUDP,
|
protoLayer: layers.LayerTypeUDP,
|
||||||
@ -577,12 +727,12 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
if in {
|
if in {
|
||||||
r.direction = firewall.RuleDirectionIN
|
r.direction = firewall.RuleDirectionIN
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
m.incomingRules[r.ip.String()][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||||
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
m.outgoingRules[r.ip.String()][r.id] = r
|
||||||
}
|
}
|
||||||
|
@ -194,7 +194,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
var addedRule Rule
|
var addedRule PeerRule
|
||||||
if tt.in {
|
if tt.in {
|
||||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
@ -15,4 +17,5 @@ type WGTunDevice interface {
|
|||||||
DeviceName() string
|
DeviceName() string
|
||||||
Close() error
|
Close() error
|
||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
|
Device() *wgdevice.Device
|
||||||
}
|
}
|
||||||
|
@ -117,6 +117,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
|||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *TunDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|
||||||
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
|
||||||
func (t *TunDevice) assignAddr() error {
|
func (t *TunDevice) assignAddr() error {
|
||||||
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@ -153,6 +154,11 @@ func (t *TunKernelDevice) DeviceName() string {
|
|||||||
return t.name
|
return t.name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device, not applicable for kernel devices
|
||||||
|
func (t *TunKernelDevice) Device() *device.Device {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -117,3 +117,8 @@ func (t *TunNetstackDevice) DeviceName() string {
|
|||||||
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
|
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
|
||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *TunNetstackDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
@ -128,6 +128,11 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice {
|
|||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *USPDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|
||||||
// assignAddr Adds IP address to the tunnel interface
|
// assignAddr Adds IP address to the tunnel interface
|
||||||
func (t *USPDevice) assignAddr() error {
|
func (t *USPDevice) assignAddr() error {
|
||||||
link := newWGLink(t.name)
|
link := newWGLink(t.name)
|
||||||
|
@ -150,6 +150,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
|||||||
return t.filteredDevice
|
return t.filteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Device returns the wireguard device
|
||||||
|
func (t *TunDevice) Device() *device.Device {
|
||||||
|
return t.device
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
|
||||||
if t.nativeTunDevice == nil {
|
if t.nativeTunDevice == nil {
|
||||||
return "", fmt.Errorf("interface has not been initialized yet")
|
return "", fmt.Errorf("interface has not been initialized yet")
|
||||||
|
@ -11,6 +11,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/errors"
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@ -203,6 +205,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice {
|
|||||||
return w.tun.FilteredDevice()
|
return w.tun.FilteredDevice()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetWGDevice returns the WireGuard device
|
||||||
|
func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
||||||
|
return w.tun.Device()
|
||||||
|
}
|
||||||
|
|
||||||
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
||||||
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||||
return w.configurer.GetStats(peerKey)
|
return w.configurer.GetStats(peerKey)
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
@ -32,5 +33,6 @@ type IWGIface interface {
|
|||||||
SetFilter(filter device.PacketFilter) error
|
SetFilter(filter device.PacketFilter) error
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
GetStats(peerKey string) (configurer.WGStats, error)
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
@ -30,6 +31,7 @@ type IWGIface interface {
|
|||||||
SetFilter(filter device.PacketFilter) error
|
SetFilter(filter device.PacketFilter) error
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
GetStats(peerKey string) (configurer.WGStats, error)
|
||||||
GetInterfaceGUIDString() (string, error)
|
GetInterfaceGUIDString() (string, error)
|
||||||
}
|
}
|
||||||
|
@ -383,10 +383,10 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
|||||||
if newRoute.Peer == m.pubKey {
|
if newRoute.Peer == m.pubKey {
|
||||||
ownNetworkIDs[haID] = true
|
ownNetworkIDs[haID] = true
|
||||||
// only linux is supported for now
|
// only linux is supported for now
|
||||||
if runtime.GOOS != "linux" {
|
//if runtime.GOOS != "linux" {
|
||||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
// log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
||||||
continue
|
// continue
|
||||||
}
|
//}
|
||||||
newServerRoutesMap[newRoute.ID] = newRoute
|
newServerRoutesMap[newRoute.ID] = newRoute
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
2
go.mod
2
go.mod
@ -99,6 +99,7 @@ require (
|
|||||||
gorm.io/driver/postgres v1.5.7
|
gorm.io/driver/postgres v1.5.7
|
||||||
gorm.io/driver/sqlite v1.5.3
|
gorm.io/driver/sqlite v1.5.3
|
||||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
|
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
|
||||||
|
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1
|
||||||
nhooyr.io/websocket v1.8.11
|
nhooyr.io/websocket v1.8.11
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -229,7 +230,6 @@ require (
|
|||||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||||
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
|
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
|
||||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 // indirect
|
|
||||||
k8s.io/apimachinery v0.26.2 // indirect
|
k8s.io/apimachinery v0.26.2 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
|
2
go.sum
2
go.sum
@ -527,8 +527,6 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
|||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
|
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
||||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||||
|
Loading…
x
Reference in New Issue
Block a user