From 9f32ccd4533d5301bcb901677af9816cb3408f92 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 5 Apr 2024 20:38:49 +0200 Subject: [PATCH] Rollback new routing functionality (#1805) --- .github/workflows/golang-test-darwin.yml | 3 - .github/workflows/golang-test-linux.yml | 10 +- .github/workflows/golang-test-windows.yml | 2 +- .github/workflows/golangci-lint.yml | 2 +- client/internal/engine.go | 16 - client/internal/peer/conn.go | 32 -- client/internal/relay/relay.go | 7 +- client/internal/routemanager/client.go | 63 +-- client/internal/routemanager/manager.go | 79 +-- client/internal/routemanager/manager_test.go | 30 +- client/internal/routemanager/mock.go | 5 - client/internal/routemanager/routemanager.go | 126 ----- .../routemanager/server_nonandroid.go | 57 +- client/internal/routemanager/systemops.go | 407 -------------- .../routemanager/systemops_android.go | 24 +- client/internal/routemanager/systemops_bsd.go | 1 + .../internal/routemanager/systemops_darwin.go | 61 -- .../routemanager/systemops_darwin_test.go | 100 ---- client/internal/routemanager/systemops_ios.go | 26 +- .../internal/routemanager/systemops_linux.go | 531 +++--------------- .../routemanager/systemops_linux_test.go | 207 ------- .../routemanager/systemops_nonandroid.go | 120 ++++ ...s_test.go => systemops_nonandroid_test.go} | 212 ++----- .../routemanager/systemops_nonlinux.go | 32 +- .../routemanager/systemops_unix_test.go | 234 -------- .../routemanager/systemops_windows.go | 88 +-- .../routemanager/systemops_windows_test.go | 289 ---------- client/internal/stdnet/dialer.go | 24 - client/internal/stdnet/listener.go | 20 - client/internal/wgproxy/proxy_ebpf.go | 9 +- client/internal/wgproxy/proxy_userspace.go | 4 +- go.mod | 2 +- iface/wg_configurer_kernel.go | 4 +- iface/wg_configurer_usp.go | 11 +- management/client/grpc.go | 2 - sharedsock/sock_linux.go | 10 - signal/client/grpc.go | 2 - util/grpc/dialer.go | 22 - util/net/dialer.go | 21 - util/net/dialer_generic.go | 163 ------ util/net/dialer_linux.go | 12 - util/net/dialer_nonlinux.go | 6 - util/net/listener.go | 21 - util/net/listener_generic.go | 163 ------ util/net/listener_linux.go | 14 - util/net/listener_mobile.go | 11 - util/net/listener_nonlinux.go | 6 - util/net/net.go | 17 - util/net/net_linux.go | 35 -- 49 files changed, 364 insertions(+), 2979 deletions(-) delete mode 100644 client/internal/routemanager/routemanager.go delete mode 100644 client/internal/routemanager/systemops.go delete mode 100644 client/internal/routemanager/systemops_darwin.go delete mode 100644 client/internal/routemanager/systemops_darwin_test.go delete mode 100644 client/internal/routemanager/systemops_linux_test.go create mode 100644 client/internal/routemanager/systemops_nonandroid.go rename client/internal/routemanager/{systemops_test.go => systemops_nonandroid_test.go} (59%) delete mode 100644 client/internal/routemanager/systemops_unix_test.go delete mode 100644 client/internal/routemanager/systemops_windows_test.go delete mode 100644 client/internal/stdnet/dialer.go delete mode 100644 client/internal/stdnet/listener.go delete mode 100644 util/grpc/dialer.go delete mode 100644 util/net/dialer.go delete mode 100644 util/net/dialer_generic.go delete mode 100644 util/net/dialer_linux.go delete mode 100644 util/net/dialer_nonlinux.go delete mode 100644 util/net/listener.go delete mode 100644 util/net/listener_generic.go delete mode 100644 util/net/listener_linux.go delete mode 100644 util/net/listener_mobile.go delete mode 100644 util/net/listener_nonlinux.go delete mode 100644 util/net/net.go delete mode 100644 util/net/net_linux.go diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index d7007c860..f8afd3d6e 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -32,9 +32,6 @@ jobs: restore-keys: | macos-go- - - name: Install libpcap - run: brew install libpcap - - name: Install modules run: go mod tidy diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 42f740e9b..74e6d1203 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -36,11 +36,7 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib - name: Install modules run: go mod tidy @@ -71,7 +67,7 @@ jobs: uses: actions/checkout@v3 - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib - name: Install modules run: go mod tidy @@ -86,7 +82,7 @@ jobs: run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock - name: Generate RouteManager Test bin - run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/... + run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/... - name: Generate nftables Manager Test bin run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 2d63acbcd..6027d3626 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -46,7 +46,7 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1" - name: test output if: ${{ always() }} run: Get-Content test-out.txt diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 13228250d..9f543c74c 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -40,7 +40,7 @@ jobs: cache: false - name: Install dependencies if: matrix.os == 'ubuntu-latest' - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: diff --git a/client/internal/engine.go b/client/internal/engine.go index d6238c4b3..13ef8ce15 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -94,9 +94,6 @@ type Engine struct { // peerConns is a map that holds all the peers that are known to this peer peerConns map[string]*peer.Conn - beforePeerHook peer.BeforeAddPeerHookFunc - afterPeerHook peer.AfterRemovePeerHookFunc - // rpManager is a Rosenpass manager rpManager *rosenpass.Manager @@ -264,14 +261,6 @@ func (e *Engine) Start() error { e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) - beforePeerHook, afterPeerHook, err := e.routeManager.Init() - if err != nil { - log.Errorf("Failed to initialize route manager: %s", err) - } else { - e.beforePeerHook = beforePeerHook - e.afterPeerHook = afterPeerHook - } - e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) err = e.wgInterfaceCreate() @@ -821,11 +810,6 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { } e.peerConns[peerKey] = conn - if e.beforePeerHook != nil && e.afterPeerHook != nil { - conn.AddBeforeAddPeerHook(e.beforePeerHook) - conn.AddAfterRemovePeerHook(e.afterPeerHook) - } - err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn) if err != nil { log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index f3d07dcad..17ef7e87f 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -20,7 +20,6 @@ import ( "github.com/netbirdio/netbird/iface/bind" signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" - nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -101,9 +100,6 @@ type IceCredentials struct { Pwd string } -type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error -type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error - type Conn struct { config ConnConfig mu sync.Mutex @@ -142,10 +138,6 @@ type Conn struct { remoteEndpoint *net.UDPAddr remoteConn *ice.Conn - - connID nbnet.ConnectionID - beforeAddPeerHooks []BeforeAddPeerHookFunc - afterRemovePeerHooks []AfterRemovePeerHookFunc } // meta holds meta information about a connection @@ -401,14 +393,6 @@ func isRelayCandidate(candidate ice.Candidate) bool { return candidate.Type() == ice.CandidateTypeRelay } -func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) { - conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) -} - -func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) { - conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) -} - // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) { conn.mu.Lock() @@ -437,13 +421,6 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem conn.remoteEndpoint = endpointUdpAddr log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) - conn.connID = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connID, endpointUdpAddr.IP); err != nil { - log.Errorf("Before add peer hook failed: %v", err) - } - } - err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) if err != nil { if conn.wgProxy != nil { @@ -534,15 +511,6 @@ func (conn *Conn) cleanup() error { // todo: is it problem if we try to remove a peer what is never existed? err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if conn.connID != "" { - for _, hook := range conn.afterRemovePeerHooks { - if err := hook(conn.connID); err != nil { - log.Errorf("After remove peer hook failed: %v", err) - } - } - } - conn.connID = "" - if conn.notifyDisconnected != nil { conn.notifyDisconnected() conn.notifyDisconnected = nil diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 84fd72e49..ad3b94f2a 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -12,7 +12,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/stdnet" - nbnet "github.com/netbirdio/netbird/util/net" ) // ProbeResult holds the info about the result of a relay probe request @@ -96,13 +95,15 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) switch uri.Proto { case stun.ProtoTypeUDP: var err error - conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "") + listener := &net.ListenConfig{} + conn, err = listener.ListenPacket(ctx, "udp", "") if err != nil { probeErr = fmt.Errorf("listen: %w", err) return } case stun.ProtoTypeTCP: - tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr) + dialer := &net.Dialer{} + tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr) if err != nil { probeErr = fmt.Errorf("dial: %w", err) return diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 38cf4bf65..f7ead5827 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -41,7 +41,6 @@ type clientNetwork struct { func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork { ctx, cancel := context.WithCancel(ctx) - client := &clientNetwork{ ctx: ctx, stop: cancel, @@ -73,18 +72,6 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { return routePeerStatuses } -// getBestRouteFromStatuses determines the most optimal route from the available routes -// within a clientNetwork, taking into account peer connection status, route metrics, and -// preference for non-relayed and direct connections. -// -// It follows these prioritization rules: -// * Connected peers: Only routes with connected peers are considered. -// * Metric: Routes with lower metrics (better) are prioritized. -// * Non-relayed: Routes without relays are preferred. -// * Direct connections: Routes with direct peer connections are favored. -// * Stability: In case of equal scores, the currently active route (if any) is maintained. -// -// It returns the ID of the selected optimal route. func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { chosen := "" chosenScore := 0 @@ -171,7 +158,7 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() { func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { state, err := c.statusRecorder.GetPeer(peerKey) if err != nil { - return fmt.Errorf("get peer state: %v", err) + return err } delete(state.Routes, c.network.String()) @@ -185,7 +172,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String()) if err != nil { - return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v", + return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v", c.network, c.chosenRoute.Peer, err) } return nil @@ -193,26 +180,30 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if c.chosenRoute != nil { - if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil { - return fmt.Errorf("remove route %s from system, err: %v", c.network, err) + err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) + if err != nil { + return err } - - if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { - return fmt.Errorf("remove route: %v", err) + err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String()) + if err != nil { + return fmt.Errorf("couldn't remove route %s from system, err: %v", + c.network, err) } } return nil } func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { + + var err error + routerPeerStatuses := c.getRouterPeerStatuses() chosen := c.getBestRouteFromStatuses(routerPeerStatuses) - - // If no route is chosen, remove the route from the peer and system if chosen == "" { - if err := c.removeRouteFromPeerAndSystem(); err != nil { - return fmt.Errorf("remove route from peer and system: %v", err) + err = c.removeRouteFromPeerAndSystem() + if err != nil { + return err } c.chosenRoute = nil @@ -220,7 +211,6 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { return nil } - // If the chosen route is the same as the current route, do nothing if c.chosenRoute != nil && c.chosenRoute.ID == chosen { if c.chosenRoute.IsEqual(c.routes[chosen]) { return nil @@ -228,13 +218,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } if c.chosenRoute != nil { - // If a previous route exists, remove it from the peer - if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { - return fmt.Errorf("remove route from peer: %v", err) + err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) + if err != nil { + return err } } else { - // otherwise add the route to the system - if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil { + err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String()) + if err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", c.network.String(), c.wgInterface.Address().IP.String(), err) } @@ -255,7 +245,8 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } - if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil { + err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()) + if err != nil { log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", c.network, c.chosenRoute.Peer, err) } @@ -296,21 +287,21 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { log.Debugf("stopping watcher for network %s", c.network) err := c.removeRouteFromPeerAndSystem() if err != nil { - log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err) + log.Error(err) } return case <-c.peerStateUpdate: err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Errorf("Couldn't recalculate route and update peer and system: %v", err) + log.Error(err) } case update := <-c.routeUpdate: if update.updateSerial < c.updateSerial { - log.Warnf("Received a routes update with smaller serial number, ignoring it") + log.Warnf("received a routes update with smaller serial number, ignoring it") continue } - log.Debugf("Received a new client network route update for %s", c.network) + log.Debugf("received a new client network route update for %s", c.network) c.handleUpdate(update) @@ -318,7 +309,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { - log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err) + log.Error(err) } c.startPeersStatusChangeWatcher() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 36a37f02c..b624d8c34 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -2,10 +2,6 @@ package routemanager import ( "context" - "fmt" - "net" - "net/netip" - "net/url" "runtime" "sync" @@ -19,14 +15,8 @@ import ( "github.com/netbirdio/netbird/version" ) -var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0) - -// nolint:unused -var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) - // Manager is a route manager interface type Manager interface { - Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -66,24 +56,6 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, return dm } -// Init sets up the routing -func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - if err := cleanupRouting(); err != nil { - log.Warnf("Failed cleaning up routing: %v", err) - } - - mgmtAddress := m.statusRecorder.GetManagementState().URL - signalAddress := m.statusRecorder.GetSignalState().URL - ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress}) - - beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface) - if err != nil { - return nil, nil, fmt.Errorf("setup routing: %w", err) - } - log.Info("Routing setup complete") - return beforePeerHook, afterPeerHook, nil -} - func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { var err error m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) @@ -99,15 +71,9 @@ func (m *DefaultManager) Stop() { if m.serverRouter != nil { m.serverRouter.cleanUp() } - if err := cleanupRouting(); err != nil { - log.Errorf("Error cleaning up routing: %v", err) - } else { - log.Info("Routing cleanup complete") - } - m.ctx = nil } -// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps +// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { select { case <-m.ctx.Done(): @@ -125,7 +91,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro if m.serverRouter != nil { err := m.serverRouter.updateRoutes(newServerRoutesMap) if err != nil { - return fmt.Errorf("update routes: %w", err) + return err } } @@ -190,7 +156,11 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string] for _, newRoute := range newRoutes { networkID := route.GetHAUniqueID(newRoute) if !ownNetworkIDs[networkID] { - if !isPrefixSupported(newRoute.Network) { + // if prefix is too small, lets assume is a possible default route which is not yet supported + // we skip this route management + if newRoute.Network.Bits() < minRangeBits { + log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route", + version.NetbirdVersion(), newRoute.Network) continue } newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) @@ -208,38 +178,3 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou } return rs } - -func isPrefixSupported(prefix netip.Prefix) bool { - switch runtime.GOOS { - case "linux", "windows", "darwin": - return true - } - - // If prefix is too small, lets assume it is a possible default prefix which is not yet supported - // we skip this prefix management - if prefix.Bits() <= minRangeBits { - log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix", - version.NetbirdVersion(), prefix) - return false - } - return true -} - -// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs. -func resolveURLsToIPs(urls []string) []net.IP { - var ips []net.IP - for _, rawurl := range urls { - u, err := url.Parse(rawurl) - if err != nil { - log.Errorf("Failed to parse url %s: %v", rawurl, err) - continue - } - ipAddrs, err := net.LookupIP(u.Hostname()) - if err != nil { - log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err) - continue - } - ips = append(ips, ipAddrs...) - } - return ips -} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 03e77e09b..2e5cf6649 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -28,14 +28,13 @@ const remotePeerKey2 = "remote1" func TestManagerUpdateRoutes(t *testing.T) { testCases := []struct { - name string - inputInitRoutes []*route.Route - inputRoutes []*route.Route - inputSerial uint64 - removeSrvRouter bool - serverRoutesExpected int - clientNetworkWatchersExpected int - clientNetworkWatchersExpectedAllowed int + name string + inputInitRoutes []*route.Route + inputRoutes []*route.Route + inputSerial uint64 + removeSrvRouter bool + serverRoutesExpected int + clientNetworkWatchersExpected int }{ { name: "Should create 2 client networks", @@ -201,9 +200,8 @@ func TestManagerUpdateRoutes(t *testing.T) { Enabled: true, }, }, - inputSerial: 1, - clientNetworkWatchersExpected: 0, - clientNetworkWatchersExpectedAllowed: 1, + inputSerial: 1, + clientNetworkWatchersExpected: 0, }, { name: "Remove 1 Client Route", @@ -417,10 +415,6 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) - - _, _, err = routeManager.Init() - - require.NoError(t, err, "should init route manager") defer routeManager.Stop() if testCase.removeSrvRouter { @@ -435,11 +429,7 @@ func TestManagerUpdateRoutes(t *testing.T) { err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) require.NoError(t, err, "should update routes") - expectedWatchers := testCase.clientNetworkWatchersExpected - if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 { - expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed - } - require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") + require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") if runtime.GOOS == "linux" && routeManager.serverRouter != nil { sr := routeManager.serverRouter.(*defaultServerRouter) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index dd2c28e59..a1214cbb9 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -6,7 +6,6 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" - "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -17,10 +16,6 @@ type MockManager struct { StopFunc func() } -func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return nil, nil, nil -} - // InitialRouteRange mock implementation of InitialRouteRange from Manager interface func (m *MockManager) InitialRouteRange() []string { return nil diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go deleted file mode 100644 index 8f9ff9f4b..000000000 --- a/client/internal/routemanager/routemanager.go +++ /dev/null @@ -1,126 +0,0 @@ -//go:build !android && !ios - -package routemanager - -import ( - "errors" - "fmt" - "net/netip" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -type ref struct { - count int - nexthop netip.Addr - intf string -} - -type RouteManager struct { - // refCountMap keeps track of the reference ref for prefixes - refCountMap map[netip.Prefix]ref - // prefixMap keeps track of the prefixes associated with a connection ID for removal - prefixMap map[nbnet.ConnectionID][]netip.Prefix - addRoute AddRouteFunc - removeRoute RemoveRouteFunc - mutex sync.Mutex -} - -type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error) -type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error - -func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager { - // TODO: read initial routing table into refCountMap - return &RouteManager{ - refCountMap: map[netip.Prefix]ref{}, - prefixMap: map[nbnet.ConnectionID][]netip.Prefix{}, - addRoute: addRoute, - removeRoute: removeRoute, - } -} - -func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() - - ref := rm.refCountMap[prefix] - log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix) - - // Add route to the system, only if it's a new prefix - if ref.count == 0 { - log.Debugf("Adding route for prefix %s", prefix) - nexthop, intf, err := rm.addRoute(prefix) - if errors.Is(err, ErrRouteNotFound) { - return nil - } - if errors.Is(err, ErrRouteNotAllowed) { - log.Debugf("Adding route for prefix %s: %s", prefix, err) - } - if err != nil { - return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) - } - ref.nexthop = nexthop - ref.intf = intf - } - - ref.count++ - rm.refCountMap[prefix] = ref - rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix) - - return nil -} - -func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() - - prefixes, ok := rm.prefixMap[connID] - if !ok { - log.Debugf("No prefixes found for connection ID %s", connID) - return nil - } - - var result *multierror.Error - for _, prefix := range prefixes { - ref := rm.refCountMap[prefix] - log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix) - if ref.count == 1 { - log.Debugf("Removing route for prefix %s", prefix) - // TODO: don't fail if the route is not found - if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { - result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) - continue - } - delete(rm.refCountMap, prefix) - } else { - ref.count-- - rm.refCountMap[prefix] = ref - } - } - delete(rm.prefixMap, connID) - - return result.ErrorOrNil() -} - -// Flush removes all references and routes from the system -func (rm *RouteManager) Flush() error { - rm.mutex.Lock() - defer rm.mutex.Unlock() - - var result *multierror.Error - for prefix := range rm.refCountMap { - log.Debugf("Removing route for prefix %s", prefix) - ref := rm.refCountMap[prefix] - if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { - result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) - } - } - rm.refCountMap = map[netip.Prefix]ref{} - rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{} - - return result.ErrorOrNil() -} diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index af82dc913..192367877 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -4,7 +4,6 @@ package routemanager import ( "context" - "fmt" "net/netip" "sync" @@ -49,7 +48,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er oldRoute := m.routes[routeID] err := m.removeFromServerNetwork(oldRoute) if err != nil { - log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v", + log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", oldRoute.ID, oldRoute.Network, err) } delete(m.routes, routeID) @@ -63,7 +62,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er err := m.addToServerNetwork(newRoute) if err != nil { - log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err) + log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) continue } m.routes[id] = newRoute @@ -82,22 +81,15 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): - log.Infof("Not removing from server network because context is done") + log.Infof("not removing from server network because context is done") return m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) + err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) if err != nil { - return fmt.Errorf("parse prefix: %w", err) + return err } - - err = m.firewall.RemoveRoutingRules(routerPair) - if err != nil { - return fmt.Errorf("remove routing rules: %w", err) - } - delete(m.routes, route.ID) state := m.statusRecorder.GetLocalPeerState() @@ -111,22 +103,15 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): - log.Infof("Not adding to server network because context is done") + log.Infof("not adding to server network because context is done") return m.ctx.Err() default: m.mux.Lock() defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) + err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) if err != nil { - return fmt.Errorf("parse prefix: %w", err) + return err } - - err = m.firewall.InsertRoutingRules(routerPair) - if err != nil { - return fmt.Errorf("insert routing rules: %w", err) - } - m.routes[route.ID] = route state := m.statusRecorder.GetLocalPeerState() @@ -144,33 +129,23 @@ func (m *defaultServerRouter) cleanUp() { m.mux.Lock() defer m.mux.Unlock() for _, r := range m.routes { - routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r) + err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r)) if err != nil { - log.Errorf("Failed to convert route to router pair: %v", err) - continue - } - - err = m.firewall.RemoveRoutingRules(routerPair) - if err != nil { - log.Errorf("Failed to remove cleanup route: %v", err) + log.Warnf("failed to remove clean up route: %s", r.ID) } + state := m.statusRecorder.GetLocalPeerState() + state.Routes = nil + m.statusRecorder.UpdateLocalPeerState(state) } - - state := m.statusRecorder.GetLocalPeerState() - state.Routes = nil - m.statusRecorder.UpdateLocalPeerState(state) } -func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) { - parsed, err := netip.ParsePrefix(source) - if err != nil { - return firewall.RouterPair{}, err - } +func routeToRouterPair(source string, route *route.Route) firewall.RouterPair { + parsed := netip.MustParsePrefix(source).Masked() return firewall.RouterPair{ ID: route.ID, Source: parsed.String(), Destination: route.Network.Masked().String(), Masquerade: route.Masquerade, - }, nil + } } diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go deleted file mode 100644 index a91f53636..000000000 --- a/client/internal/routemanager/systemops.go +++ /dev/null @@ -1,407 +0,0 @@ -//go:build !android && !ios - -package routemanager - -import ( - "context" - "errors" - "fmt" - "net" - "net/netip" - - "github.com/hashicorp/go-multierror" - "github.com/libp2p/go-netroute" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" - nbnet "github.com/netbirdio/netbird/util/net" -) - -var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) -var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) -var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) -var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) - -var ErrRouteNotFound = errors.New("route not found") -var ErrRouteNotAllowed = errors.New("route not allowed") - -// TODO: fix: for default our wg address now appears as the default gw -func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - addr := netip.IPv4Unspecified() - if prefix.Addr().Is6() { - addr = netip.IPv6Unspecified() - } - - defaultGateway, _, err := getNextHop(addr) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - return fmt.Errorf("get existing route gateway: %s", err) - } - - if !prefix.Contains(defaultGateway) { - log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) - return nil - } - - gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) - if defaultGateway.Is6() { - gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) - } - - ok, err := existsInRouteTable(gatewayPrefix) - if err != nil { - return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) - } - - if ok { - log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) - return nil - } - - var exitIntf string - gatewayHop, intf, err := getNextHop(defaultGateway) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) - } - if intf != nil { - exitIntf = intf.Name - } - - log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) -} - -func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { - r, err := netroute.New() - if err != nil { - return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) - } - intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) - if err != nil { - log.Warnf("Failed to get route for %s: %v", ip, err) - return netip.Addr{}, nil, ErrRouteNotFound - } - - log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) - if gateway == nil { - if preferredSrc == nil { - return netip.Addr{}, nil, ErrRouteNotFound - } - log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) - - addr, ok := netip.AddrFromSlice(preferredSrc) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) - } - return addr.Unmap(), intf, nil - } - - addr, ok := netip.AddrFromSlice(gateway) - if !ok { - return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) - } - - return addr.Unmap(), intf, nil -} - -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } - } - return false, nil -} - -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil - } - } - return false, nil -} - -// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. -// If the next hop or interface is pointing to the VPN interface, it will return the initial values. -func addRouteToNonVPNIntf( - prefix netip.Prefix, - vpnIntf *iface.WGIface, - initialNextHop netip.Addr, - initialIntf *net.Interface, -) (netip.Addr, string, error) { - addr := prefix.Addr() - switch { - case addr.IsLoopback(), - addr.IsLinkLocalUnicast(), - addr.IsLinkLocalMulticast(), - addr.IsInterfaceLocalMulticast(), - addr.IsUnspecified(), - addr.IsMulticast(): - - return netip.Addr{}, "", ErrRouteNotAllowed - } - - // Determine the exit interface and next hop for the prefix, so we can add a specific route - nexthop, intf, err := getNextHop(addr) - if err != nil { - return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err) - } - - log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) - exitNextHop := nexthop - var exitIntf string - if intf != nil { - exitIntf = intf.Name - } - - vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) - if !ok { - return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") - } - - // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values - if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { - log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) - exitNextHop = initialNextHop - if initialIntf != nil { - exitIntf = initialIntf.Name - } - } - - log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) - if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { - return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) - } - - return exitNextHop, exitIntf, nil -} - -// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix -// in two /1 prefixes to avoid replacing the existing default route -func genericAddVPNRoute(prefix netip.Prefix, intf string) error { - if prefix == defaultv4 { - if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { - return err - } - if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return err - } - - // TODO: remove once IPv6 is supported on the interface - if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil - } else if prefix == defaultv6 { - if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil - } - - return addNonExistingRoute(prefix, intf) -} - -// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table -func addNonExistingRoute(prefix netip.Prefix, intf string) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return fmt.Errorf("exists in route table: %w", err) - } - if ok { - log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return fmt.Errorf("sub range: %w", err) - } - - if ok { - err := addRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return addToRouteTable(prefix, netip.Addr{}, intf) -} - -// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, -// it will remove the split /1 prefixes -func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error { - if prefix == defaultv4 { - var result *multierror.Error - if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - // TODO: remove once IPv6 is supported on the interface - if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() - } else if prefix == defaultv6 { - var result *multierror.Error - if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() - } - - return removeFromRouteTable(prefix, netip.Addr{}, intf) -} - -func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return nil, fmt.Errorf("parse IP address: %s", ip) - } - addr = addr.Unmap() - - var prefixLength int - switch { - case addr.Is4(): - prefixLength = 32 - case addr.Is6(): - prefixLength = 128 - default: - return nil, fmt.Errorf("invalid IP address: %s", addr) - } - - prefix := netip.PrefixFrom(addr, prefixLength) - return &prefix, nil -} - -func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - log.Errorf("Unable to get initial v4 default next hop: %v", err) - } - initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - log.Errorf("Unable to get initial v6 default next hop: %v", err) - } - - *routeManager = NewRouteManager( - func(prefix netip.Prefix) (netip.Addr, string, error) { - addr := prefix.Addr() - nexthop, intf := initialNextHopV4, initialIntfV4 - if addr.Is6() { - nexthop, intf = initialNextHopV6, initialIntfV6 - } - return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) - }, - removeFromRouteTable, - ) - - return setupHooks(*routeManager, initAddresses) -} - -func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { - if routeManager == nil { - return nil - } - - // TODO: Remove hooks selectively - nbnet.RemoveDialerHooks() - nbnet.RemoveListenerHooks() - - if err := routeManager.Flush(); err != nil { - return fmt.Errorf("flush route manager: %w", err) - } - - return nil -} - -func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { - prefix, err := getPrefixFromIP(ip) - if err != nil { - return fmt.Errorf("convert ip to prefix: %w", err) - } - - if err := routeManager.AddRouteRef(connID, *prefix); err != nil { - return fmt.Errorf("adding route reference: %v", err) - } - - return nil - } - afterHook := func(connID nbnet.ConnectionID) error { - if err := routeManager.RemoveRouteRef(connID); err != nil { - return fmt.Errorf("remove route reference: %w", err) - } - - return nil - } - - for _, ip := range initAddresses { - if err := beforeHook("init", ip); err != nil { - log.Errorf("Failed to add route reference: %v", err) - } - } - - nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { - if ctx.Err() != nil { - return ctx.Err() - } - - var result *multierror.Error - for _, ip := range resolvedIPs { - result = multierror.Append(result, beforeHook(connID, ip.IP)) - } - return result.ErrorOrNil() - }) - - nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { - return afterHook(connID) - }) - - nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { - return beforeHook(connID, ip.IP) - }) - - nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { - return afterHook(connID) - }) - - return beforeHook, afterHook, nil -} diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 34d2d270f..950a26843 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -1,33 +1,13 @@ package routemanager import ( - "net" "net/netip" - "runtime" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" ) -func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return nil, nil, nil -} - -func cleanupRouting() error { +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { return nil } -func enableIPForwarding() error { - log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) - return nil -} - -func addVPNRoute(netip.Prefix, string) error { - return nil -} - -func removeVPNRoute(netip.Prefix, string) error { +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { return nil } diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index 173e7c0e8..b2da8075c 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -1,4 +1,5 @@ //go:build darwin || dragonfly || freebsd || netbsd || openbsd +// +build darwin dragonfly freebsd netbsd openbsd package routemanager diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go deleted file mode 100644 index f34964a83..000000000 --- a/client/internal/routemanager/systemops_darwin.go +++ /dev/null @@ -1,61 +0,0 @@ -//go:build darwin && !ios - -package routemanager - -import ( - "fmt" - "net" - "net/netip" - "os/exec" - "strings" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" -) - -var routeManager *RouteManager - -func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) -} - -func cleanupRouting() error { - return cleanupRoutingWithRouteManager(routeManager) -} - -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - return routeCmd("add", prefix, nexthop, intf) -} - -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - return routeCmd("delete", prefix, nexthop, intf) -} - -func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { - inet := "-inet" - if prefix.Addr().Is6() { - inet = "-inet6" - // Special case for IPv6 split default route, pointing to the wg interface fails - // TODO: Remove once we have IPv6 support on the interface - if prefix.Bits() == 1 { - intf = "lo0" - } - } - - args := []string{"-n", action, inet, prefix.String()} - if nexthop.IsValid() { - args = append(args, nexthop.Unmap().String()) - } else if intf != "" { - args = append(args, "-interface", intf) - } - - out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s: %s", strings.Join(args, " "), out) - - if err != nil { - return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err) - } - return nil -} diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops_darwin_test.go deleted file mode 100644 index 5c5aaa24f..000000000 --- a/client/internal/routemanager/systemops_darwin_test.go +++ /dev/null @@ -1,100 +0,0 @@ -//go:build !ios - -package routemanager - -import ( - "fmt" - "net" - "os/exec" - "regexp" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var expectedVPNint = "utun100" -var expectedExternalInt = "lo0" -var expectedInternalInt = "lo0" - -func init() { - testCases = append(testCases, []testCase{ - { - name: "To more specific route without custom dialer via vpn", - destination: "10.10.0.2:53", - expectedInterface: expectedVPNint, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53), - }, - }...) -} - -func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { - t.Helper() - - err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run() - require.NoError(t, err, "Failed to create loopback alias") - - t.Cleanup(func() { - err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run() - assert.NoError(t, err, "Failed to remove loopback alias") - }) - - return "lo0" -} - -func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) { - t.Helper() - - var originalNexthop net.IP - if dstCIDR == "0.0.0.0/0" { - var err error - originalNexthop, err = fetchOriginalGateway() - if err != nil { - t.Logf("Failed to fetch original gateway: %v", err) - } - - if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil { - t.Logf("Failed to delete route: %v, output: %s", err, output) - } - } - - t.Cleanup(func() { - if originalNexthop != nil { - err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run() - assert.NoError(t, err, "Failed to restore original route") - } - }) - - err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run() - require.NoError(t, err, "Failed to add route") - - t.Cleanup(func() { - err := exec.Command("route", "delete", "-net", dstCIDR).Run() - assert.NoError(t, err, "Failed to remove route") - }) -} - -func fetchOriginalGateway() (net.IP, error) { - output, err := exec.Command("route", "-n", "get", "default").CombinedOutput() - if err != nil { - return nil, err - } - - matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output)) - if len(matches) == 0 { - return nil, fmt.Errorf("gateway not found") - } - - return net.ParseIP(matches[1]), nil -} - -func setupDummyInterfacesAndRoutes(t *testing.T) { - t.Helper() - - defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24") - addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) - - otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24") - addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) -} diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go index 34d2d270f..aae0f8dc8 100644 --- a/client/internal/routemanager/systemops_ios.go +++ b/client/internal/routemanager/systemops_ios.go @@ -1,33 +1,15 @@ +//go:build ios + package routemanager import ( - "net" "net/netip" - "runtime" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" ) -func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return nil, nil, nil -} - -func cleanupRouting() error { +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { return nil } -func enableIPForwarding() error { - log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) - return nil -} - -func addVPNRoute(netip.Prefix, string) error { - return nil -} - -func removeVPNRoute(netip.Prefix, string) error { +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { return nil } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index ef4643727..0562826a5 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -3,342 +3,142 @@ package routemanager import ( - "bufio" - "context" - "errors" - "fmt" "net" "net/netip" "os" "syscall" - "time" + "unsafe" - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" - nbnet "github.com/netbirdio/netbird/util/net" ) -const ( - // NetbirdVPNTableID is the ID of the custom routing table used by Netbird. - NetbirdVPNTableID = 0x1BD0 - // NetbirdVPNTableName is the name of the custom routing table used by Netbird. - NetbirdVPNTableName = "netbird" +// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html +// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'. +type routeInfoInMemory struct { + Family byte + DstLen byte + SrcLen byte + TOS byte - // rtTablesPath is the path to the file containing the routing table names. - rtTablesPath = "/etc/iproute2/rt_tables" + Table byte + Protocol byte + Scope byte + Type byte - // ipv4ForwardingPath is the path to the file containing the IP forwarding setting. - ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" -) - -var ErrTableIDExists = errors.New("ID exists with different name") - -var routeManager = &RouteManager{} -var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" - -type ruleParams struct { - fwmark int - tableID int - family int - priority int - invert bool - suppressPrefix int - description string + Flags uint32 } -func getSetupRules() []ruleParams { - return []ruleParams{ - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"}, - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"}, - } -} +const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" -// setupRouting establishes the routing configuration for the VPN, including essential rules -// to ensure proper traffic flow for management, locally configured routes, and VPN traffic. -// -// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over -// potential routes received and configured for the VPN. This rule is skipped for the default route and routes -// that are not in the main table. -// -// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. -// This table is where a default route or other specific routes received from the management server are configured, -// enabling VPN connectivity. -// -// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. -func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { - if isLegacy { - log.Infof("Using legacy routing setup") - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +func addToRouteTable(prefix netip.Prefix, addr string) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return err } - if err = addRoutingTableName(); err != nil { - log.Errorf("Error adding routing table name: %v", err) + addrMask := "/32" + if prefix.Addr().Unmap().Is6() { + addrMask = "/128" } - defer func() { - if err != nil { - if cleanErr := cleanupRouting(); cleanErr != nil { - log.Errorf("Error cleaning up routing: %v", cleanErr) - } - } - }() - - rules := getSetupRules() - for _, rule := range rules { - if err := addRule(rule); err != nil { - if errors.Is(err, syscall.EOPNOTSUPP) { - log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") - isLegacy = true - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) - } - return nil, nil, fmt.Errorf("%s: %w", rule.description, err) - } + ip, _, err := net.ParseCIDR(addr + addrMask) + if err != nil { + return err } - return nil, nil, nil -} - -// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. -// It systematically removes the three rules and any associated routing table entries to ensure a clean state. -// The function uses error aggregation to report any errors encountered during the cleanup process. -func cleanupRouting() error { - if isLegacy { - return cleanupRoutingWithRouteManager(routeManager) + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Dst: ipNet, + Gw: ip, } - var result *multierror.Error - - if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { - result = multierror.Append(result, fmt.Errorf("flush routes v4: %w", err)) - } - if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { - result = multierror.Append(result, fmt.Errorf("flush routes v6: %w", err)) + err = netlink.RouteAdd(route) + if err != nil { + return err } - rules := getSetupRules() - for _, rule := range rules { - if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) { - result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err)) - } - } - - return result.ErrorOrNil() -} - -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) -} - -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) -} - -func addVPNRoute(prefix netip.Prefix, intf string) error { - if isLegacy { - return genericAddVPNRoute(prefix, intf) - } - - // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 - - // TODO remove this once we have ipv6 support - if prefix == defaultv4 { - if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { - return fmt.Errorf("add blackhole: %w", err) - } - } - if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { - return fmt.Errorf("add route: %w", err) - } return nil } -func removeVPNRoute(prefix netip.Prefix, intf string) error { - if isLegacy { - return genericRemoveVPNRoute(prefix, intf) +func removeFromRouteTable(prefix netip.Prefix, addr string) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return err } - // TODO remove this once we have ipv6 support - if prefix == defaultv4 { - if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { - return fmt.Errorf("remove unreachable route: %w", err) - } + addrMask := "/32" + if prefix.Addr().Unmap().Is6() { + addrMask = "/128" } - if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { - return fmt.Errorf("remove route: %w", err) + + ip, _, err := net.ParseCIDR(addr + addrMask) + if err != nil { + return err } + + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Dst: ipNet, + Gw: ip, + } + + err = netlink.RouteDel(route) + if err != nil { + return err + } + return nil } func getRoutesFromTable() ([]netip.Prefix, error) { - v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4) + tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC) if err != nil { - return nil, fmt.Errorf("get v4 routes: %w", err) + return nil, err } - v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6) + msgs, err := syscall.ParseNetlinkMessage(tab) if err != nil { - return nil, fmt.Errorf("get v6 routes: %w", err) - + return nil, err } - return append(v4Routes, v6Routes...), nil -} - -// getRoutes fetches routes from a specific routing table identified by tableID. -func getRoutes(tableID, family int) ([]netip.Prefix, error) { var prefixList []netip.Prefix - - routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) - if err != nil { - return nil, fmt.Errorf("list routes from table %d: %v", tableID, err) - } - - for _, route := range routes { - if route.Dst != nil { - addr, ok := netip.AddrFromSlice(route.Dst.IP) - if !ok { - return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP) +loop: + for _, m := range msgs { + switch m.Header.Type { + case syscall.NLMSG_DONE: + break loop + case syscall.RTM_NEWROUTE: + rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0])) + msg := m + attrs, err := syscall.ParseNetlinkRouteAttr(&msg) + if err != nil { + return nil, err + } + if rt.Family != syscall.AF_INET { + continue loop } - ones, _ := route.Dst.Mask.Size() - - prefix := netip.PrefixFrom(addr, ones) - if prefix.IsValid() { - prefixList = append(prefixList, prefix) + for _, attr := range attrs { + if attr.Attr.Type == syscall.RTA_DST { + addr, ok := netip.AddrFromSlice(attr.Value) + if !ok { + continue + } + mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8) + cidr, _ := mask.Size() + routePrefix := netip.PrefixFrom(addr, cidr) + if routePrefix.IsValid() && routePrefix.Addr().Is4() { + prefixList = append(prefixList, routePrefix) + } + } } } } - return prefixList, nil } -// addRoute adds a route to a specific routing table identified by tableID. -func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { - route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Table: tableID, - Family: getAddressFamily(prefix), - } - - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - route.Dst = ipNet - - if err := addNextHop(addr, intf, route); err != nil { - return fmt.Errorf("add gateway and device: %w", err) - } - - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("netlink add route: %w", err) - } - - return nil -} - -// addUnreachableRoute adds an unreachable route for the specified IP family and routing table. -// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6. -// tableID specifies the routing table to which the unreachable route will be added. -func addUnreachableRoute(prefix netip.Prefix, tableID int) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - - route := &netlink.Route{ - Type: syscall.RTN_UNREACHABLE, - Table: tableID, - Family: getAddressFamily(prefix), - Dst: ipNet, - } - - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("netlink add unreachable route: %w", err) - } - - return nil -} - -func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - - route := &netlink.Route{ - Type: syscall.RTN_UNREACHABLE, - Table: tableID, - Family: getAddressFamily(prefix), - Dst: ipNet, - } - - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("netlink remove unreachable route: %w", err) - } - - return nil - -} - -// removeRoute removes a route from a specific routing table identified by tableID. -func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - - route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Table: tableID, - Family: getAddressFamily(prefix), - Dst: ipNet, - } - - if err := addNextHop(addr, intf, route); err != nil { - return fmt.Errorf("add gateway and device: %w", err) - } - - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("netlink remove route: %w", err) - } - - return nil -} - -func flushRoutes(tableID, family int) error { - routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) - if err != nil { - return fmt.Errorf("list routes from table %d: %w", tableID, err) - } - - var result *multierror.Error - for i := range routes { - route := routes[i] - // unreachable default routes don't come back with Dst set - if route.Gw == nil && route.Src == nil && route.Dst == nil { - if family == netlink.FAMILY_V4 { - routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)} - } else { - routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)} - } - } - if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err)) - } - } - - return result.ErrorOrNil() -} - func enableIPForwarding() error { bytes, err := os.ReadFile(ipv4ForwardingPath) if err != nil { - return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err) + return err } // check if it is already enabled @@ -347,162 +147,5 @@ func enableIPForwarding() error { return nil } - //nolint:gosec - if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil { - return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err) - } - return nil -} - -// entryExists checks if the specified ID or name already exists in the rt_tables file -// and verifies if existing names start with "netbird_". -func entryExists(file *os.File, id int) (bool, error) { - if _, err := file.Seek(0, 0); err != nil { - return false, fmt.Errorf("seek rt_tables: %w", err) - } - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - var existingID int - var existingName string - if _, err := fmt.Sscanf(line, "%d %s\n", &existingID, &existingName); err == nil { - if existingID == id { - if existingName != NetbirdVPNTableName { - return true, ErrTableIDExists - } - return true, nil - } - } - } - if err := scanner.Err(); err != nil { - return false, fmt.Errorf("scan rt_tables: %w", err) - } - return false, nil -} - -// addRoutingTableName adds human-readable names for custom routing tables. -func addRoutingTableName() error { - file, err := os.Open(rtTablesPath) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - return nil - } - return fmt.Errorf("open rt_tables: %w", err) - } - defer func() { - if err := file.Close(); err != nil { - log.Errorf("Error closing rt_tables: %v", err) - } - }() - - exists, err := entryExists(file, NetbirdVPNTableID) - if err != nil { - return fmt.Errorf("verify entry %d, %s: %w", NetbirdVPNTableID, NetbirdVPNTableName, err) - } - if exists { - return nil - } - - // Reopen the file in append mode to add new entries - if err := file.Close(); err != nil { - log.Errorf("Error closing rt_tables before appending: %v", err) - } - file, err = os.OpenFile(rtTablesPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) - if err != nil { - return fmt.Errorf("open rt_tables for appending: %w", err) - } - - if _, err := file.WriteString(fmt.Sprintf("\n%d\t%s\n", NetbirdVPNTableID, NetbirdVPNTableName)); err != nil { - return fmt.Errorf("append entry to rt_tables: %w", err) - } - - return nil -} - -// addRule adds a routing rule to a specific routing table identified by tableID. -func addRule(params ruleParams) error { - rule := netlink.NewRule() - rule.Table = params.tableID - rule.Mark = params.fwmark - rule.Family = params.family - rule.Priority = params.priority - rule.Invert = params.invert - rule.SuppressPrefixlen = params.suppressPrefix - - if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - return fmt.Errorf("add routing rule: %w", err) - } - - return nil -} - -// removeRule removes a routing rule from a specific routing table identified by tableID. -func removeRule(params ruleParams) error { - rule := netlink.NewRule() - rule.Table = params.tableID - rule.Mark = params.fwmark - rule.Family = params.family - rule.Invert = params.invert - rule.Priority = params.priority - rule.SuppressPrefixlen = params.suppressPrefix - - if err := netlink.RuleDel(rule); err != nil { - return fmt.Errorf("remove routing rule: %w", err) - } - - return nil -} - -func removeAllRules(params ruleParams) error { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - done := make(chan error, 1) - go func() { - for { - if ctx.Err() != nil { - done <- ctx.Err() - return - } - if err := removeRule(params); err != nil { - if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) { - done <- nil - return - } - done <- err - return - } - } - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-done: - return err - } -} - -// addNextHop adds the gateway and device to the route. -func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { - if addr.IsValid() { - route.Gw = addr.AsSlice() - } - - if intf != "" { - link, err := netlink.LinkByName(intf) - if err != nil { - return fmt.Errorf("set interface %s: %w", intf, err) - } - route.LinkIndex = link.Attrs().Index - } - - return nil -} - -func getAddressFamily(prefix netip.Prefix) int { - if prefix.Addr().Is4() { - return netlink.FAMILY_V4 - } - return netlink.FAMILY_V6 + return os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) //nolint:gosec } diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go deleted file mode 100644 index 0043c3f4e..000000000 --- a/client/internal/routemanager/systemops_linux_test.go +++ /dev/null @@ -1,207 +0,0 @@ -//go:build !android - -package routemanager - -import ( - "errors" - "fmt" - "net" - "os" - "strings" - "syscall" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/vishvananda/netlink" -) - -var expectedVPNint = "wgtest0" -var expectedLoopbackInt = "lo" -var expectedExternalInt = "dummyext0" -var expectedInternalInt = "dummyint0" - -func init() { - testCases = append(testCases, []testCase{ - { - name: "To more specific route without custom dialer via physical interface", - destination: "10.10.0.2:53", - expectedInterface: expectedInternalInt, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), - }, - { - name: "To more specific route (local) without custom dialer via physical interface", - destination: "127.0.10.1:53", - expectedInterface: expectedLoopbackInt, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), - }, - }...) -} - -func TestEntryExists(t *testing.T) { - tempDir := t.TempDir() - tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir) - - content := []string{ - "1000 reserved", - fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName), - "9999 other_table", - } - require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644)) - - file, err := os.Open(tempFilePath) - require.NoError(t, err) - defer func() { - assert.NoError(t, file.Close()) - }() - - tests := []struct { - name string - id int - shouldExist bool - err error - }{ - { - name: "ExistsWithNetbirdPrefix", - id: 7120, - shouldExist: true, - err: nil, - }, - { - name: "ExistsWithDifferentName", - id: 1000, - shouldExist: true, - err: ErrTableIDExists, - }, - { - name: "DoesNotExist", - id: 1234, - shouldExist: false, - err: nil, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - exists, err := entryExists(file, tc.id) - if tc.err != nil { - assert.ErrorIs(t, err, tc.err) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tc.shouldExist, exists) - }) - } -} - -func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { - t.Helper() - - dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}} - err := netlink.LinkDel(dummy) - if err != nil && !errors.Is(err, syscall.EINVAL) { - t.Logf("Failed to delete dummy interface: %v", err) - } - - err = netlink.LinkAdd(dummy) - require.NoError(t, err) - - err = netlink.LinkSetUp(dummy) - require.NoError(t, err) - - if ipAddressCIDR != "" { - addr, err := netlink.ParseAddr(ipAddressCIDR) - require.NoError(t, err) - err = netlink.AddrAdd(dummy, addr) - require.NoError(t, err) - } - - t.Cleanup(func() { - err := netlink.LinkDel(dummy) - assert.NoError(t, err) - }) - - return dummy.Name -} - -func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { - t.Helper() - - _, dstIPNet, err := net.ParseCIDR(dstCIDR) - require.NoError(t, err) - - // Handle existing routes with metric 0 - var originalNexthop net.IP - var originalLinkIndex int - if dstIPNet.String() == "0.0.0.0/0" { - var err error - originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) - if err != nil && !errors.Is(err, ErrRouteNotFound) { - t.Logf("Failed to fetch original gateway: %v", err) - } - - if originalNexthop != nil { - err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) - switch { - case err != nil && !errors.Is(err, syscall.ESRCH): - t.Logf("Failed to delete route: %v", err) - case err == nil: - t.Cleanup(func() { - err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) - if err != nil && !errors.Is(err, syscall.EEXIST) { - t.Fatalf("Failed to add route: %v", err) - } - }) - default: - t.Logf("Failed to delete route: %v", err) - } - } - } - - link, err := netlink.LinkByName(intf) - require.NoError(t, err) - linkIndex := link.Attrs().Index - - route := &netlink.Route{ - Dst: dstIPNet, - Gw: gw, - LinkIndex: linkIndex, - } - err = netlink.RouteDel(route) - if err != nil && !errors.Is(err, syscall.ESRCH) { - t.Logf("Failed to delete route: %v", err) - } - - err = netlink.RouteAdd(route) - if err != nil && !errors.Is(err, syscall.EEXIST) { - t.Fatalf("Failed to add route: %v", err) - } - require.NoError(t, err) -} - -func fetchOriginalGateway(family int) (net.IP, int, error) { - routes, err := netlink.RouteList(nil, family) - if err != nil { - return nil, 0, err - } - - for _, route := range routes { - if route.Dst == nil && route.Priority == 0 { - return route.Gw, route.LinkIndex, nil - } - } - - return nil, 0, ErrRouteNotFound -} - -func setupDummyInterfacesAndRoutes(t *testing.T) { - t.Helper() - - defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") - addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) - - otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24") - addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) -} diff --git a/client/internal/routemanager/systemops_nonandroid.go b/client/internal/routemanager/systemops_nonandroid.go new file mode 100644 index 000000000..11247c7dc --- /dev/null +++ b/client/internal/routemanager/systemops_nonandroid.go @@ -0,0 +1,120 @@ +//go:build !android && !ios + +package routemanager + +import ( + "fmt" + "net" + "net/netip" + + "github.com/libp2p/go-netroute" + log "github.com/sirupsen/logrus" +) + +var errRouteNotFound = fmt.Errorf("route not found") + +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { + ok, err := existsInRouteTable(prefix) + if err != nil { + return err + } + if ok { + log.Warnf("skipping adding a new route for network %s because it already exists", prefix) + return nil + } + + ok, err = isSubRange(prefix) + if err != nil { + return err + } + + if ok { + err := addRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("unable to add route for current default gateway route. Will proceed without it. error: %s", err) + } + } + + return addToRouteTable(prefix, addr) +} + +func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) + if err != nil && err != errRouteNotFound { + return err + } + + addr := netip.MustParseAddr(defaultGateway.String()) + + if !prefix.Contains(addr) { + log.Debugf("skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) + return nil + } + + gatewayPrefix := netip.PrefixFrom(addr, 32) + + ok, err := existsInRouteTable(gatewayPrefix) + if err != nil { + return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) + } + + if ok { + log.Debugf("skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + return nil + } + + gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix) + if err != nil && err != errRouteNotFound { + return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) + } + log.Debugf("adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return addToRouteTable(gatewayPrefix, gatewayHop.String()) +} + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, err + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, err + } + for _, tableRoute := range routes { + if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} + +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { + return removeFromRouteTable(prefix, addr) +} + +func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { + r, err := netroute.New() + if err != nil { + return nil, err + } + _, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice()) + if err != nil { + log.Errorf("getting routes returned an error: %v", err) + return nil, errRouteNotFound + } + + if gateway == nil { + return preferredSrc, nil + } + + return gateway, nil +} diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_nonandroid_test.go similarity index 59% rename from client/internal/routemanager/systemops_test.go rename to client/internal/routemanager/systemops_nonandroid_test.go index 97386f19a..6f32d9634 100644 --- a/client/internal/routemanager/systemops_test.go +++ b/client/internal/routemanager/systemops_nonandroid_test.go @@ -1,32 +1,24 @@ -//go:build !android && !ios +//go:build !android package routemanager import ( "bytes" - "context" "fmt" "net" "net/netip" "os" - "runtime" "strings" "testing" "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/iface" ) -type dialer interface { - Dial(network, address string) (net.Conn, error) - DialContext(ctx context.Context, network, address string) (net.Conn, error) -} - func TestAddRemoveRoutes(t *testing.T) { testCases := []struct { name string @@ -61,30 +53,27 @@ func TestAddRemoveRoutes(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - _, _, err = setupRouting(nil, nil) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) - require.NoError(t, err, "genericAddVPNRoute should not return err") + err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) + require.NoError(t, err, "getExistingRIBRouteGateway should not return err") if testCase.shouldRouteToWireguard { - assertWGOutInterface(t, testCase.prefix, wgInterface, false) + require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") } else { - assertWGOutInterface(t, testCase.prefix, wgInterface, true) + require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface") } exists, err := existsInRouteTable(testCase.prefix) require.NoError(t, err, "existsInRouteTable should not return err") if exists && testCase.shouldRouteToWireguard { - err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) - require.NoError(t, err, "genericRemoveVPNRoute should not return err") + err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String()) + require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err") - prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) - require.NoError(t, err, "getNextHop should not return err") + prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) + require.NoError(t, err, "getExistingRIBRouteGateway should not return err") - internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) require.NoError(t, err) if testCase.shouldBeRemoved { @@ -97,12 +86,12 @@ func TestAddRemoveRoutes(t *testing.T) { } } -func TestGetNextHop(t *testing.T) { - gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) +func TestGetExistingRIBRouteGateway(t *testing.T) { + gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) } - if !gateway.IsValid() { + if gateway == nil { t.Fatal("should return a gateway") } addresses, err := net.InterfaceAddrs() @@ -124,11 +113,11 @@ func TestGetNextHop(t *testing.T) { } } - localIP, _, err := getNextHop(testingPrefix.Addr()) + localIP, err := getExistingRIBRouteGateway(testingPrefix) if err != nil { t.Fatal("shouldn't return error: ", err) } - if !localIP.IsValid() { + if localIP == nil { t.Fatal("should return a gateway for local network") } if localIP.String() == gateway.String() { @@ -139,8 +128,8 @@ func TestGetNextHop(t *testing.T) { } } -func TestAddExistAndRemoveRoute(t *testing.T) { - defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) +func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { + defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) t.Log("defaultGateway: ", defaultGateway) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) @@ -200,14 +189,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") + MockAddr := wgInterface.Address().IP.String() + // Prepare the environment if testCase.preExistingPrefix.IsValid() { - err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) + err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr) require.NoError(t, err, "should not return err when adding pre-existing route") } // Add the route - err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) + err = addToRouteTableIfNoExists(testCase.prefix, MockAddr) require.NoError(t, err, "should not return err when adding route") if testCase.shouldAddRoute { @@ -217,7 +208,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) { require.True(t, ok, "route should exist") // remove route again if added - err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) + err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr) require.NoError(t, err, "should not return err") } @@ -226,7 +217,6 @@ func TestAddExistAndRemoveRoute(t *testing.T) { ok, err := existsInRouteTable(testCase.prefix) t.Log("Buffer string: ", buf.String()) require.NoError(t, err, "should not return err") - if !strings.Contains(buf.String(), "because it already exists") { require.False(t, ok, "route should not exist") } @@ -234,6 +224,31 @@ func TestAddExistAndRemoveRoute(t *testing.T) { } } +func TestExistsInRouteTable(t *testing.T) { + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var addressPrefixes []netip.Prefix + for _, address := range addresses { + p := netip.MustParsePrefix(address.String()) + if p.Addr().Is4() { + addressPrefixes = append(addressPrefixes, p.Masked()) + } + } + + for _, prefix := range addressPrefixes { + exists, err := existsInRouteTable(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address exists in route table: ", err) + } + if !exists { + t.Fatalf("address %s should exist in route table", prefix) + } + } +} + func TestIsSubRange(t *testing.T) { addresses, err := net.InterfaceAddrs() if err != nil { @@ -271,132 +286,3 @@ func TestIsSubRange(t *testing.T) { } } } - -func TestExistsInRouteTable(t *testing.T) { - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var addressPrefixes []netip.Prefix - for _, address := range addresses { - p := netip.MustParsePrefix(address.String()) - if p.Addr().Is6() { - continue - } - // Windows sometimes has hidden interface link local addrs that don't turn up on any interface - if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() { - continue - } - // Linux loopback 127/8 is in the local table, not in the main table and always takes precedence - if runtime.GOOS == "linux" && p.Addr().IsLoopback() { - continue - } - - addressPrefixes = append(addressPrefixes, p.Masked()) - } - - for _, prefix := range addressPrefixes { - exists, err := existsInRouteTable(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address exists in route table: ", err) - } - if !exists { - t.Fatalf("address %s should exist in route table", prefix) - } - } -} - -func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { - t.Helper() - - peerPrivateKey, err := wgtypes.GeneratePrivateKey() - require.NoError(t, err) - - newNet, err := stdnet.NewNet() - require.NoError(t, err) - - wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WireGuard interface") - - err = wgInterface.Create() - require.NoError(t, err, "should create testing WireGuard interface") - - t.Cleanup(func() { - wgInterface.Close() - }) - - return wgInterface -} - -func setupTestEnv(t *testing.T) { - t.Helper() - - setupDummyInterfacesAndRoutes(t) - - wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) - t.Cleanup(func() { - assert.NoError(t, wgIface.Close()) - }) - - _, _, err := setupRouting(nil, wgIface) - require.NoError(t, err, "setupRouting should not return err") - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - // default route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 10.0.0.0/8 route exists in main table and vpn table - err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 10.10.0.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // 127.0.10.0/24 more specific route exists in vpn table - err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) - - // unique route in vpn table - err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) - require.NoError(t, err, "addVPNRoute should not return err") - t.Cleanup(func() { - err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) - assert.NoError(t, err, "removeVPNRoute should not return err") - }) -} - -func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { - t.Helper() - if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() { - return - } - - prefixGateway, _, err := getNextHop(prefix.Addr()) - require.NoError(t, err, "getNextHop should not return err") - if invert { - assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") - } else { - assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") - } -} diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index 38026107e..47bd60eb0 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -1,23 +1,41 @@ -//go:build !linux && !ios +//go:build !linux +// +build !linux package routemanager import ( "net/netip" + "os/exec" "runtime" log "github.com/sirupsen/logrus" ) -func enableIPForwarding() error { - log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) +func addToRouteTable(prefix netip.Prefix, addr string) error { + cmd := exec.Command("route", "add", prefix.String(), addr) + out, err := cmd.Output() + if err != nil { + return err + } + log.Debugf(string(out)) return nil } -func addVPNRoute(prefix netip.Prefix, intf string) error { - return genericAddVPNRoute(prefix, intf) +func removeFromRouteTable(prefix netip.Prefix, addr string) error { + args := []string{"delete", prefix.String()} + if runtime.GOOS == "darwin" { + args = append(args, addr) + } + cmd := exec.Command("route", args...) + out, err := cmd.Output() + if err != nil { + return err + } + log.Debugf(string(out)) + return nil } -func removeVPNRoute(prefix netip.Prefix, intf string) error { - return genericRemoveVPNRoute(prefix, intf) +func enableIPForwarding() error { + log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil } diff --git a/client/internal/routemanager/systemops_unix_test.go b/client/internal/routemanager/systemops_unix_test.go deleted file mode 100644 index 561eaeea4..000000000 --- a/client/internal/routemanager/systemops_unix_test.go +++ /dev/null @@ -1,234 +0,0 @@ -//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly - -package routemanager - -import ( - "fmt" - "net" - "strings" - "testing" - "time" - - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/gopacket/gopacket/pcap" - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -type PacketExpectation struct { - SrcIP net.IP - DstIP net.IP - SrcPort int - DstPort int - UDP bool - TCP bool -} - -type testCase struct { - name string - destination string - expectedInterface string - dialer dialer - expectedPacket PacketExpectation -} - -var testCases = []testCase{ - { - name: "To external host without custom dialer via vpn", - destination: "192.0.2.1:53", - expectedInterface: expectedVPNint, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), - }, - { - name: "To external host with custom dialer via physical interface", - destination: "192.0.2.1:53", - expectedInterface: expectedExternalInt, - dialer: nbnet.NewDialer(), - expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), - }, - - { - name: "To duplicate internal route with custom dialer via physical interface", - destination: "10.0.0.2:53", - expectedInterface: expectedInternalInt, - dialer: nbnet.NewDialer(), - expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), - }, - { - name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence - destination: "10.0.0.2:53", - expectedInterface: expectedInternalInt, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), - }, - - { - name: "To unique vpn route with custom dialer via physical interface", - destination: "172.16.0.2:53", - expectedInterface: expectedExternalInt, - dialer: nbnet.NewDialer(), - expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53), - }, - { - name: "To unique vpn route without custom dialer via vpn", - destination: "172.16.0.2:53", - expectedInterface: expectedVPNint, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53), - }, -} - -func TestRouting(t *testing.T) { - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - setupTestEnv(t) - - filter := createBPFFilter(tc.destination) - handle := startPacketCapture(t, tc.expectedInterface, filter) - - sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer) - - packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) - packet, err := packetSource.NextPacket() - require.NoError(t, err) - - verifyPacket(t, packet, tc.expectedPacket) - }) - } -} - -func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { - return PacketExpectation{ - SrcIP: net.ParseIP(srcIP), - DstIP: net.ParseIP(dstIP), - SrcPort: srcPort, - DstPort: dstPort, - UDP: true, - } -} - -func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { - t.Helper() - - inactive, err := pcap.NewInactiveHandle(intf) - require.NoError(t, err, "Failed to create inactive pcap handle") - defer inactive.CleanUp() - - err = inactive.SetSnapLen(1600) - require.NoError(t, err, "Failed to set snap length on inactive handle") - - err = inactive.SetTimeout(time.Second * 10) - require.NoError(t, err, "Failed to set timeout on inactive handle") - - err = inactive.SetImmediateMode(true) - require.NoError(t, err, "Failed to set immediate mode on inactive handle") - - handle, err := inactive.Activate() - require.NoError(t, err, "Failed to activate pcap handle") - t.Cleanup(handle.Close) - - err = handle.SetBPFFilter(filter) - require.NoError(t, err, "Failed to set BPF filter") - - return handle -} - -func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) { - t.Helper() - - if dialer == nil { - dialer = &net.Dialer{} - } - - if sourcePort != 0 { - localUDPAddr := &net.UDPAddr{ - IP: net.IPv4zero, - Port: sourcePort, - } - switch dialer := dialer.(type) { - case *nbnet.Dialer: - dialer.LocalAddr = localUDPAddr - case *net.Dialer: - dialer.LocalAddr = localUDPAddr - default: - t.Fatal("Unsupported dialer type") - } - } - - msg := new(dns.Msg) - msg.Id = dns.Id() - msg.RecursionDesired = true - msg.Question = []dns.Question{ - {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - - conn, err := dialer.Dial("udp", destination) - require.NoError(t, err, "Failed to dial UDP") - defer conn.Close() - - data, err := msg.Pack() - require.NoError(t, err, "Failed to pack DNS message") - - _, err = conn.Write(data) - if err != nil { - if strings.Contains(err.Error(), "required key not available") { - t.Logf("Ignoring WireGuard key error: %v", err) - return - } - t.Fatalf("Failed to send DNS query: %v", err) - } -} - -func createBPFFilter(destination string) string { - host, port, err := net.SplitHostPort(destination) - if err != nil { - return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) - } - return "udp" -} - -func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { - t.Helper() - - ipLayer := packet.Layer(layers.LayerTypeIPv4) - require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") - - ip, ok := ipLayer.(*layers.IPv4) - require.True(t, ok, "Failed to cast to IPv4 layer") - - // Convert both source and destination IP addresses to 16-byte representation - expectedSrcIP := exp.SrcIP.To16() - actualSrcIP := ip.SrcIP.To16() - assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") - - expectedDstIP := exp.DstIP.To16() - actualDstIP := ip.DstIP.To16() - assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") - - if exp.UDP { - udpLayer := packet.Layer(layers.LayerTypeUDP) - require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") - - udp, ok := udpLayer.(*layers.UDP) - require.True(t, ok, "Failed to cast to UDP layer") - - assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") - assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") - } - - if exp.TCP { - tcpLayer := packet.Layer(layers.LayerTypeTCP) - require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") - - tcp, ok := tcpLayer.(*layers.TCP) - require.True(t, ok, "Failed to cast to TCP layer") - - assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") - assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") - } -} diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index 50fff0cd5..309c184b9 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -1,19 +1,13 @@ //go:build windows +// +build windows package routemanager import ( - "fmt" "net" "net/netip" - "os/exec" - "strings" - log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" ) type Win32_IP4RouteTable struct { @@ -21,35 +15,23 @@ type Win32_IP4RouteTable struct { Mask string } -var routeManager *RouteManager - -func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) -} - -func cleanupRouting() error { - return cleanupRoutingWithRouteManager(routeManager) -} - func getRoutesFromTable() ([]netip.Prefix, error) { var routes []Win32_IP4RouteTable query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" err := wmi.Query(query, &routes) if err != nil { - return nil, fmt.Errorf("get routes: %w", err) + return nil, err } var prefixList []netip.Prefix for _, route := range routes { addr, err := netip.ParseAddr(route.Destination) if err != nil { - log.Warnf("Unable to parse route destination %s: %v", route.Destination, err) continue } maskSlice := net.ParseIP(route.Mask).To4() if maskSlice == nil { - log.Warnf("Unable to parse route mask %s", route.Mask) continue } mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3]) @@ -62,69 +44,3 @@ func getRoutesFromTable() ([]netip.Prefix, error) { } return prefixList, nil } - -func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - destinationPrefix := prefix.String() - psCmd := "New-NetRoute" - - addressFamily := "IPv4" - if prefix.Addr().Is6() { - addressFamily = "IPv6" - } - - script := fmt.Sprintf( - `%s -AddressFamily "%s" -DestinationPrefix "%s" -InterfaceAlias "%s" -Confirm:$False -ErrorAction Stop`, - psCmd, addressFamily, destinationPrefix, intf, - ) - - if nexthop.IsValid() { - script = fmt.Sprintf( - `%s -NextHop "%s"`, script, nexthop, - ) - } - - out, err := exec.Command("powershell", "-Command", script).CombinedOutput() - log.Tracef("PowerShell add route: %s", string(out)) - - if err != nil { - return fmt.Errorf("PowerShell add route: %w", err) - } - - return nil -} - -func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { - args := []string{"add", prefix.String(), nexthop.Unmap().String()} - - out, err := exec.Command("route", args...).CombinedOutput() - - log.Tracef("route %s output: %s", strings.Join(args, " "), out) - if err != nil { - return fmt.Errorf("route add: %w", err) - } - - return nil -} - -func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - // Powershell doesn't support adding routes without an interface but allows to add interface by name - if intf != "" { - return addRoutePowershell(prefix, nexthop, intf) - } - return addRouteCmd(prefix, nexthop, intf) -} - -func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { - args := []string{"delete", prefix.String()} - if nexthop.IsValid() { - args = append(args, nexthop.Unmap().String()) - } - - out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s output: %s", strings.Join(args, " "), out) - - if err != nil { - return fmt.Errorf("remove route: %w", err) - } - return nil -} diff --git a/client/internal/routemanager/systemops_windows_test.go b/client/internal/routemanager/systemops_windows_test.go deleted file mode 100644 index a5e03b8d2..000000000 --- a/client/internal/routemanager/systemops_windows_test.go +++ /dev/null @@ -1,289 +0,0 @@ -package routemanager - -import ( - "context" - "encoding/json" - "fmt" - "net" - "os/exec" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -var expectedExtInt = "Ethernet1" - -type RouteInfo struct { - NextHop string `json:"nexthop"` - InterfaceAlias string `json:"interfacealias"` - RouteMetric int `json:"routemetric"` -} - -type FindNetRouteOutput struct { - IPAddress string `json:"IPAddress"` - InterfaceIndex int `json:"InterfaceIndex"` - InterfaceAlias string `json:"InterfaceAlias"` - AddressFamily int `json:"AddressFamily"` - NextHop string `json:"NextHop"` - DestinationPrefix string `json:"DestinationPrefix"` -} - -type testCase struct { - name string - destination string - expectedSourceIP string - expectedDestPrefix string - expectedNextHop string - expectedInterface string - dialer dialer -} - -var expectedVPNint = "wgtest0" - -var testCases = []testCase{ - { - name: "To external host without custom dialer via vpn", - destination: "192.0.2.1:53", - expectedSourceIP: "100.64.0.1", - expectedDestPrefix: "128.0.0.0/1", - expectedNextHop: "0.0.0.0", - expectedInterface: "wgtest0", - dialer: &net.Dialer{}, - }, - { - name: "To external host with custom dialer via physical interface", - destination: "192.0.2.1:53", - expectedDestPrefix: "192.0.2.1/32", - expectedInterface: expectedExtInt, - dialer: nbnet.NewDialer(), - }, - - { - name: "To duplicate internal route with custom dialer via physical interface", - destination: "10.0.0.2:53", - expectedDestPrefix: "10.0.0.2/32", - expectedInterface: expectedExtInt, - dialer: nbnet.NewDialer(), - }, - { - name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence - destination: "10.0.0.2:53", - expectedSourceIP: "10.0.0.1", - expectedDestPrefix: "10.0.0.0/8", - expectedNextHop: "0.0.0.0", - expectedInterface: "Loopback Pseudo-Interface 1", - dialer: &net.Dialer{}, - }, - - { - name: "To unique vpn route with custom dialer via physical interface", - destination: "172.16.0.2:53", - expectedDestPrefix: "172.16.0.2/32", - expectedInterface: expectedExtInt, - dialer: nbnet.NewDialer(), - }, - { - name: "To unique vpn route without custom dialer via vpn", - destination: "172.16.0.2:53", - expectedSourceIP: "100.64.0.1", - expectedDestPrefix: "172.16.0.0/12", - expectedNextHop: "0.0.0.0", - expectedInterface: "wgtest0", - dialer: &net.Dialer{}, - }, - - { - name: "To more specific route without custom dialer via vpn interface", - destination: "10.10.0.2:53", - expectedSourceIP: "100.64.0.1", - expectedDestPrefix: "10.10.0.0/24", - expectedNextHop: "0.0.0.0", - expectedInterface: "wgtest0", - dialer: &net.Dialer{}, - }, - - { - name: "To more specific route (local) without custom dialer via physical interface", - destination: "127.0.10.2:53", - expectedSourceIP: "10.0.0.1", - expectedDestPrefix: "127.0.0.0/8", - expectedNextHop: "0.0.0.0", - expectedInterface: "Loopback Pseudo-Interface 1", - dialer: &net.Dialer{}, - }, -} - -func TestRouting(t *testing.T) { - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - setupTestEnv(t) - - route, err := fetchOriginalGateway() - require.NoError(t, err, "Failed to fetch original gateway") - ip, err := fetchInterfaceIP(route.InterfaceAlias) - require.NoError(t, err, "Failed to fetch interface IP") - - output := testRoute(t, tc.destination, tc.dialer) - if tc.expectedInterface == expectedExtInt { - verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias) - } else { - verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface) - } - }) - } -} - -// fetchInterfaceIP fetches the IPv4 address of the specified interface. -func fetchInterfaceIP(interfaceAlias string) (string, error) { - script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias) - out, err := exec.Command("powershell", "-Command", script).Output() - if err != nil { - return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err) - } - - ip := strings.TrimSpace(string(out)) - return ip, nil -} - -func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - conn, err := dialer.DialContext(ctx, "udp", destination) - require.NoError(t, err, "Failed to dial destination") - defer func() { - err := conn.Close() - assert.NoError(t, err, "Failed to close connection") - }() - - host, _, err := net.SplitHostPort(destination) - require.NoError(t, err) - - script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host) - - out, err := exec.Command("powershell", "-Command", script).Output() - require.NoError(t, err, "Failed to execute Find-NetRoute") - - var outputs []FindNetRouteOutput - err = json.Unmarshal(out, &outputs) - require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute") - - require.Greater(t, len(outputs), 0, "No route found for destination") - combinedOutput := combineOutputs(outputs) - - return combinedOutput -} - -func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { - t.Helper() - - ip, ipNet, err := net.ParseCIDR(ipAddressCIDR) - require.NoError(t, err) - subnetMaskSize, _ := ipNet.Mask.Size() - script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize) - _, err = exec.Command("powershell", "-Command", script).CombinedOutput() - require.NoError(t, err, "Failed to assign IP address to loopback adapter") - - // Wait for the IP address to be applied - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - err = waitForIPAddress(ctx, interfaceName, ip.String()) - require.NoError(t, err, "IP address not applied within timeout") - - t.Cleanup(func() { - script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String()) - _, err = exec.Command("powershell", "-Command", script).CombinedOutput() - require.NoError(t, err, "Failed to remove IP address from loopback adapter") - }) - - return interfaceName -} - -func fetchOriginalGateway() (*RouteInfo, error) { - cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json") - output, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err) - } - - var routeInfo RouteInfo - err = json.Unmarshal(output, &routeInfo) - if err != nil { - return nil, fmt.Errorf("failed to parse JSON output: %w", err) - } - - return &routeInfo, nil -} - -func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) { - t.Helper() - - assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch") - assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch") - assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch") - assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch") -} - -func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error { - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput() - if err != nil { - return err - } - - ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n") - for _, ip := range ipAddresses { - if strings.TrimSpace(ip) == expectedIPAddress { - return nil - } - } - } - } -} - -func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { - var combined FindNetRouteOutput - - for _, output := range outputs { - if output.IPAddress != "" { - combined.IPAddress = output.IPAddress - } - if output.InterfaceIndex != 0 { - combined.InterfaceIndex = output.InterfaceIndex - } - if output.InterfaceAlias != "" { - combined.InterfaceAlias = output.InterfaceAlias - } - if output.AddressFamily != 0 { - combined.AddressFamily = output.AddressFamily - } - if output.NextHop != "" { - combined.NextHop = output.NextHop - } - if output.DestinationPrefix != "" { - combined.DestinationPrefix = output.DestinationPrefix - } - } - - return &combined -} - -func setupDummyInterfacesAndRoutes(t *testing.T) { - t.Helper() - - createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8") -} diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go deleted file mode 100644 index e80adb42b..000000000 --- a/client/internal/stdnet/dialer.go +++ /dev/null @@ -1,24 +0,0 @@ -package stdnet - -import ( - "net" - - "github.com/pion/transport/v3" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -// Dial connects to the address on the named network. -func (n *Net) Dial(network, address string) (net.Conn, error) { - return nbnet.NewDialer().Dial(network, address) -} - -// DialUDP connects to the address on the named UDP network. -func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { - return nbnet.DialUDP(network, laddr, raddr) -} - -// DialTCP connects to the address on the named TCP network. -func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { - return nbnet.DialTCP(network, laddr, raddr) -} diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go deleted file mode 100644 index 9ce0a5556..000000000 --- a/client/internal/stdnet/listener.go +++ /dev/null @@ -1,20 +0,0 @@ -package stdnet - -import ( - "context" - "net" - - "github.com/pion/transport/v3" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -// ListenPacket listens for incoming packets on the given network and address. -func (n *Net) ListenPacket(network, address string) (net.PacketConn, error) { - return nbnet.NewListener().ListenPacket(context.Background(), network, address) -} - -// ListenUDP acts like ListenPacket for UDP networks. -func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) { - return nbnet.ListenUDP(network, locAddr) -} diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index 2235c5d2b..f02b4943b 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -17,7 +17,6 @@ import ( "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" - nbnet "github.com/netbirdio/netbird/util/net" ) // WGEBPFProxy definition for proxy with EBPF support @@ -68,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error { IP: net.ParseIP("127.0.0.1"), } - conn, err := nbnet.ListenUDP("udp", &addr) + conn, err := net.ListenUDP("udp", &addr) if err != nil { cErr := p.Free() if cErr != nil { @@ -229,12 +228,6 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { return nil, fmt.Errorf("binding to lo interface failed: %w", err) } - // Set the fwmark on the socket. - err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark) - if err != nil { - return nil, fmt.Errorf("setting fwmark failed: %w", err) - } - // Convert the file descriptor to a PacketConn. file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) if file == nil { diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go index 17ebfbc49..b692ea708 100644 --- a/client/internal/wgproxy/proxy_userspace.go +++ b/client/internal/wgproxy/proxy_userspace.go @@ -6,8 +6,6 @@ import ( "net" log "github.com/sirupsen/logrus" - - nbnet "github.com/netbirdio/netbird/util/net" ) // WGUserSpaceProxy proxies @@ -35,7 +33,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) { p.remoteConn = remoteConn var err error - p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) + p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) return nil, err diff --git a/go.mod b/go.mod index 29a1570c8..e4e36b966 100644 --- a/go.mod +++ b/go.mod @@ -48,7 +48,6 @@ require ( github.com/google/gopacket v1.1.19 github.com/google/martian/v3 v3.0.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 - github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 @@ -124,6 +123,7 @@ require ( github.com/google/s2a-go v0.1.4 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.10.0 // indirect + github.com/gopacket/gopacket v1.1.1 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/iface/wg_configurer_kernel.go b/iface/wg_configurer_kernel.go index 9fe987cee..36fd13cc2 100644 --- a/iface/wg_configurer_kernel.go +++ b/iface/wg_configurer_kernel.go @@ -10,8 +10,6 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - nbnet "github.com/netbirdio/netbird/util/net" ) type wgKernelConfigurer struct { @@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err if err != nil { return err } - fwmark := nbnet.NetbirdFwmark + fwmark := 0 config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, diff --git a/iface/wg_configurer_usp.go b/iface/wg_configurer_usp.go index 24dfadf14..200bfbc96 100644 --- a/iface/wg_configurer_usp.go +++ b/iface/wg_configurer_usp.go @@ -13,8 +13,6 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - nbnet "github.com/netbirdio/netbird/util/net" ) type wgUSPConfigurer struct { @@ -39,7 +37,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error if err != nil { return err } - fwmark := getFwmark() + fwmark := 0 config := wgtypes.Config{ PrivateKey: &key, ReplacePeers: true, @@ -347,10 +345,3 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { } return sb.String() } - -func getFwmark() int { - if runtime.GOOS == "linux" { - return nbnet.NetbirdFwmark - } - return 0 -} diff --git a/management/client/grpc.go b/management/client/grpc.go index 0b1804906..0234f866c 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -24,7 +24,6 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) const ConnectTimeout = 10 * time.Second @@ -58,7 +57,6 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE mgmCtx, addr, transportOption, - nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index 74ac6c163..02b4e174d 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -21,8 +21,6 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" - - nbnet "github.com/netbirdio/netbird/util/net" ) // ErrSharedSockStopped indicates that shared socket has been stopped @@ -84,18 +82,10 @@ func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) { return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err) } - if err = nbnet.SetSocketMark(rawSock.conn4); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err) - } - var sockErr error rawSock.conn6, sockErr = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil) if sockErr != nil { log.Errorf("Failed to create ipv6 raw socket: %v", err) - } else { - if err = nbnet.SetSocketMark(rawSock.conn6); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err) - } } ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port)) diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 7c4535e28..7531608c3 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -23,7 +23,6 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/signal/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -77,7 +76,6 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo sigCtx, addr, transportOption, - nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go deleted file mode 100644 index 96b2bc32b..000000000 --- a/util/grpc/dialer.go +++ /dev/null @@ -1,22 +0,0 @@ -package grpc - -import ( - "context" - "net" - - log "github.com/sirupsen/logrus" - "google.golang.org/grpc" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -func WithCustomDialer() grpc.DialOption { - return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) - if err != nil { - log.Errorf("Failed to dial: %s", err) - return nil, err - } - return conn, nil - }) -} diff --git a/util/net/dialer.go b/util/net/dialer.go deleted file mode 100644 index 0786c667e..000000000 --- a/util/net/dialer.go +++ /dev/null @@ -1,21 +0,0 @@ -package net - -import ( - "net" -) - -// Dialer extends the standard net.Dialer with the ability to execute hooks before -// and after connections. This can be used to bypass the VPN for connections using this dialer. -type Dialer struct { - *net.Dialer -} - -// NewDialer returns a customized net.Dialer with overridden Control method -func NewDialer() *Dialer { - dialer := &Dialer{ - Dialer: &net.Dialer{}, - } - dialer.init() - - return dialer -} diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go deleted file mode 100644 index 06fac3bbf..000000000 --- a/util/net/dialer_generic.go +++ /dev/null @@ -1,163 +0,0 @@ -//go:build !android && !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" -) - -type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error -type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error - -var ( - dialerDialHooksMutex sync.RWMutex - dialerDialHooks []DialerDialHookFunc - dialerCloseHooksMutex sync.RWMutex - dialerCloseHooks []DialerCloseHookFunc -) - -// AddDialerHook allows adding a new hook to be executed before dialing. -func AddDialerHook(hook DialerDialHookFunc) { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = append(dialerDialHooks, hook) -} - -// AddDialerCloseHook allows adding a new hook to be executed on connection close. -func AddDialerCloseHook(hook DialerCloseHookFunc) { - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = append(dialerCloseHooks, hook) -} - -// RemoveDialerHook removes all dialer hooks. -func RemoveDialerHooks() { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = nil - - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = nil -} - -// DialContext wraps the net.Dialer's DialContext method to use the custom connection -func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - var resolver *net.Resolver - if d.Resolver != nil { - resolver = d.Resolver - } - - connID := GenerateConnID() - if dialerDialHooks != nil { - if err := calliDialerHooks(ctx, connID, address, resolver); err != nil { - log.Errorf("Failed to call dialer hooks: %v", err) - } - } - - conn, err := d.Dialer.DialContext(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("dial: %w", err) - } - - // Wrap the connection in Conn to handle Close with hooks - return &Conn{Conn: conn, ID: connID}, nil -} - -// Dial wraps the net.Dialer's Dial method to use the custom connection -func (d *Dialer) Dial(network, address string) (net.Conn, error) { - return d.DialContext(context.Background(), network, address) -} - -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -func (c *Conn) Close() error { - err := c.Conn.Close() - - dialerCloseHooksMutex.RLock() - defer dialerCloseHooksMutex.RUnlock() - - for _, hook := range dialerCloseHooks { - if err := hook(c.ID, &c.Conn); err != nil { - log.Errorf("Error executing dialer close hook: %v", err) - } - } - - return err -} - -func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { - host, _, err := net.SplitHostPort(address) - if err != nil { - return fmt.Errorf("split host and port: %w", err) - } - ips, err := resolver.LookupIPAddr(ctx, host) - if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) - } - - log.Debugf("Dialer resolved IPs for %s: %v", address, ips) - - var result *multierror.Error - - dialerDialHooksMutex.RLock() - defer dialerDialHooksMutex.RUnlock() - for _, hook := range dialerDialHooks { - if err := hook(ctx, connID, ips); err != nil { - result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) - } - } - - return result.ErrorOrNil() -} - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) - } - - return tcpConn, nil -} diff --git a/util/net/dialer_linux.go b/util/net/dialer_linux.go deleted file mode 100644 index aed5c59a3..000000000 --- a/util/net/dialer_linux.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !android - -package net - -import "syscall" - -// init configures the net.Dialer Control function to set the fwmark on the socket -func (d *Dialer) init() { - d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) - } -} diff --git a/util/net/dialer_nonlinux.go b/util/net/dialer_nonlinux.go deleted file mode 100644 index 3254e6d06..000000000 --- a/util/net/dialer_nonlinux.go +++ /dev/null @@ -1,6 +0,0 @@ -//go:build !linux || android - -package net - -func (d *Dialer) init() { -} diff --git a/util/net/listener.go b/util/net/listener.go deleted file mode 100644 index f4d769f58..000000000 --- a/util/net/listener.go +++ /dev/null @@ -1,21 +0,0 @@ -package net - -import ( - "net" -) - -// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before -// responding via the socket and after closing. This can be used to bypass the VPN for listeners. -type ListenerConfig struct { - *net.ListenConfig -} - -// NewListener creates a new ListenerConfig instance. -func NewListener() *ListenerConfig { - listener := &ListenerConfig{ - ListenConfig: &net.ListenConfig{}, - } - listener.init() - - return listener -} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go deleted file mode 100644 index 451279e9d..000000000 --- a/util/net/listener_generic.go +++ /dev/null @@ -1,163 +0,0 @@ -//go:build !android && !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - log "github.com/sirupsen/logrus" -) - -// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. -type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error - -// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. -type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error - -var ( - listenerWriteHooksMutex sync.RWMutex - listenerWriteHooks []ListenerWriteHookFunc - listenerCloseHooksMutex sync.RWMutex - listenerCloseHooks []ListenerCloseHookFunc -) - -// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. -func AddListenerWriteHook(hook ListenerWriteHookFunc) { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = append(listenerWriteHooks, hook) -} - -// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. -func AddListenerCloseHook(hook ListenerCloseHookFunc) { - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = append(listenerCloseHooks, hook) -} - -// RemoveListenerHooks removes all dialer hooks. -func RemoveListenerHooks() { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = nil - - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = nil -} - -// ListenPacket listens on the network address and returns a PacketConn -// which includes support for write hooks. -func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { - pc, err := l.ListenConfig.ListenPacket(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("listen packet: %w", err) - } - connID := GenerateConnID() - return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil -} - -// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. -type PacketConn struct { - net.PacketConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.PacketConn.WriteTo(b, addr) -} - -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. -func (c *PacketConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.PacketConn) -} - -// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. -type UDPConn struct { - *net.UDPConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.UDPConn.WriteTo(b, addr) -} - -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. -func (c *UDPConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.UDPConn) -} - -func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { - // Lookup the address in the seenAddrs map to avoid calling the hooks for every write - if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { - ipStr, _, splitErr := net.SplitHostPort(addr.String()) - if splitErr != nil { - log.Errorf("Error splitting IP address and port: %v", splitErr) - return - } - - ip, err := net.ResolveIPAddr("ip", ipStr) - if err != nil { - log.Errorf("Error resolving IP address: %v", err) - return - } - log.Debugf("Listener resolved IP for %s: %s", addr, ip) - - func() { - listenerWriteHooksMutex.RLock() - defer listenerWriteHooksMutex.RUnlock() - - for _, hook := range listenerWriteHooks { - if err := hook(id, ip, b); err != nil { - log.Errorf("Error executing listener write hook: %v", err) - } - } - }() - } -} - -func closeConn(id ConnectionID, conn net.PacketConn) error { - err := conn.Close() - - listenerCloseHooksMutex.RLock() - defer listenerCloseHooksMutex.RUnlock() - - for _, hook := range listenerCloseHooks { - if err := hook(id, conn); err != nil { - log.Errorf("Error executing listener close hook: %v", err) - } - } - - return err -} - -// ListenUDP listens on the network address and returns a transport.UDPConn -// which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { - conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listen UDP: %w", err) - } - - packetConn := conn.(*PacketConn) - udpConn, ok := packetConn.PacketConn.(*net.UDPConn) - if !ok { - if err := packetConn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) - } - - return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil -} diff --git a/util/net/listener_linux.go b/util/net/listener_linux.go deleted file mode 100644 index 8d332160a..000000000 --- a/util/net/listener_linux.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !android - -package net - -import ( - "syscall" -) - -// init configures the net.ListenerConfig Control function to set the fwmark on the socket -func (l *ListenerConfig) init() { - l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) - } -} diff --git a/util/net/listener_mobile.go b/util/net/listener_mobile.go deleted file mode 100644 index 0dbbb360b..000000000 --- a/util/net/listener_mobile.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build android || ios - -package net - -import ( - "net" -) - -func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { - return net.ListenUDP(network, laddr) -} diff --git a/util/net/listener_nonlinux.go b/util/net/listener_nonlinux.go deleted file mode 100644 index fb6eadaaa..000000000 --- a/util/net/listener_nonlinux.go +++ /dev/null @@ -1,6 +0,0 @@ -//go:build !linux || android - -package net - -func (l *ListenerConfig) init() { -} diff --git a/util/net/net.go b/util/net/net.go deleted file mode 100644 index 9ea7ae803..000000000 --- a/util/net/net.go +++ /dev/null @@ -1,17 +0,0 @@ -package net - -import "github.com/google/uuid" - -const ( - // NetbirdFwmark is the fwmark value used by Netbird via wireguard - NetbirdFwmark = 0x1BD00 -) - -// ConnectionID provides a globally unique identifier for network connections. -// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. -type ConnectionID string - -// GenerateConnID generates a unique identifier for each connection. -func GenerateConnID() ConnectionID { - return ConnectionID(uuid.NewString()) -} diff --git a/util/net/net_linux.go b/util/net/net_linux.go deleted file mode 100644 index 821417500..000000000 --- a/util/net/net_linux.go +++ /dev/null @@ -1,35 +0,0 @@ -//go:build !android - -package net - -import ( - "fmt" - "syscall" -) - -// SetSocketMark sets the SO_MARK option on the given socket connection -func SetSocketMark(conn syscall.Conn) error { - sysconn, err := conn.SyscallConn() - if err != nil { - return fmt.Errorf("get raw conn: %w", err) - } - - return SetRawSocketMark(sysconn) -} - -func SetRawSocketMark(conn syscall.RawConn) error { - var setErr error - - err := conn.Control(func(fd uintptr) { - setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) - }) - if err != nil { - return fmt.Errorf("control: %w", err) - } - - if setErr != nil { - return fmt.Errorf("set SO_MARK: %w", setErr) - } - - return nil -}