From 4199da4a45cdda18ac8a5ab7fd779609c23157d0 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 26 Dec 2024 15:07:27 +0100 Subject: [PATCH] Add userspace routing --- client/firewall/iface.go | 4 + client/firewall/uspfilter/common/iface.go | 16 ++ .../firewall/uspfilter/forwarder/endpoint.go | 79 ++++++ .../firewall/uspfilter/forwarder/forwarder.go | 120 +++++++++ client/firewall/uspfilter/forwarder/tcp.go | 82 ++++++ client/firewall/uspfilter/forwarder/udp.go | 153 ++++++++++++ client/firewall/uspfilter/rule.go | 22 +- client/firewall/uspfilter/uspfilter.go | 236 ++++++++++++++---- client/firewall/uspfilter/uspfilter_test.go | 2 +- client/iface/device.go | 3 + client/iface/device/device_darwin.go | 5 + client/iface/device/device_kernel_unix.go | 6 + client/iface/device/device_netstack.go | 5 + client/iface/device/device_usp_unix.go | 5 + client/iface/device/device_windows.go | 5 + client/iface/iface.go | 7 + client/iface/iwginterface.go | 2 + client/iface/iwginterface_windows.go | 2 + client/internal/routemanager/manager.go | 8 +- go.mod | 2 +- go.sum | 2 - 21 files changed, 712 insertions(+), 54 deletions(-) create mode 100644 client/firewall/uspfilter/common/iface.go create mode 100644 client/firewall/uspfilter/forwarder/endpoint.go create mode 100644 client/firewall/uspfilter/forwarder/forwarder.go create mode 100644 client/firewall/uspfilter/forwarder/tcp.go create mode 100644 client/firewall/uspfilter/forwarder/udp.go diff --git a/client/firewall/iface.go b/client/firewall/iface.go index f349f9210..d842abaa1 100644 --- a/client/firewall/iface.go +++ b/client/firewall/iface.go @@ -1,6 +1,8 @@ package firewall import ( + wgdevice "golang.zx2c4.com/wireguard/device" + "github.com/netbirdio/netbird/client/iface/device" ) @@ -10,4 +12,6 @@ type IFaceMapper interface { Address() device.WGAddress IsUserspaceBind() bool SetFilter(device.PacketFilter) error + GetDevice() *device.FilteredDevice + GetWGDevice() *wgdevice.Device } diff --git a/client/firewall/uspfilter/common/iface.go b/client/firewall/uspfilter/common/iface.go new file mode 100644 index 000000000..3bb128457 --- /dev/null +++ b/client/firewall/uspfilter/common/iface.go @@ -0,0 +1,16 @@ +package common + +import ( + device2 "golang.zx2c4.com/wireguard/device" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +// IFaceMapper defines subset methods of interface required for manager +type IFaceMapper interface { + SetFilter(device.PacketFilter) error + Address() iface.WGAddress + GetWGDevice() *device2.Device + GetDevice() *device.FilteredDevice +} diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go new file mode 100644 index 000000000..9f22fe3a2 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -0,0 +1,79 @@ +package forwarder + +import ( + log "github.com/sirupsen/logrus" + wgdevice "golang.zx2c4.com/wireguard/device" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device +type endpoint struct { + dispatcher stack.NetworkDispatcher + device *wgdevice.Device + mtu uint32 +} + +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher +} + +func (e *endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +func (e *endpoint) MTU() uint32 { + return e.mtu +} + +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return stack.CapabilityNone +} + +func (e *endpoint) MaxHeaderLength() uint16 { + return 0 +} + +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { + var written int + for _, pkt := range pkts.AsSlice() { + netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice()) + + data := stack.PayloadSince(pkt.NetworkHeader()) + if data == nil { + continue + } + + // Send the packet through WireGuard + address := netHeader.DestinationAddress() + + // TODO: handle dest ip addresses outside our network + err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()) + if err != nil { + log.Errorf("CreateOutboundPacket: %v", err) + continue + } + written++ + } + + return written, nil +} + +func (e *endpoint) Wait() { +} + +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} + +func (e *endpoint) AddHeader(*stack.PacketBuffer) { +} + +func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool { + return true +} diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go new file mode 100644 index 000000000..4554ebb20 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -0,0 +1,120 @@ +package forwarder + +import ( + "fmt" + + log "github.com/sirupsen/logrus" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + + "github.com/netbirdio/netbird/client/firewall/uspfilter/common" +) + +const ( + receiveWindow = 32768 + maxInFlight = 1024 +) + +type Forwarder struct { + stack *stack.Stack + endpoint *endpoint + udpForwarder *udpForwarder +} + +func New(iface common.IFaceMapper) (*Forwarder, error) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{ + tcp.NewProtocol, + udp.NewProtocol, + }, + HandleLocal: false, + }) + + mtu, err := iface.GetDevice().MTU() + if err != nil { + return nil, fmt.Errorf("get MTU: %w", err) + } + nicID := tcpip.NICID(1) + endpoint := &endpoint{ + device: iface.GetWGDevice(), + mtu: uint32(mtu), + } + + if err := s.CreateNIC(nicID, endpoint); err != nil { + return nil, fmt.Errorf("failed to create NIC: %w", err) + } + + _, bits := iface.Address().Network.Mask.Size() + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(iface.Address().IP.To4()), + PrefixLen: bits, + }, + } + + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + return nil, fmt.Errorf("failed to add protocol address: %w", err) + } + + defaultSubnet, err := tcpip.NewSubnet( + tcpip.AddrFrom4([4]byte{0, 0, 0, 0}), + tcpip.MaskFromBytes([]byte{0, 0, 0, 0}), + ) + if err != nil { + return nil, fmt.Errorf("creating default subnet: %w", err) + } + + if s.SetPromiscuousMode(nicID, true); err != nil { + return nil, fmt.Errorf("set promiscuous mode: %w", err) + } + if s.SetSpoofing(nicID, true); err != nil { + return nil, fmt.Errorf("set spoofing: %w", err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: defaultSubnet, + NIC: nicID, + }, + }) + + f := &Forwarder{ + stack: s, + endpoint: endpoint, + udpForwarder: newUDPForwarder(), + } + + // Set up TCP forwarder + tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) + + // Set up UDP forwarder + udpForwarder := udp.NewForwarder(s, f.handleUDP) + s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + + log.Debugf("forwarder: Initialization complete with NIC %d", nicID) + return f, nil +} + +func (f *Forwarder) InjectIncomingPacket(payload []byte) error { + if len(payload) < header.IPv4MinimumSize { + return fmt.Errorf("packet too small: %d bytes", len(payload)) + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(payload), + }) + defer pkt.DecRef() + + if f.endpoint.dispatcher != nil { + f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt) + } + return nil +} diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go new file mode 100644 index 000000000..4f406dea5 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -0,0 +1,82 @@ +package forwarder + +import ( + "fmt" + "io" + "net" + "sync" + + log "github.com/sirupsen/logrus" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// handleTCP is called by the TCP forwarder for new connections. +func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { + id := r.ID() + + dstAddr := id.LocalAddress + dstPort := id.LocalPort + dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort) + + // Dial the destination first + dialer := net.Dialer{} + outConn, err := dialer.Dial("tcp", dialAddr) + if err != nil { + r.Complete(true) + return + } + + // Create wait queue for blocking syscalls + wq := waiter.Queue{} + + ep, err2 := r.CreateEndpoint(&wq) + if err2 != nil { + if err := outConn.Close(); err != nil { + log.Errorf("forwarder: outConn close error: %v", err) + } + r.Complete(true) + return + } + + // Now that we've successfully connected to the destination, + // we can complete the incoming connection + r.Complete(false) + + inConn := gonet.NewTCPConn(&wq, ep) + + go f.proxyTCP(inConn, outConn) +} + +func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) { + defer func() { + if err := inConn.Close(); err != nil { + log.Errorf("forwarder: inConn close error: %v", err) + } + if err := outConn.Close(); err != nil { + log.Errorf("forwarder: outConn close error: %v", err) + } + }() + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + _, err := io.Copy(outConn, inConn) + if err != nil { + log.Errorf("proxyTCP: copy error: %v", err) + } + }() + + go func() { + defer wg.Done() + _, err := io.Copy(inConn, outConn) + if err != nil { + log.Errorf("proxyTCP: copy error: %v", err) + } + }() + + wg.Wait() +} diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go new file mode 100644 index 000000000..836d570cb --- /dev/null +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -0,0 +1,153 @@ +package forwarder + +import ( + "fmt" + "net" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + udpTimeout = 60 * time.Second +) + +type udpPacketConn struct { + conn *gonet.UDPConn + outConn net.Conn + lastTime time.Time +} + +type udpForwarder struct { + sync.RWMutex + conns map[string]*udpPacketConn +} + +func newUDPForwarder() *udpForwarder { + f := &udpForwarder{ + conns: make(map[string]*udpPacketConn), + } + go f.cleanup() + return f +} + +// cleanup periodically removes idle UDP connections +func (f *udpForwarder) cleanup() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for range ticker.C { + f.Lock() + now := time.Now() + for addr, conn := range f.conns { + if now.Sub(conn.lastTime) > udpTimeout { + conn.conn.Close() + conn.outConn.Close() + delete(f.conns, addr) + } + } + f.Unlock() + } +} + +// handleUDP is called by the UDP forwarder for new packets +func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { + id := r.ID() + dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort) + + // Create wait queue for blocking syscalls + wq := waiter.Queue{} + + ep, err := r.CreateEndpoint(&wq) + if err != nil { + log.Errorf("Create UDP endpoint error: %v", err) + return + } + + inConn := gonet.NewUDPConn(f.stack, &wq, ep) + + // Try to get existing connection or create a new one + f.udpForwarder.Lock() + pConn, exists := f.udpForwarder.conns[dstAddr] + if !exists { + outConn, err := net.Dial("udp", dstAddr) + if err != nil { + f.udpForwarder.Unlock() + if err := inConn.Close(); err != nil { + log.Errorf("forwader: UDP inConn close error: %v", err) + } + log.Errorf("forwarder> UDP dial error: %v", err) + return + } + + pConn = &udpPacketConn{ + conn: inConn, + outConn: outConn, + lastTime: time.Now(), + } + f.udpForwarder.conns[dstAddr] = pConn + + go f.proxyUDP(pConn, dstAddr) + } + f.udpForwarder.Unlock() +} + +func (f *Forwarder) proxyUDP(pConn *udpPacketConn, dstAddr string) { + defer func() { + if err := pConn.conn.Close(); err != nil { + log.Errorf("forwarder: inConn close error: %v", err) + } + if err := pConn.outConn.Close(); err != nil { + log.Errorf("forwarder: outConn close error: %v", err) + } + }() + + var wg sync.WaitGroup + wg.Add(2) + + // Handle outbound to inbound traffic + go func() { + defer wg.Done() + f.copyUDP(pConn.conn, pConn.outConn, dstAddr, "outbound->inbound") + }() + + // Handle inbound to outbound traffic + go func() { + defer wg.Done() + f.copyUDP(pConn.outConn, pConn.conn, dstAddr, "inbound->outbound") + }() + + wg.Wait() + + // Clean up the connection from the map + f.udpForwarder.Lock() + delete(f.udpForwarder.conns, dstAddr) + f.udpForwarder.Unlock() +} + +func (f *Forwarder) copyUDP(dst net.Conn, src net.Conn, dstAddr, direction string) { + buffer := make([]byte, 65535) + for { + n, err := src.Read(buffer) + if err != nil { + log.Errorf("UDP %s read error: %v", direction, err) + return + } + + _, err = dst.Write(buffer[:n]) + if err != nil { + log.Errorf("UDP %s write error: %v", direction, err) + continue + } + + f.udpForwarder.Lock() + if conn, ok := f.udpForwarder.conns[dstAddr]; ok { + conn.lastTime = time.Now() + } + f.udpForwarder.Unlock() + } +} diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index 5c1daccaf..3d199ce65 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -2,14 +2,15 @@ package uspfilter import ( "net" + "net/netip" "github.com/google/gopacket" firewall "github.com/netbirdio/netbird/client/firewall/manager" ) -// Rule to handle management of rules -type Rule struct { +// PeerRule to handle management of rules +type PeerRule struct { id string ip net.IP ipLayer gopacket.LayerType @@ -25,6 +26,21 @@ type Rule struct { } // GetRuleID returns the rule id -func (r *Rule) GetRuleID() string { +func (r *PeerRule) GetRuleID() string { + return r.id +} + +type RouteRule struct { + id string + sources []netip.Prefix + destination netip.Prefix + proto firewall.Protocol + srcPort *firewall.Port + dstPort *firewall.Port + action firewall.Action +} + +// GetRuleID returns the rule id +func (r *RouteRule) GetRuleID() string { return r.id } diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 24cfd6e96..feed1887b 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -14,9 +14,9 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" - "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -24,34 +24,34 @@ const layerTypeAll = 0 const EnvDisableConntrack = "NB_DISABLE_CONNTRACK" +// TODO: Add env var to disable routing + var ( errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") ) -// IFaceMapper defines subset methods of interface required for manager -type IFaceMapper interface { - SetFilter(device.PacketFilter) error - Address() iface.WGAddress -} - // RuleSet is a set of rules grouped by a string key -type RuleSet map[string]Rule +type RuleSet map[string]PeerRule // Manager userspace firewall manager type Manager struct { outgoingRules map[string]RuleSet incomingRules map[string]RuleSet + routeRules map[string]RouteRule wgNetwork *net.IPNet decoders sync.Pool - wgIface IFaceMapper + wgIface common.IFaceMapper nativeFirewall firewall.Manager mutex sync.RWMutex + routingEnabled bool + stateful bool udpTracker *conntrack.UDPTracker icmpTracker *conntrack.ICMPTracker tcpTracker *conntrack.TCPTracker + forwarder *forwarder.Forwarder } // decoder for packages @@ -68,11 +68,11 @@ type decoder struct { } // Create userspace firewall manager constructor -func Create(iface IFaceMapper) (*Manager, error) { +func Create(iface common.IFaceMapper) (*Manager, error) { return create(iface) } -func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) { +func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) { mgr, err := create(iface) if err != nil { return nil, err @@ -82,7 +82,7 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager return mgr, nil } -func create(iface IFaceMapper) (*Manager, error) { +func create(iface common.IFaceMapper) (*Manager, error) { disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) m := &Manager{ @@ -101,8 +101,11 @@ func create(iface IFaceMapper) (*Manager, error) { }, outgoingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet), + routeRules: make(map[string]RouteRule), wgIface: iface, stateful: !disableConntrack, + // TODO: fix + routingEnabled: true, } // Only initialize trackers if stateful mode is enabled @@ -114,8 +117,23 @@ func create(iface IFaceMapper) (*Manager, error) { m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) } + intf := iface.GetWGDevice() + if intf == nil { + log.Info("forwarding not supported") + // Only supported in userspace mode as we need to inject packets back into wireguard directly + // TODO: Check if native firewall can do the job, in that case just forward everything (restores previous behavior) + m.routingEnabled = false + } else { + var err error + m.forwarder, err = forwarder.New(iface) + if err != nil { + log.Errorf("failed to create forwarder: %v", err) + m.routingEnabled = false + } + } + if err := iface.SetFilter(m); err != nil { - return nil, err + return nil, fmt.Errorf("set filter: %w", err) } return m, nil } @@ -161,7 +179,7 @@ func (m *Manager) AddPeerFiltering( ipsetName string, comment string, ) ([]firewall.Rule, error) { - r := Rule{ + r := PeerRule{ id: uuid.New().String(), ip: ip, ipLayer: layers.LayerTypeIPv6, @@ -217,18 +235,44 @@ func (m *Manager) AddPeerFiltering( return []firewall.Rule{&r}, nil } -func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { - if m.nativeFirewall == nil { - return nil, errRouteNotSupported +func (m *Manager) AddRouteFiltering( + sources []netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + ruleID := uuid.New().String() + rule := RouteRule{ + id: ruleID, + sources: sources, + destination: destination, + proto: proto, + srcPort: sPort, + dstPort: dPort, + action: action, } - return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) + + m.routeRules[ruleID] = rule + + return &rule, nil } func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { - if m.nativeFirewall == nil { - return errRouteNotSupported + m.mutex.Lock() + defer m.mutex.Unlock() + + ruleID := rule.GetRuleID() + if _, exists := m.routeRules[ruleID]; !exists { + return fmt.Errorf("route rule not found: %s", ruleID) } - return m.nativeFirewall.DeleteRouteRule(rule) + + delete(m.routeRules, ruleID) + return nil } // DeletePeerRule from the firewall by rule definition @@ -236,7 +280,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() - r, ok := rule.(*Rule) + r, ok := rule.(*PeerRule) if !ok { return fmt.Errorf("delete rule: invalid rule type: %T", rule) } @@ -279,7 +323,11 @@ func (m *Manager) DropIncoming(packetData []byte) bool { return m.dropFilter(packetData, m.incomingRules) } -// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP +func (m *Manager) isLocalIP(ip net.IP) bool { + // TODO: add other interface IPs and keep track of them + return ip.Equal(m.wgIface.Address().IP) +} + func (m *Manager) processOutgoingHooks(packetData []byte) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -300,18 +348,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { return false } - // Always process UDP hooks - if d.decoded[1] == layers.LayerTypeUDP { - // Track UDP state only if enabled - if m.stateful { - m.trackUDPOutbound(d, srcIP, dstIP) - } - return m.checkUDPHooks(d, dstIP, packetData) - } - - // Track other protocols only if stateful mode is enabled + // Track all protocols if stateful mode is enabled if m.stateful { switch d.decoded[1] { + case layers.LayerTypeUDP: + m.trackUDPOutbound(d, srcIP, dstIP) case layers.LayerTypeTCP: m.trackTCPOutbound(d, srcIP, dstIP) case layers.LayerTypeICMPv4: @@ -319,6 +360,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { } } + // Process UDP hooks even if stateful mode is disabled + if d.decoded[1] == layers.LayerTypeUDP { + return m.checkUDPHooks(d, dstIP, packetData) + } + return false } @@ -409,6 +455,7 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { defer m.decoders.Put(d) if !m.isValidPacket(d, packetData) { + log.Debugf("invalid packet: %v", d.decoded) return true } @@ -418,16 +465,69 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { return true } - if !m.isWireguardTraffic(srcIP, dstIP) { - return false - } + // Check if this is local or routed traffic + isLocal := m.isLocalIP(dstIP) - // Check connection state only if enabled + // For all inbound traffic, first check if it matches a tracked connection. + // This must happen before any other filtering because the packets are statefully tracked. if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { return false } - return m.applyRules(srcIP, packetData, rules, d) + // Handle local traffic - apply peer ACLs + if isLocal { + return m.applyRules(srcIP, packetData, rules, d) + } + + // Handle routed traffic + // TODO: Handle replies for [routed network -> netbird peer], we don't need to start the forwarder here + // We might need to apply NAT + // Don't handle routing if not enabled + if !m.routingEnabled { + return true + } + + // Get protocol and ports for route ACL check + proto := getProtocolFromPacket(d) + srcPort, dstPort := getPortsFromPacket(d) + + // Check route ACLs + if !m.checkRouteACLs(srcIP, dstIP, proto, srcPort, dstPort) { + return true + } + + // Let forwarder handle the packet if it passed route ACLs + err := m.forwarder.InjectIncomingPacket(packetData) + if err != nil { + log.Errorf("Failed to inject incoming packet: %v", err) + } + + // Default: drop + return true +} + +func getProtocolFromPacket(d *decoder) firewall.Protocol { + switch d.decoded[1] { + case layers.LayerTypeTCP: + return firewall.ProtocolTCP + case layers.LayerTypeUDP: + return firewall.ProtocolUDP + case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: + return firewall.ProtocolICMP + default: + return firewall.ProtocolALL + } +} + +func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) { + switch d.decoded[1] { + case layers.LayerTypeTCP: + return uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + return uint16(d.udp.SrcPort), uint16(d.udp.DstPort) + default: + return 0, 0 + } } func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { @@ -498,7 +598,7 @@ func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]R return true } -func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) { +func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) { payloadLayer := d.decoded[1] for _, rule := range rules { if rule.matchByIP && !ip.Equal(rule.ip) { @@ -547,6 +647,56 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode return false, false } +func (m *Manager) checkRouteACLs(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + srcAddr, _ := netip.AddrFromSlice(srcIP) + dstAddr, _ := netip.AddrFromSlice(dstIP) + + // Default deny if no rules match + matched := false + + for _, rule := range m.routeRules { + // Check destination + if !rule.destination.Contains(dstAddr) { + continue + } + + // Check if source matches any source prefix + sourceMatched := false + for _, src := range rule.sources { + if src.Contains(srcAddr) { + sourceMatched = true + break + } + } + if !sourceMatched { + continue + } + + // Check protocol + if rule.proto != firewall.ProtocolALL && rule.proto != proto { + continue + } + + // Check ports if specified + if rule.srcPort != nil && rule.srcPort.Values[0] != int(srcPort) { + continue + } + if rule.dstPort != nil && rule.dstPort.Values[0] != int(dstPort) { + continue + } + + matched = true + if rule.action == firewall.ActionDrop { + return false + } + } + + return matched +} + // SetNetwork of the wireguard interface to which filtering applied func (m *Manager) SetNetwork(network *net.IPNet) { m.wgNetwork = network @@ -558,7 +708,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) { func (m *Manager) AddUDPPacketHook( in bool, ip net.IP, dPort uint16, hook func([]byte) bool, ) string { - r := Rule{ + r := PeerRule{ id: uuid.New().String(), ip: ip, protoLayer: layers.LayerTypeUDP, @@ -577,12 +727,12 @@ func (m *Manager) AddUDPPacketHook( if in { r.direction = firewall.RuleDirectionIN if _, ok := m.incomingRules[r.ip.String()]; !ok { - m.incomingRules[r.ip.String()] = make(map[string]Rule) + m.incomingRules[r.ip.String()] = make(map[string]PeerRule) } m.incomingRules[r.ip.String()][r.id] = r } else { if _, ok := m.outgoingRules[r.ip.String()]; !ok { - m.outgoingRules[r.ip.String()] = make(map[string]Rule) + m.outgoingRules[r.ip.String()] = make(map[string]PeerRule) } m.outgoingRules[r.ip.String()][r.id] = r } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index d3563e6f2..443d82607 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -194,7 +194,7 @@ func TestAddUDPPacketHook(t *testing.T) { manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) - var addedRule Rule + var addedRule PeerRule if tt.in { if len(manager.incomingRules[tt.ip.String()]) != 1 { t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) diff --git a/client/iface/device.go b/client/iface/device.go index 0d4e69145..2a170adfb 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -3,6 +3,8 @@ package iface import ( + wgdevice "golang.zx2c4.com/wireguard/device" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" ) @@ -15,4 +17,5 @@ type WGTunDevice interface { DeviceName() string Close() error FilteredDevice() *device.FilteredDevice + Device() *wgdevice.Device } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index b5a128bc1..fe7ed1752 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -117,6 +117,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } +// Device returns the wireguard device +func (t *TunDevice) Device() *device.Device { + return t.device +} + // assignAddr Adds IP address to the tunnel interface and network route based on the range provided func (t *TunDevice) assignAddr() error { cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index f355d2cf7..978e72b79 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -9,6 +9,7 @@ import ( "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/device" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" @@ -153,6 +154,11 @@ func (t *TunKernelDevice) DeviceName() string { return t.name } +// Device returns the wireguard device, not applicable for kernel devices +func (t *TunKernelDevice) Device() *device.Device { + return nil +} + func (t *TunKernelDevice) FilteredDevice() *FilteredDevice { return nil } diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index f5d39e9e0..c7d297187 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -117,3 +117,8 @@ func (t *TunNetstackDevice) DeviceName() string { func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } + +// Device returns the wireguard device +func (t *TunNetstackDevice) Device() *device.Device { + return t.device +} diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 643d77565..1a154501a 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -128,6 +128,11 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } +// Device returns the wireguard device +func (t *USPDevice) Device() *device.Device { + return t.device +} + // assignAddr Adds IP address to the tunnel interface func (t *USPDevice) assignAddr() error { link := newWGLink(t.name) diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index 86968d06d..e603d7696 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -150,6 +150,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } +// Device returns the wireguard device +func (t *TunDevice) Device() *device.Device { + return t.device +} + func (t *TunDevice) GetInterfaceGUIDString() (string, error) { if t.nativeTunDevice == nil { return "", fmt.Errorf("interface has not been initialized yet") diff --git a/client/iface/iface.go b/client/iface/iface.go index 1fb9c2691..64219975f 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -11,6 +11,8 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + wgdevice "golang.zx2c4.com/wireguard/device" + "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" @@ -203,6 +205,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice { return w.tun.FilteredDevice() } +// GetWGDevice returns the WireGuard device +func (w *WGIface) GetWGDevice() *wgdevice.Device { + return w.tun.Device() +} + // GetStats returns the last handshake time, rx and tx bytes for the given peer func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) { return w.configurer.GetStats(peerKey) diff --git a/client/iface/iwginterface.go b/client/iface/iwginterface.go index f5ab29539..472ab45f9 100644 --- a/client/iface/iwginterface.go +++ b/client/iface/iwginterface.go @@ -6,6 +6,7 @@ import ( "net" "time" + wgdevice "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" @@ -32,5 +33,6 @@ type IWGIface interface { SetFilter(filter device.PacketFilter) error GetFilter() device.PacketFilter GetDevice() *device.FilteredDevice + GetWGDevice() *wgdevice.Device GetStats(peerKey string) (configurer.WGStats, error) } diff --git a/client/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go index 96eec52a5..c9183cafd 100644 --- a/client/iface/iwginterface_windows.go +++ b/client/iface/iwginterface_windows.go @@ -4,6 +4,7 @@ import ( "net" "time" + wgdevice "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" @@ -30,6 +31,7 @@ type IWGIface interface { SetFilter(filter device.PacketFilter) error GetFilter() device.PacketFilter GetDevice() *device.FilteredDevice + GetWGDevice() *wgdevice.Device GetStats(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDString() (string, error) } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 389e97e2d..0d4a21ac4 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -383,10 +383,10 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID] if newRoute.Peer == m.pubKey { ownNetworkIDs[haID] = true // only linux is supported for now - if runtime.GOOS != "linux" { - log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) - continue - } + //if runtime.GOOS != "linux" { + // log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) + // continue + //} newServerRoutesMap[newRoute.ID] = newRoute } } diff --git a/go.mod b/go.mod index d48280df0..b1da75512 100644 --- a/go.mod +++ b/go.mod @@ -99,6 +99,7 @@ require ( gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.3 gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde + gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 nhooyr.io/websocket v1.8.11 ) @@ -229,7 +230,6 @@ require ( gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect - gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 // indirect k8s.io/apimachinery v0.26.2 // indirect ) diff --git a/go.sum b/go.sum index 540cbf20b..5511fbe30 100644 --- a/go.sum +++ b/go.sum @@ -527,8 +527,6 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= -github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY= -github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=