Support client default routes for Linux (#1667)

All routes are now installed in a custom netbird routing table.
Management and wireguard traffic is now marked with a custom fwmark.
When the mark is present the traffic is routed via the main routing table, bypassing the VPN.
When the mark is absent the traffic is routed via the netbird routing table, if:
- there's no match in the main routing table
- it would match the default route in the routing table

IPv6 traffic is blocked when a default route IPv4 route is configured to avoid leakage.
This commit is contained in:
Viktor Liu 2024-03-21 16:49:28 +01:00 committed by GitHub
parent 846871913d
commit 2475473227
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 1656 additions and 376 deletions

View File

@ -14,8 +14,8 @@ jobs:
test: test:
strategy: strategy:
matrix: matrix:
arch: ['386','amd64'] arch: [ '386','amd64' ]
store: ['jsonfile', 'sqlite'] store: [ 'jsonfile', 'sqlite' ]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install Go - name: Install Go
@ -36,7 +36,11 @@ jobs:
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@ -67,7 +71,7 @@ jobs:
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@ -82,7 +86,7 @@ jobs:
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin - name: Generate RouteManager Test bin
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/... run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/...
- name: Generate nftables Manager Test bin - name: Generate nftables Manager Test bin
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...

View File

@ -40,7 +40,7 @@ jobs:
cache: false cache: false
- name: Install dependencies - name: Install dependencies
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v3 uses: golangci/golangci-lint-action@v3
with: with:

View File

@ -230,8 +230,8 @@ func (e *Engine) Start() error {
wgIface, err := e.newWgIface() wgIface, err := e.newWgIface()
if err != nil { if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err.Error()) log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
return err return fmt.Errorf("new wg interface: %w", err)
} }
e.wgInterface = wgIface e.wgInterface = wgIface
@ -244,29 +244,33 @@ func (e *Engine) Start() error {
} }
e.rpManager, err = rosenpass.NewManager(e.config.PreSharedKey, e.config.WgIfaceName) e.rpManager, err = rosenpass.NewManager(e.config.PreSharedKey, e.config.WgIfaceName)
if err != nil { if err != nil {
return err return fmt.Errorf("create rosenpass manager: %w", err)
} }
err := e.rpManager.Run() err := e.rpManager.Run()
if err != nil { if err != nil {
return err return fmt.Errorf("run rosenpass manager: %w", err)
} }
} }
initialRoutes, dnsServer, err := e.newDnsServer() initialRoutes, dnsServer, err := e.newDnsServer()
if err != nil { if err != nil {
e.close() e.close()
return err return fmt.Errorf("create dns server: %w", err)
} }
e.dnsServer = dnsServer e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
if err := e.routeManager.Init(); err != nil {
e.close()
return fmt.Errorf("init route manager: %w", err)
}
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
err = e.wgInterfaceCreate() err = e.wgInterfaceCreate()
if err != nil { if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close() e.close()
return err return fmt.Errorf("create wg interface: %w", err)
} }
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface) e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
@ -278,7 +282,7 @@ func (e *Engine) Start() error {
err = e.routeManager.EnableServerRouter(e.firewall) err = e.routeManager.EnableServerRouter(e.firewall)
if err != nil { if err != nil {
e.close() e.close()
return err return fmt.Errorf("enable server router: %w", err)
} }
} }
@ -286,7 +290,7 @@ func (e *Engine) Start() error {
if err != nil { if err != nil {
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error()) log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
e.close() e.close()
return err return fmt.Errorf("up wg interface: %w", err)
} }
if e.firewall != nil { if e.firewall != nil {
@ -296,7 +300,7 @@ func (e *Engine) Start() error {
err = e.dnsServer.Initialize() err = e.dnsServer.Initialize()
if err != nil { if err != nil {
e.close() e.close()
return err return fmt.Errorf("initialize dns server: %w", err)
} }
e.receiveSignalEvents() e.receiveSignalEvents()

View File

@ -10,6 +10,9 @@ import (
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
"github.com/pion/turn/v3" "github.com/pion/turn/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/util/net"
) )
// ProbeResult holds the info about the result of a relay probe request // ProbeResult holds the info about the result of a relay probe request
@ -27,7 +30,15 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
} }
}() }()
client, err := stun.DialURI(uri, &stun.DialConfig{}) net, err := stdnet.NewNet(nil)
if err != nil {
probeErr = fmt.Errorf("new net: %w", err)
return
}
client, err := stun.DialURI(uri, &stun.DialConfig{
Net: net,
})
if err != nil { if err != nil {
probeErr = fmt.Errorf("dial: %w", err) probeErr = fmt.Errorf("dial: %w", err)
return return
@ -85,14 +96,13 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
switch uri.Proto { switch uri.Proto {
case stun.ProtoTypeUDP: case stun.ProtoTypeUDP:
var err error var err error
conn, err = net.ListenPacket("udp", "") conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "")
if err != nil { if err != nil {
probeErr = fmt.Errorf("listen: %w", err) probeErr = fmt.Errorf("listen: %w", err)
return return
} }
case stun.ProtoTypeTCP: case stun.ProtoTypeTCP:
dialer := net.Dialer{} tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr)
tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr)
if err != nil { if err != nil {
probeErr = fmt.Errorf("dial: %w", err) probeErr = fmt.Errorf("dial: %w", err)
return return
@ -109,12 +119,18 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
} }
}() }()
net, err := stdnet.NewNet(nil)
if err != nil {
probeErr = fmt.Errorf("new net: %w", err)
return
}
cfg := &turn.ClientConfig{ cfg := &turn.ClientConfig{
STUNServerAddr: turnServerAddr, STUNServerAddr: turnServerAddr,
TURNServerAddr: turnServerAddr, TURNServerAddr: turnServerAddr,
Conn: conn, Conn: conn,
Username: uri.Username, Username: uri.Username,
Password: uri.Password, Password: uri.Password,
Net: net,
} }
client, err := turn.NewClient(cfg) client, err := turn.NewClient(cfg)
if err != nil { if err != nil {

View File

@ -41,6 +41,7 @@ type clientNetwork struct {
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork { func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{ client := &clientNetwork{
ctx: ctx, ctx: ctx,
stop: cancel, stop: cancel,
@ -72,6 +73,18 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
return routePeerStatuses return routePeerStatuses
} }
// getBestRouteFromStatuses determines the most optimal route from the available routes
// within a clientNetwork, taking into account peer connection status, route metrics, and
// preference for non-relayed and direct connections.
//
// It follows these prioritization rules:
// * Connected peers: Only routes with connected peers are considered.
// * Metric: Routes with lower metrics (better) are prioritized.
// * Non-relayed: Routes without relays are preferred.
// * Direct connections: Routes with direct peer connections are favored.
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
//
// It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
chosen := "" chosen := ""
chosenScore := 0 chosenScore := 0
@ -158,7 +171,7 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
state, err := c.statusRecorder.GetPeer(peerKey) state, err := c.statusRecorder.GetPeer(peerKey)
if err != nil { if err != nil {
return err return fmt.Errorf("get peer state: %v", err)
} }
delete(state.Routes, c.network.String()) delete(state.Routes, c.network.String())
@ -172,7 +185,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String()) err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
if err != nil { if err != nil {
return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v", return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err) c.network, c.chosenRoute.Peer, err)
} }
return nil return nil
@ -180,30 +193,26 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
func (c *clientNetwork) removeRouteFromPeerAndSystem() error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
if c.chosenRoute != nil { if c.chosenRoute != nil {
err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) if err := removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil {
if err != nil { return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
return err
} }
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String())
if err != nil { if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
return fmt.Errorf("couldn't remove route %s from system, err: %v", return fmt.Errorf("remove route: %v", err)
c.network, err)
} }
} }
return nil return nil
} }
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
var err error
routerPeerStatuses := c.getRouterPeerStatuses() routerPeerStatuses := c.getRouterPeerStatuses()
chosen := c.getBestRouteFromStatuses(routerPeerStatuses) chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
// If no route is chosen, remove the route from the peer and system
if chosen == "" { if chosen == "" {
err = c.removeRouteFromPeerAndSystem() if err := c.removeRouteFromPeerAndSystem(); err != nil {
if err != nil { return fmt.Errorf("remove route from peer and system: %v", err)
return err
} }
c.chosenRoute = nil c.chosenRoute = nil
@ -211,6 +220,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
return nil return nil
} }
// If the chosen route is the same as the current route, do nothing
if c.chosenRoute != nil && c.chosenRoute.ID == chosen { if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
if c.chosenRoute.IsEqual(c.routes[chosen]) { if c.chosenRoute.IsEqual(c.routes[chosen]) {
return nil return nil
@ -218,13 +228,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
} }
if c.chosenRoute != nil { if c.chosenRoute != nil {
err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) // If a previous route exists, remove it from the peer
if err != nil { if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
return err return fmt.Errorf("remove route from peer: %v", err)
} }
} else { } else {
err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String()) // otherwise add the route to the system
if err != nil { if err := addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.network.String(), c.wgInterface.Address().IP.String(), err) c.network.String(), c.wgInterface.Address().IP.String(), err)
} }
@ -245,8 +255,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
} }
} }
err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()) if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil {
if err != nil {
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err) c.network, c.chosenRoute.Peer, err)
} }
@ -287,21 +296,21 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
log.Debugf("stopping watcher for network %s", c.network) log.Debugf("stopping watcher for network %s", c.network)
err := c.removeRouteFromPeerAndSystem() err := c.removeRouteFromPeerAndSystem()
if err != nil { if err != nil {
log.Error(err) log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err)
} }
return return
case <-c.peerStateUpdate: case <-c.peerStateUpdate:
err := c.recalculateRouteAndUpdatePeerAndSystem() err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil { if err != nil {
log.Error(err) log.Errorf("Couldn't recalculate route and update peer and system: %v", err)
} }
case update := <-c.routeUpdate: case update := <-c.routeUpdate:
if update.updateSerial < c.updateSerial { if update.updateSerial < c.updateSerial {
log.Warnf("received a routes update with smaller serial number, ignoring it") log.Warnf("Received a routes update with smaller serial number, ignoring it")
continue continue
} }
log.Debugf("received a new client network route update for %s", c.network) log.Debugf("Received a new client network route update for %s", c.network)
c.handleUpdate(update) c.handleUpdate(update)
@ -309,7 +318,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
err := c.recalculateRouteAndUpdatePeerAndSystem() err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil { if err != nil {
log.Error(err) log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err)
} }
c.startPeersStatusChangeWatcher() c.startPeersStatusChangeWatcher()

View File

@ -2,6 +2,8 @@ package routemanager
import ( import (
"context" "context"
"fmt"
"net/netip"
"runtime" "runtime"
"sync" "sync"
@ -15,8 +17,14 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
// nolint:unused
var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() error
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
SetRouteChangeListener(listener listener.NetworkChangeListener) SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string InitialRouteRange() []string
@ -56,6 +64,19 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
return dm return dm
} }
// Init sets up the routing
func (m *DefaultManager) Init() error {
if err := cleanupRouting(); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
if err := setupRouting(); err != nil {
return fmt.Errorf("setup routing: %w", err)
}
log.Info("Routing setup complete")
return nil
}
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
var err error var err error
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
@ -71,9 +92,15 @@ func (m *DefaultManager) Stop() {
if m.serverRouter != nil { if m.serverRouter != nil {
m.serverRouter.cleanUp() m.serverRouter.cleanUp()
} }
if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
}
m.ctx = nil
} }
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
@ -91,7 +118,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
if m.serverRouter != nil { if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap) err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil { if err != nil {
return err return fmt.Errorf("update routes: %w", err)
} }
} }
@ -156,11 +183,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]
for _, newRoute := range newRoutes { for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute) networkID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] { if !ownNetworkIDs[networkID] {
// if prefix is too small, lets assume is a possible default route which is not yet supported if !isPrefixSupported(newRoute.Network) {
// we skip this route management
if newRoute.Network.Bits() < minRangeBits {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route",
version.NetbirdVersion(), newRoute.Network)
continue continue
} }
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
@ -178,3 +201,18 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
} }
return rs return rs
} }
func isPrefixSupported(prefix netip.Prefix) bool {
if runtime.GOOS == "linux" {
return true
}
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
// we skip this prefix management
if prefix.Bits() < minRangeBits {
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
version.NetbirdVersion(), prefix)
return false
}
return true
}

View File

@ -28,13 +28,14 @@ const remotePeerKey2 = "remote1"
func TestManagerUpdateRoutes(t *testing.T) { func TestManagerUpdateRoutes(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
inputInitRoutes []*route.Route inputInitRoutes []*route.Route
inputRoutes []*route.Route inputRoutes []*route.Route
inputSerial uint64 inputSerial uint64
removeSrvRouter bool removeSrvRouter bool
serverRoutesExpected int serverRoutesExpected int
clientNetworkWatchersExpected int clientNetworkWatchersExpected int
clientNetworkWatchersExpectedLinux int
}{ }{
{ {
name: "Should create 2 client networks", name: "Should create 2 client networks",
@ -200,8 +201,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
Enabled: true, Enabled: true,
}, },
}, },
inputSerial: 1, inputSerial: 1,
clientNetworkWatchersExpected: 0, clientNetworkWatchersExpected: 0,
clientNetworkWatchersExpectedLinux: 1,
}, },
{ {
name: "Remove 1 Client Route", name: "Remove 1 Client Route",
@ -415,6 +417,8 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm") statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO() ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
err = routeManager.Init()
require.NoError(t, err, "should init route manager")
defer routeManager.Stop() defer routeManager.Stop()
if testCase.removeSrvRouter { if testCase.removeSrvRouter {
@ -429,7 +433,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
require.NoError(t, err, "should update routes") require.NoError(t, err, "should update routes")
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") expectedWatchers := testCase.clientNetworkWatchersExpected
if runtime.GOOS == "linux" && testCase.clientNetworkWatchersExpectedLinux != 0 {
expectedWatchers = testCase.clientNetworkWatchersExpectedLinux
}
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
if runtime.GOOS == "linux" && routeManager.serverRouter != nil { if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
sr := routeManager.serverRouter.(*defaultServerRouter) sr := routeManager.serverRouter.(*defaultServerRouter)

View File

@ -16,6 +16,10 @@ type MockManager struct {
StopFunc func() StopFunc func()
} }
func (m *MockManager) Init() error {
return nil
}
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface // InitialRouteRange mock implementation of InitialRouteRange from Manager interface
func (m *MockManager) InitialRouteRange() []string { func (m *MockManager) InitialRouteRange() []string {
return nil return nil

View File

@ -4,6 +4,7 @@ package routemanager
import ( import (
"context" "context"
"fmt"
"net/netip" "net/netip"
"sync" "sync"
@ -48,7 +49,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
oldRoute := m.routes[routeID] oldRoute := m.routes[routeID]
err := m.removeFromServerNetwork(oldRoute) err := m.removeFromServerNetwork(oldRoute)
if err != nil { if err != nil {
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err) oldRoute.ID, oldRoute.Network, err)
} }
delete(m.routes, routeID) delete(m.routes, routeID)
@ -62,7 +63,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
err := m.addToServerNetwork(newRoute) err := m.addToServerNetwork(newRoute)
if err != nil { if err != nil {
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
continue continue
} }
m.routes[id] = newRoute m.routes[id] = newRoute
@ -81,15 +82,22 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error { func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
log.Infof("not removing from server network because context is done") log.Infof("Not removing from server network because context is done")
return m.ctx.Err() return m.ctx.Err()
default: default:
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
if err != nil { if err != nil {
return err return fmt.Errorf("parse prefix: %w", err)
} }
err = m.firewall.RemoveRoutingRules(routerPair)
if err != nil {
return fmt.Errorf("remove routing rules: %w", err)
}
delete(m.routes, route.ID) delete(m.routes, route.ID)
state := m.statusRecorder.GetLocalPeerState() state := m.statusRecorder.GetLocalPeerState()
@ -103,15 +111,22 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
log.Infof("not adding to server network because context is done") log.Infof("Not adding to server network because context is done")
return m.ctx.Err() return m.ctx.Err()
default: default:
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
if err != nil { if err != nil {
return err return fmt.Errorf("parse prefix: %w", err)
} }
err = m.firewall.InsertRoutingRules(routerPair)
if err != nil {
return fmt.Errorf("insert routing rules: %w", err)
}
m.routes[route.ID] = route m.routes[route.ID] = route
state := m.statusRecorder.GetLocalPeerState() state := m.statusRecorder.GetLocalPeerState()
@ -129,9 +144,15 @@ func (m *defaultServerRouter) cleanUp() {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
for _, r := range m.routes { for _, r := range m.routes {
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r)) routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r)
if err != nil { if err != nil {
log.Warnf("failed to remove clean up route: %s", r.ID) log.Errorf("Failed to convert route to router pair: %v", err)
continue
}
err = m.firewall.RemoveRoutingRules(routerPair)
if err != nil {
log.Errorf("Failed to remove cleanup route: %v", err)
} }
state := m.statusRecorder.GetLocalPeerState() state := m.statusRecorder.GetLocalPeerState()
@ -139,13 +160,15 @@ func (m *defaultServerRouter) cleanUp() {
m.statusRecorder.UpdateLocalPeerState(state) m.statusRecorder.UpdateLocalPeerState(state)
} }
} }
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
func routeToRouterPair(source string, route *route.Route) firewall.RouterPair { parsed, err := netip.ParsePrefix(source)
parsed := netip.MustParsePrefix(source).Masked() if err != nil {
return firewall.RouterPair{}, err
}
return firewall.RouterPair{ return firewall.RouterPair{
ID: route.ID, ID: route.ID,
Source: parsed.String(), Source: parsed.String(),
Destination: route.Network.Masked().String(), Destination: route.Network.Masked().String(),
Masquerade: route.Masquerade, Masquerade: route.Masquerade,
} }, nil
} }

View File

@ -4,10 +4,10 @@ import (
"net/netip" "net/netip"
) )
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error {
return nil return nil
} }
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error {
return nil return nil
} }

View File

@ -1,5 +1,4 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd //go:build darwin || dragonfly || freebsd || netbsd || openbsd
// +build darwin dragonfly freebsd netbsd openbsd
package routemanager package routemanager

View File

@ -0,0 +1,13 @@
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && !ios
package routemanager
import "net/netip"
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
return genericAddToRouteTableIfNoExists(prefix, addr, intf)
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf)
}

View File

@ -1,15 +1,13 @@
//go:build ios
package routemanager package routemanager
import ( import (
"net/netip" "net/netip"
) )
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error {
return nil return nil
} }
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error {
return nil return nil
} }

View File

@ -3,142 +3,298 @@
package routemanager package routemanager
import ( import (
"bufio"
"errors"
"fmt"
"net" "net"
"net/netip" "net/netip"
"os" "os"
"syscall" "syscall"
"unsafe"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
nbnet "github.com/netbirdio/netbird/util/net"
) )
// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html const (
// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'. // NetbirdVPNTableID is the ID of the custom routing table used by Netbird.
type routeInfoInMemory struct { NetbirdVPNTableID = 0x1BD0
Family byte // NetbirdVPNTableName is the name of the custom routing table used by Netbird.
DstLen byte NetbirdVPNTableName = "netbird"
SrcLen byte
TOS byte
Table byte // rtTablesPath is the path to the file containing the routing table names.
Protocol byte rtTablesPath = "/etc/iproute2/rt_tables"
Scope byte
Type byte
Flags uint32 // ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
)
var ErrTableIDExists = errors.New("ID exists with different name")
type ruleParams struct {
fwmark int
tableID int
family int
priority int
invert bool
suppressPrefix int
description string
} }
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" func getSetupRules() []ruleParams {
return []ruleParams{
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "add rule v4 netbird"},
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "add rule v6 netbird"},
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "add rule with suppress prefixlen v4"},
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "add rule with suppress prefixlen v6"},
}
}
func addToRouteTable(prefix netip.Prefix, addr string) error { // setupRouting establishes the routing configuration for the VPN, including essential rules
_, ipNet, err := net.ParseCIDR(prefix.String()) // to ensure proper traffic flow for management, locally configured routes, and VPN traffic.
if err != nil { //
return err // Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over
// potential routes received and configured for the VPN. This rule is skipped for the default route and routes
// that are not in the main table.
//
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
//
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
func setupRouting() (err error) {
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
} }
addrMask := "/32" defer func() {
if prefix.Addr().Unmap().Is6() { if err != nil {
addrMask = "/128" if cleanErr := cleanupRouting(); cleanErr != nil {
} log.Errorf("Error cleaning up routing: %v", cleanErr)
}
}
}()
ip, _, err := net.ParseCIDR(addr + addrMask) rules := getSetupRules()
if err != nil { for _, rule := range rules {
return err if err := addRule(rule); err != nil {
} return fmt.Errorf("%s: %w", rule.description, err)
}
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Dst: ipNet,
Gw: ip,
}
err = netlink.RouteAdd(route)
if err != nil {
return err
} }
return nil return nil
} }
func removeFromRouteTable(prefix netip.Prefix, addr string) error { // cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
_, ipNet, err := net.ParseCIDR(prefix.String()) // It systematically removes the three rules and any associated routing table entries to ensure a clean state.
if err != nil { // The function uses error aggregation to report any errors encountered during the cleanup process.
return err func cleanupRouting() error {
var result *multierror.Error
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
result = multierror.Append(result, fmt.Errorf("flush routes v4: %w", err))
}
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
result = multierror.Append(result, fmt.Errorf("flush routes v6: %w", err))
} }
addrMask := "/32" rules := getSetupRules()
if prefix.Addr().Unmap().Is6() { for _, rule := range rules {
addrMask = "/128" if err := removeAllRules(rule); err != nil {
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
}
} }
ip, _, err := net.ParseCIDR(addr + addrMask) return result.ErrorOrNil()
if err != nil { }
return err
}
route := &netlink.Route{ func addToRouteTableIfNoExists(prefix netip.Prefix, _ string, intf string) error {
Scope: netlink.SCOPE_UNIVERSE, // No need to check if routes exist as main table takes precedence over the VPN table via Rule 2
Dst: ipNet,
Gw: ip,
}
err = netlink.RouteDel(route) // TODO remove this once we have ipv6 support
if err != nil { if prefix == defaultv4 {
return err if err := addUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
return fmt.Errorf("add blackhole: %w", err)
}
} }
if err := addRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
return fmt.Errorf("add route: %w", err)
}
return nil
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, _ string, intf string) error {
// TODO remove this once we have ipv6 support
if prefix == defaultv4 {
if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
return fmt.Errorf("remove unreachable route: %w", err)
}
}
if err := removeRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil return nil
} }
func getRoutesFromTable() ([]netip.Prefix, error) { func getRoutesFromTable() ([]netip.Prefix, error) {
tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC) return getRoutes(NetbirdVPNTableID, netlink.FAMILY_V4)
if err != nil { }
return nil, err
// addRoute adds a route to a specific routing table identified by tableID.
func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error {
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
Family: family,
} }
msgs, err := syscall.ParseNetlinkMessage(tab)
if err != nil { if prefix != nil {
return nil, err _, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route.Dst = ipNet
} }
var prefixList []netip.Prefix
loop: if err := addNextHop(addr, intf, route); err != nil {
for _, m := range msgs { return fmt.Errorf("add gateway and device: %w", err)
switch m.Header.Type { }
case syscall.NLMSG_DONE:
break loop if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) {
case syscall.RTM_NEWROUTE: return fmt.Errorf("netlink add route: %w", err)
rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0])) }
msg := m
attrs, err := syscall.ParseNetlinkRouteAttr(&msg) return nil
if err != nil { }
return nil, err
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
// tableID specifies the routing table to which the unreachable route will be added.
func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route := &netlink.Route{
Type: syscall.RTN_UNREACHABLE,
Table: tableID,
Family: ipFamily,
Dst: ipNet,
}
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) {
return fmt.Errorf("netlink add unreachable route: %w", err)
}
return nil
}
func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route := &netlink.Route{
Type: syscall.RTN_UNREACHABLE,
Table: tableID,
Family: ipFamily,
Dst: ipNet,
}
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) {
return fmt.Errorf("netlink remove unreachable route: %w", err)
}
return nil
}
// removeRoute removes a route from a specific routing table identified by tableID.
func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
Family: family,
Dst: ipNet,
}
if err := addNextHop(addr, intf, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err)
}
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) {
return fmt.Errorf("netlink remove route: %w", err)
}
return nil
}
func flushRoutes(tableID, family int) error {
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
if err != nil {
return fmt.Errorf("list routes from table %d: %w", tableID, err)
}
var result *multierror.Error
for i := range routes {
route := routes[i]
// unreachable default routes don't come back with Dst set
if route.Gw == nil && route.Src == nil && route.Dst == nil {
if family == netlink.FAMILY_V4 {
routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)}
} else {
routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
} }
if rt.Family != syscall.AF_INET { }
continue loop if err := netlink.RouteDel(&routes[i]); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err))
}
}
return result.ErrorOrNil()
}
// getRoutes fetches routes from a specific routing table identified by tableID.
func getRoutes(tableID, family int) ([]netip.Prefix, error) {
var prefixList []netip.Prefix
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
if err != nil {
return nil, fmt.Errorf("list routes from table %d: %v", tableID, err)
}
for _, route := range routes {
if route.Dst != nil {
addr, ok := netip.AddrFromSlice(route.Dst.IP)
if !ok {
return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP)
} }
for _, attr := range attrs { ones, _ := route.Dst.Mask.Size()
if attr.Attr.Type == syscall.RTA_DST {
addr, ok := netip.AddrFromSlice(attr.Value) prefix := netip.PrefixFrom(addr, ones)
if !ok { if prefix.IsValid() {
continue prefixList = append(prefixList, prefix)
}
mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8)
cidr, _ := mask.Size()
routePrefix := netip.PrefixFrom(addr, cidr)
if routePrefix.IsValid() && routePrefix.Addr().Is4() {
prefixList = append(prefixList, routePrefix)
}
}
} }
} }
} }
return prefixList, nil return prefixList, nil
} }
func enableIPForwarding() error { func enableIPForwarding() error {
bytes, err := os.ReadFile(ipv4ForwardingPath) bytes, err := os.ReadFile(ipv4ForwardingPath)
if err != nil { if err != nil {
return err return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err)
} }
// check if it is already enabled // check if it is already enabled
@ -147,5 +303,142 @@ func enableIPForwarding() error {
return nil return nil
} }
return os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) //nolint:gosec //nolint:gosec
if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil {
return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err)
}
return nil
}
// entryExists checks if the specified ID or name already exists in the rt_tables file
// and verifies if existing names start with "netbird_".
func entryExists(file *os.File, id int) (bool, error) {
if _, err := file.Seek(0, 0); err != nil {
return false, fmt.Errorf("seek rt_tables: %w", err)
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
var existingID int
var existingName string
if _, err := fmt.Sscanf(line, "%d %s\n", &existingID, &existingName); err == nil {
if existingID == id {
if existingName != NetbirdVPNTableName {
return true, ErrTableIDExists
}
return true, nil
}
}
}
if err := scanner.Err(); err != nil {
return false, fmt.Errorf("scan rt_tables: %w", err)
}
return false, nil
}
// addRoutingTableName adds human-readable names for custom routing tables.
func addRoutingTableName() error {
file, err := os.Open(rtTablesPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return fmt.Errorf("open rt_tables: %w", err)
}
defer func() {
if err := file.Close(); err != nil {
log.Errorf("Error closing rt_tables: %v", err)
}
}()
exists, err := entryExists(file, NetbirdVPNTableID)
if err != nil {
return fmt.Errorf("verify entry %d, %s: %w", NetbirdVPNTableID, NetbirdVPNTableName, err)
}
if exists {
return nil
}
// Reopen the file in append mode to add new entries
if err := file.Close(); err != nil {
log.Errorf("Error closing rt_tables before appending: %v", err)
}
file, err = os.OpenFile(rtTablesPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644)
if err != nil {
return fmt.Errorf("open rt_tables for appending: %w", err)
}
if _, err := file.WriteString(fmt.Sprintf("\n%d\t%s\n", NetbirdVPNTableID, NetbirdVPNTableName)); err != nil {
return fmt.Errorf("append entry to rt_tables: %w", err)
}
return nil
}
// addRule adds a routing rule to a specific routing table identified by tableID.
func addRule(params ruleParams) error {
rule := netlink.NewRule()
rule.Table = params.tableID
rule.Mark = params.fwmark
rule.Family = params.family
rule.Priority = params.priority
rule.Invert = params.invert
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleAdd(rule); err != nil {
return fmt.Errorf("add routing rule: %w", err)
}
return nil
}
// removeRule removes a routing rule from a specific routing table identified by tableID.
func removeRule(params ruleParams) error {
rule := netlink.NewRule()
rule.Table = params.tableID
rule.Mark = params.fwmark
rule.Family = params.family
rule.Invert = params.invert
rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleDel(rule); err != nil {
return fmt.Errorf("remove routing rule: %w", err)
}
return nil
}
func removeAllRules(params ruleParams) error {
for {
if err := removeRule(params); err != nil {
if errors.Is(err, syscall.ENOENT) {
break
}
return err
}
}
return nil
}
// addNextHop adds the gateway and device to the route.
func addNextHop(addr *string, intf *string, route *netlink.Route) error {
if addr != nil {
ip := net.ParseIP(*addr)
if ip == nil {
return fmt.Errorf("parsing address %s failed", *addr)
}
route.Gw = ip
}
if intf != nil {
link, err := netlink.LinkByName(*intf)
if err != nil {
return fmt.Errorf("set interface %s: %w", *intf, err)
}
route.LinkIndex = link.Attrs().Index
}
return nil
} }

View File

@ -0,0 +1,469 @@
//go:build !android
package routemanager
import (
"errors"
"fmt"
"net"
"net/netip"
"os"
"strings"
"syscall"
"testing"
"time"
"github.com/gopacket/gopacket"
"github.com/gopacket/gopacket/layers"
"github.com/gopacket/gopacket/pcap"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
type PacketExpectation struct {
SrcIP net.IP
DstIP net.IP
SrcPort int
DstPort int
UDP bool
TCP bool
}
func TestEntryExists(t *testing.T) {
tempDir := t.TempDir()
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
content := []string{
"1000 reserved",
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
"9999 other_table",
}
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
file, err := os.Open(tempFilePath)
require.NoError(t, err)
defer func() {
assert.NoError(t, file.Close())
}()
tests := []struct {
name string
id int
shouldExist bool
err error
}{
{
name: "ExistsWithNetbirdPrefix",
id: 7120,
shouldExist: true,
err: nil,
},
{
name: "ExistsWithDifferentName",
id: 1000,
shouldExist: true,
err: ErrTableIDExists,
},
{
name: "DoesNotExist",
id: 1234,
shouldExist: false,
err: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
exists, err := entryExists(file, tc.id)
if tc.err != nil {
assert.ErrorIs(t, err, tc.err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.shouldExist, exists)
})
}
}
func TestRoutingWithTables(t *testing.T) {
testCases := []struct {
name string
destination string
captureInterface string
dialer *net.Dialer
packetExpectation PacketExpectation
}{
{
name: "To external host without fwmark via vpn",
destination: "192.0.2.1:53",
captureInterface: "wgtest0",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To external host with fwmark via physical interface",
destination: "192.0.2.1:53",
captureInterface: "dummyext0",
dialer: nbnet.NewDialer(),
packetExpectation: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To duplicate internal route with fwmark via physical interface",
destination: "10.0.0.1:53",
captureInterface: "dummyint0",
dialer: nbnet.NewDialer(),
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53),
},
{
name: "To duplicate internal route without fwmark via physical interface", // local route takes precedence
destination: "10.0.0.1:53",
captureInterface: "dummyint0",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53),
},
{
name: "To unique vpn route with fwmark via physical interface",
destination: "172.16.0.1:53",
captureInterface: "dummyext0",
dialer: nbnet.NewDialer(),
packetExpectation: createPacketExpectation("192.168.0.1", 12345, "172.16.0.1", 53),
},
{
name: "To unique vpn route without fwmark via vpn",
destination: "172.16.0.1:53",
captureInterface: "wgtest0",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("100.64.0.1", 12345, "172.16.0.1", 53),
},
{
name: "To more specific route without fwmark via vpn interface",
destination: "10.10.0.1:53",
captureInterface: "dummyint0",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.10.0.1", 53),
},
{
name: "To more specific route (local) without fwmark via physical interface",
destination: "127.0.10.1:53",
captureInterface: "lo",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
wgIface, _, _ := setupTestEnv(t)
// default route exists in main table and vpn table
err := addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
// 10.0.0.0/8 route exists in main table and vpn table
err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
// 10.10.0.0/24 more specific route exists in vpn table
err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
// 127.0.10.0/24 more specific route exists in vpn table
err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
// unique route in vpn table
err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/16"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
filter := createBPFFilter(tc.destination)
handle := startPacketCapture(t, tc.captureInterface, filter)
sendTestPacket(t, tc.destination, tc.packetExpectation.SrcPort, tc.dialer)
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
packet, err := packetSource.NextPacket()
require.NoError(t, err)
verifyPacket(t, packet, tc.packetExpectation)
})
}
}
func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) {
t.Helper()
ipLayer := packet.Layer(layers.LayerTypeIPv4)
require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet")
ip, ok := ipLayer.(*layers.IPv4)
require.True(t, ok, "Failed to cast to IPv4 layer")
// Convert both source and destination IP addresses to 16-byte representation
expectedSrcIP := exp.SrcIP.To16()
actualSrcIP := ip.SrcIP.To16()
assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch")
expectedDstIP := exp.DstIP.To16()
actualDstIP := ip.DstIP.To16()
assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch")
if exp.UDP {
udpLayer := packet.Layer(layers.LayerTypeUDP)
require.NotNil(t, udpLayer, "Expected UDP layer not found in packet")
udp, ok := udpLayer.(*layers.UDP)
require.True(t, ok, "Failed to cast to UDP layer")
assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch")
assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch")
}
if exp.TCP {
tcpLayer := packet.Layer(layers.LayerTypeTCP)
require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet")
tcp, ok := tcpLayer.(*layers.TCP)
require.True(t, ok, "Failed to cast to TCP layer")
assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch")
assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch")
}
}
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) *netlink.Dummy {
t.Helper()
dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}}
err := netlink.LinkDel(dummy)
if err != nil && !errors.Is(err, syscall.EINVAL) {
t.Logf("Failed to delete dummy interface: %v", err)
}
err = netlink.LinkAdd(dummy)
require.NoError(t, err)
err = netlink.LinkSetUp(dummy)
require.NoError(t, err)
if ipAddressCIDR != "" {
addr, err := netlink.ParseAddr(ipAddressCIDR)
require.NoError(t, err)
err = netlink.AddrAdd(dummy, addr)
require.NoError(t, err)
}
return dummy
}
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) {
t.Helper()
_, dstIPNet, err := net.ParseCIDR(dstCIDR)
require.NoError(t, err)
if dstIPNet.String() == "0.0.0.0/0" {
gw, linkIndex, err := fetchOriginalGateway(netlink.FAMILY_V4)
if err != nil {
t.Logf("Failed to fetch original gateway: %v", err)
}
// Handle existing routes with metric 0
err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0})
if err == nil {
t.Cleanup(func() {
err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: gw, LinkIndex: linkIndex, Priority: 0})
if err != nil && !errors.Is(err, syscall.EEXIST) {
t.Fatalf("Failed to add route: %v", err)
}
})
} else if !errors.Is(err, syscall.ESRCH) {
t.Logf("Failed to delete route: %v", err)
}
}
route := &netlink.Route{
Dst: dstIPNet,
Gw: gw,
LinkIndex: linkIndex,
}
err = netlink.RouteDel(route)
if err != nil && !errors.Is(err, syscall.ESRCH) {
t.Logf("Failed to delete route: %v", err)
}
err = netlink.RouteAdd(route)
if err != nil && !errors.Is(err, syscall.EEXIST) {
t.Fatalf("Failed to add route: %v", err)
}
}
// fetchOriginalGateway returns the original gateway IP address and the interface index.
func fetchOriginalGateway(family int) (net.IP, int, error) {
routes, err := netlink.RouteList(nil, family)
if err != nil {
return nil, 0, err
}
for _, route := range routes {
if route.Dst == nil {
return route.Gw, route.LinkIndex, nil
}
}
return nil, 0, fmt.Errorf("default route not found")
}
func setupDummyInterfacesAndRoutes(t *testing.T) (string, string) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy.Attrs().Index)
otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", nil, otherDummy.Attrs().Index)
t.Cleanup(func() {
err := netlink.LinkDel(defaultDummy)
assert.NoError(t, err)
err = netlink.LinkDel(otherDummy)
assert.NoError(t, err)
})
return defaultDummy.Name, otherDummy.Name
}
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
t.Helper()
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
newNet, err := stdnet.NewNet(nil)
require.NoError(t, err)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WireGuard interface")
err = wgInterface.Create()
require.NoError(t, err, "should create testing WireGuard interface")
t.Cleanup(func() {
wgInterface.Close()
})
return wgInterface
}
func setupTestEnv(t *testing.T) (*iface.WGIface, string, string) {
t.Helper()
defaultDummy, otherDummy := setupDummyInterfacesAndRoutes(t)
wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820)
t.Cleanup(func() {
assert.NoError(t, wgIface.Close())
})
err := setupRouting()
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
return wgIface, defaultDummy, otherDummy
}
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
t.Helper()
inactive, err := pcap.NewInactiveHandle(intf)
require.NoError(t, err, "Failed to create inactive pcap handle")
defer inactive.CleanUp()
err = inactive.SetSnapLen(1600)
require.NoError(t, err, "Failed to set snap length on inactive handle")
err = inactive.SetTimeout(time.Second * 10)
require.NoError(t, err, "Failed to set timeout on inactive handle")
err = inactive.SetImmediateMode(true)
require.NoError(t, err, "Failed to set immediate mode on inactive handle")
handle, err := inactive.Activate()
require.NoError(t, err, "Failed to activate pcap handle")
t.Cleanup(handle.Close)
err = handle.SetBPFFilter(filter)
require.NoError(t, err, "Failed to set BPF filter")
return handle
}
func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer *net.Dialer) {
t.Helper()
if dialer == nil {
dialer = &net.Dialer{}
}
if sourcePort != 0 {
localUDPAddr := &net.UDPAddr{
IP: net.IPv4zero,
Port: sourcePort,
}
dialer.LocalAddr = localUDPAddr
}
msg := new(dns.Msg)
msg.Id = dns.Id()
msg.RecursionDesired = true
msg.Question = []dns.Question{
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
conn, err := dialer.Dial("udp", destination)
require.NoError(t, err, "Failed to dial UDP")
defer conn.Close()
data, err := msg.Pack()
require.NoError(t, err, "Failed to pack DNS message")
_, err = conn.Write(data)
if err != nil {
if strings.Contains(err.Error(), "required key not available") {
t.Logf("Ignoring WireGuard key error: %v", err)
return
}
t.Fatalf("Failed to send DNS query: %v", err)
}
}
func createBPFFilter(destination string) string {
host, port, err := net.SplitHostPort(destination)
if err != nil {
return fmt.Sprintf("udp and dst host %s and dst port %s", host, port)
}
return "udp"
}
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
return PacketExpectation{
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
UDP: true,
}
}

View File

@ -1,11 +1,15 @@
//go:build !android && !ios //go:build !android
//nolint:unused
package routemanager package routemanager
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os/exec"
"runtime"
"github.com/libp2p/go-netroute" "github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -13,41 +17,16 @@ import (
var errRouteNotFound = fmt.Errorf("route not found") var errRouteNotFound = fmt.Errorf("route not found")
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { func genericAddRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
ok, err := existsInRouteTable(prefix) defaultGateway, err := getExistingRIBRouteGateway(defaultv4)
if err != nil { if err != nil && !errors.Is(err, errRouteNotFound) {
return err return fmt.Errorf("get existing route gateway: %s", err)
}
if ok {
log.Warnf("skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return err
}
if ok {
err := addRouteForCurrentDefaultGateway(prefix)
if err != nil {
log.Warnf("unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return addToRouteTable(prefix, addr)
}
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
if err != nil && err != errRouteNotFound {
return err
} }
addr := netip.MustParseAddr(defaultGateway.String()) addr := netip.MustParseAddr(defaultGateway.String())
if !prefix.Contains(addr) { if !prefix.Contains(addr) {
log.Debugf("skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix)
return nil return nil
} }
@ -59,22 +38,93 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
} }
if ok { if ok {
log.Debugf("skipping adding a new route for gateway %s because it already exists", gatewayPrefix) log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil return nil
} }
gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix) gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix)
if err != nil && err != errRouteNotFound { if err != nil && !errors.Is(err, errRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
} }
log.Debugf("adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return addToRouteTable(gatewayPrefix, gatewayHop.String()) return genericAddToRouteTable(gatewayPrefix, gatewayHop.String(), "")
}
func genericAddToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
err := genericAddRouteForCurrentDefaultGateway(prefix)
if err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return genericAddToRouteTable(prefix, addr, intf)
}
func genericRemoveFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
return genericRemoveFromRouteTable(prefix, addr, intf)
}
func genericAddToRouteTable(prefix netip.Prefix, addr, _ string) error {
cmd := exec.Command("route", "add", prefix.String(), addr)
out, err := cmd.Output()
if err != nil {
return fmt.Errorf("add route: %w", err)
}
log.Debugf(string(out))
return nil
}
func genericRemoveFromRouteTable(prefix netip.Prefix, addr, _ string) error {
args := []string{"delete", prefix.String()}
if runtime.GOOS == "darwin" {
args = append(args, addr)
}
cmd := exec.Command("route", args...)
out, err := cmd.Output()
if err != nil {
return fmt.Errorf("remove route: %w", err)
}
log.Debugf(string(out))
return nil
}
func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
r, err := netroute.New()
if err != nil {
return nil, fmt.Errorf("new netroute: %w", err)
}
_, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice())
if err != nil {
log.Errorf("Getting routes returned an error: %v", err)
return nil, errRouteNotFound
}
if gateway == nil {
return preferredSrc, nil
}
return gateway, nil
} }
func existsInRouteTable(prefix netip.Prefix) (bool, error) { func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable() routes, err := getRoutesFromTable()
if err != nil { if err != nil {
return false, err return false, fmt.Errorf("get routes from table: %w", err)
} }
for _, tableRoute := range routes { for _, tableRoute := range routes {
if tableRoute == prefix { if tableRoute == prefix {
@ -87,34 +137,12 @@ func existsInRouteTable(prefix netip.Prefix) (bool, error) {
func isSubRange(prefix netip.Prefix) (bool, error) { func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable() routes, err := getRoutesFromTable()
if err != nil { if err != nil {
return false, err return false, fmt.Errorf("get routes from table: %w", err)
} }
for _, tableRoute := range routes { for _, tableRoute := range routes {
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { if isPrefixSupported(tableRoute) && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil return true, nil
} }
} }
return false, nil return false, nil
} }
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
return removeFromRouteTable(prefix, addr)
}
func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
r, err := netroute.New()
if err != nil {
return nil, err
}
_, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice())
if err != nil {
log.Errorf("getting routes returned an error: %v", err)
return nil, errRouteNotFound
}
if gateway == nil {
return preferredSrc, nil
}
return gateway, nil
}

View File

@ -8,17 +8,63 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"os/exec"
"runtime"
"strings" "strings"
"testing" "testing"
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
) )
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
t.Helper()
if runtime.GOOS == "linux" {
outIntf, err := getOutgoingInterfaceLinux(prefix.Addr().String())
require.NoError(t, err, "getOutgoingInterfaceLinux should not return error")
if invert {
require.NotEqual(t, wgIface.Name(), outIntf, "outgoing interface should not be the wireguard interface")
} else {
require.Equal(t, wgIface.Name(), outIntf, "outgoing interface should be the wireguard interface")
}
return
}
prefixGateway, err := getExistingRIBRouteGateway(prefix)
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
} else {
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
}
}
func getOutgoingInterfaceLinux(destination string) (string, error) {
cmd := exec.Command("ip", "route", "get", destination)
output, err := cmd.Output()
if err != nil {
return "", fmt.Errorf("executing ip route get: %w", err)
}
return parseOutgoingInterface(string(output)), nil
}
func parseOutgoingInterface(routeGetOutput string) string {
fields := strings.Fields(routeGetOutput)
for i, field := range fields {
if field == "dev" && i+1 < len(fields) {
return fields[i+1]
}
}
return ""
}
func TestAddRemoveRoutes(t *testing.T) { func TestAddRemoveRoutes(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
@ -54,23 +100,26 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create() err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface") require.NoError(t, err, "should create testing wireguard interface")
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String()) require.NoError(t, setupRouting())
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err") require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
if testCase.shouldRouteToWireguard { if testCase.shouldRouteToWireguard {
require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") assertWGOutInterface(t, testCase.prefix, wgInterface, false)
} else { } else {
require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface") assertWGOutInterface(t, testCase.prefix, wgInterface, true)
} }
exists, err := existsInRouteTable(testCase.prefix) exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err") require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard { if exists && testCase.shouldRouteToWireguard {
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String()) err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name())
require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err") require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err")
prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
require.NoError(t, err, "getExistingRIBRouteGateway should not return err") require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
@ -189,16 +238,21 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
err = wgInterface.Create() err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface") require.NoError(t, err, "should create testing wireguard interface")
require.NoError(t, setupRouting())
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
MockAddr := wgInterface.Address().IP.String() MockAddr := wgInterface.Address().IP.String()
// Prepare the environment // Prepare the environment
if testCase.preExistingPrefix.IsValid() { if testCase.preExistingPrefix.IsValid() {
err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr) err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr, wgInterface.Name())
require.NoError(t, err, "should not return err when adding pre-existing route") require.NoError(t, err, "should not return err when adding pre-existing route")
} }
// Add the route // Add the route
err = addToRouteTableIfNoExists(testCase.prefix, MockAddr) err = addToRouteTableIfNoExists(testCase.prefix, MockAddr, wgInterface.Name())
require.NoError(t, err, "should not return err when adding route") require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute { if testCase.shouldAddRoute {
@ -208,7 +262,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
require.True(t, ok, "route should exist") require.True(t, ok, "route should exist")
// remove route again if added // remove route again if added
err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr) err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr, wgInterface.Name())
require.NoError(t, err, "should not return err") require.NoError(t, err, "should not return err")
} }
@ -217,72 +271,12 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
ok, err := existsInRouteTable(testCase.prefix) ok, err := existsInRouteTable(testCase.prefix)
t.Log("Buffer string: ", buf.String()) t.Log("Buffer string: ", buf.String())
require.NoError(t, err, "should not return err") require.NoError(t, err, "should not return err")
if !strings.Contains(buf.String(), "because it already exists") {
// Linux uses a separate routing table, so the route can exist in both tables.
// The main routing table takes precedence over the wireguard routing table.
if !strings.Contains(buf.String(), "because it already exists") && runtime.GOOS != "linux" {
require.False(t, ok, "route should not exist") require.False(t, ok, "route should not exist")
} }
}) })
} }
} }
func TestExistsInRouteTable(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if p.Addr().Is4() {
addressPrefixes = append(addressPrefixes, p.Masked())
}
}
for _, prefix := range addressPrefixes {
exists, err := existsInRouteTable(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
}
if !exists {
t.Fatalf("address %s should exist in route table", prefix)
}
}
}
func TestIsSubRange(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var subRangeAddressPrefixes []netip.Prefix
var nonSubRangeAddressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
}
}
for _, prefix := range subRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if !isSubRangePrefix {
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
}
}
for _, prefix := range nonSubRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if isSubRangePrefix {
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
}
}
}

View File

@ -1,41 +1,22 @@
//go:build !linux //go:build !linux || android
// +build !linux
package routemanager package routemanager
import ( import (
"net/netip"
"os/exec"
"runtime" "runtime"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func addToRouteTable(prefix netip.Prefix, addr string) error { func setupRouting() error {
cmd := exec.Command("route", "add", prefix.String(), addr)
out, err := cmd.Output()
if err != nil {
return err
}
log.Debugf(string(out))
return nil return nil
} }
func removeFromRouteTable(prefix netip.Prefix, addr string) error { func cleanupRouting() error {
args := []string{"delete", prefix.String()}
if runtime.GOOS == "darwin" {
args = append(args, addr)
}
cmd := exec.Command("route", args...)
out, err := cmd.Output()
if err != nil {
return err
}
log.Debugf(string(out))
return nil return nil
} }
func enableIPForwarding() error { func enableIPForwarding() error {
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS) log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil return nil
} }

View File

@ -0,0 +1,80 @@
//go:build !linux || android
package routemanager
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsSubRange(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var subRangeAddressPrefixes []netip.Prefix
var nonSubRangeAddressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
}
}
for _, prefix := range subRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if !isSubRangePrefix {
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
}
}
for _, prefix := range nonSubRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if isSubRangePrefix {
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
}
}
}
func TestExistsInRouteTable(t *testing.T) {
require.NoError(t, setupRouting())
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if p.Addr().Is4() {
addressPrefixes = append(addressPrefixes, p.Masked())
}
}
for _, prefix := range addressPrefixes {
exists, err := existsInRouteTable(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
}
if !exists {
t.Fatalf("address %s should exist in route table", prefix)
}
}
}

View File

@ -1,12 +1,13 @@
//go:build windows //go:build windows
// +build windows
package routemanager package routemanager
import ( import (
"fmt"
"net" "net"
"net/netip" "net/netip"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi" "github.com/yusufpapurcu/wmi"
) )
@ -21,17 +22,19 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
err := wmi.Query(query, &routes) err := wmi.Query(query, &routes)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("get routes: %w", err)
} }
var prefixList []netip.Prefix var prefixList []netip.Prefix
for _, route := range routes { for _, route := range routes {
addr, err := netip.ParseAddr(route.Destination) addr, err := netip.ParseAddr(route.Destination)
if err != nil { if err != nil {
log.Warnf("Unable to parse route destination %s: %v", route.Destination, err)
continue continue
} }
maskSlice := net.ParseIP(route.Mask).To4() maskSlice := net.ParseIP(route.Mask).To4()
if maskSlice == nil { if maskSlice == nil {
log.Warnf("Unable to parse route mask %s", route.Mask)
continue continue
} }
mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3]) mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3])
@ -44,3 +47,11 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
} }
return prefixList, nil return prefixList, nil
} }
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
return genericAddToRouteTableIfNoExists(prefix, addr, intf)
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf)
}

View File

@ -0,0 +1,24 @@
package stdnet
import (
"net"
"github.com/pion/transport/v3"
nbnet "github.com/netbirdio/netbird/util/net"
)
// Dial connects to the address on the named network.
func (n *Net) Dial(network, address string) (net.Conn, error) {
return nbnet.NewDialer().Dial(network, address)
}
// DialUDP connects to the address on the named UDP network.
func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) {
return nbnet.DialUDP(network, laddr, raddr)
}
// DialTCP connects to the address on the named TCP network.
func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) {
return nbnet.DialTCP(network, laddr, raddr)
}

View File

@ -0,0 +1,20 @@
package stdnet
import (
"context"
"net"
"github.com/pion/transport/v3"
nbnet "github.com/netbirdio/netbird/util/net"
)
// ListenPacket listens for incoming packets on the given network and address.
func (n *Net) ListenPacket(network, address string) (net.PacketConn, error) {
return nbnet.NewListener().ListenPacket(context.Background(), network, address)
}
// ListenUDP acts like ListenPacket for UDP networks.
func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) {
return nbnet.ListenUDP(network, locAddr)
}

View File

@ -1,8 +1,10 @@
package wgproxy package wgproxy
import ( import (
"context"
"fmt" "fmt"
"net"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@ -23,7 +25,7 @@ func (pl portLookup) searchFreePort() (int, error) {
} }
func (pl portLookup) tryToBind(port int) error { func (pl portLookup) tryToBind(port int) error {
l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port)) l, err := nbnet.NewListener().ListenPacket(context.Background(), "udp", fmt.Sprintf(":%d", port))
if err != nil { if err != nil {
return err return err
} }

View File

@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/ebpf" "github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
@ -66,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error {
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
} }
p.conn, err = net.ListenUDP("udp", &addr) p.conn, err = nbnet.ListenUDP("udp", &addr)
if err != nil { if err != nil {
cErr := p.Free() cErr := p.Free()
if cErr != nil { if cErr != nil {
@ -208,20 +209,41 @@ generatePort:
} }
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, err
}
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, err
} }
return net.FilePacketConn(os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))) // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
} }
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error { func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {

View File

@ -6,6 +6,8 @@ import (
"net" "net"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
) )
// WGUserSpaceProxy proxies // WGUserSpaceProxy proxies
@ -33,7 +35,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
p.remoteConn = remoteConn p.remoteConn = remoteConn
var err error var err error
p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil { if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err) log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err return nil, err

4
go.mod
View File

@ -47,8 +47,9 @@ require (
github.com/google/go-cmp v0.5.9 github.com/google/go-cmp v0.5.9
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/gopacket/gopacket v1.1.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
github.com/hashicorp/go-multierror v1.1.0 github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-version v1.6.0 github.com/hashicorp/go-version v1.6.0
github.com/libp2p/go-netroute v0.2.0 github.com/libp2p/go-netroute v0.2.0
@ -123,7 +124,6 @@ require (
github.com/google/s2a-go v0.1.4 // indirect github.com/google/s2a-go v0.1.4 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
github.com/googleapis/gax-go/v2 v2.10.0 // indirect github.com/googleapis/gax-go/v2 v2.10.0 // indirect
github.com/gopacket/gopacket v1.1.1 // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect

4
go.sum
View File

@ -291,8 +291,8 @@ github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f2
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.0 h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=
github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE=

View File

@ -23,6 +23,24 @@ func parseWGAddress(address string) (WGAddress, error) {
}, nil }, nil
} }
// Masked returns the WGAddress with the IP address part masked according to its network mask.
func (addr WGAddress) Masked() WGAddress {
ip := addr.IP.To4()
if ip == nil {
ip = addr.IP.To16()
}
maskedIP := make(net.IP, len(ip))
for i := range ip {
maskedIP[i] = ip[i] & addr.Network.Mask[i]
}
return WGAddress{
IP: maskedIP,
Network: addr.Network,
}
}
func (addr WGAddress) String() string { func (addr WGAddress) String() string {
maskSize, _ := addr.Network.Mask.Size() maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize) return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)

View File

@ -10,6 +10,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbnet "github.com/netbirdio/netbird/util/net"
) )
type wgKernelConfigurer struct { type wgKernelConfigurer struct {
@ -29,7 +31,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
if err != nil { if err != nil {
return err return err
} }
fwmark := 0 fwmark := nbnet.NetbirdFwmark
config := wgtypes.Config{ config := wgtypes.Config{
PrivateKey: &key, PrivateKey: &key,
ReplacePeers: true, ReplacePeers: true,

View File

@ -13,6 +13,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbnet "github.com/netbirdio/netbird/util/net"
) )
type wgUSPConfigurer struct { type wgUSPConfigurer struct {
@ -37,7 +39,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error
if err != nil { if err != nil {
return err return err
} }
fwmark := 0 fwmark := getFwmark()
config := wgtypes.Config{ config := wgtypes.Config{
PrivateKey: &key, PrivateKey: &key,
ReplacePeers: true, ReplacePeers: true,
@ -345,3 +347,10 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
} }
return sb.String() return sb.String()
} }
func getFwmark() int {
if runtime.GOOS == "linux" {
return nbnet.NetbirdFwmark
}
return 0
}

View File

@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
nbgrpc "github.com/netbirdio/netbird/util/grpc"
) )
const ConnectTimeout = 10 * time.Second const ConnectTimeout = 10 * time.Second
@ -57,6 +58,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
mgmCtx, mgmCtx,
addr, addr,
transportOption, transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(), grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{ grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second, Time: 30 * time.Second,

View File

@ -21,6 +21,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
nbnet "github.com/netbirdio/netbird/util/net"
) )
// ErrSharedSockStopped indicates that shared socket has been stopped // ErrSharedSockStopped indicates that shared socket has been stopped
@ -55,8 +57,7 @@ var writeSerializerOptions = gopacket.SerializeOptions{
} }
// Listen creates an IPv4 and IPv6 raw sockets, starts a reader and routing table routines // Listen creates an IPv4 and IPv6 raw sockets, starts a reader and routing table routines
func Listen(port int, filter BPFFilter) (net.PacketConn, error) { func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) {
var err error
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
rawSock := &SharedSocket{ rawSock := &SharedSocket{
ctx: ctx, ctx: ctx,
@ -65,37 +66,51 @@ func Listen(port int, filter BPFFilter) (net.PacketConn, error) {
packetDemux: make(chan rcvdPacket), packetDemux: make(chan rcvdPacket),
} }
defer func() {
if err != nil {
if closeErr := rawSock.Close(); closeErr != nil {
log.Errorf("Failed to close raw socket: %v", closeErr)
}
}
}()
rawSock.router, err = netroute.New() rawSock.router, err = netroute.New()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create raw socket router: %v", err) return nil, fmt.Errorf("failed to create raw socket router: %w", err)
} }
rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil) rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create ipv4 raw socket: %v", err) return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err)
} }
rawSock.conn6, err = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil) if err = nbnet.SetSocketMark(rawSock.conn4); err != nil {
if err != nil { return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err)
log.Errorf("failed to create ipv6 raw socket: %v", err) }
var sockErr error
rawSock.conn6, sockErr = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil)
if sockErr != nil {
log.Errorf("Failed to create ipv6 raw socket: %v", err)
} else {
if err = nbnet.SetSocketMark(rawSock.conn6); err != nil {
return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err)
}
} }
ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port)) ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port))
if err != nil { if err != nil {
_ = rawSock.Close() return nil, fmt.Errorf("getBPFInstructions failed with: %w", err)
return nil, fmt.Errorf("getBPFInstructions failed with: %rawSock", err)
} }
err = rawSock.conn4.SetBPF(ipv4Instructions) err = rawSock.conn4.SetBPF(ipv4Instructions)
if err != nil { if err != nil {
_ = rawSock.Close() return nil, fmt.Errorf("socket4.SetBPF failed with: %w", err)
return nil, fmt.Errorf("socket4.SetBPF failed with: %rawSock", err)
} }
if rawSock.conn6 != nil { if rawSock.conn6 != nil {
err = rawSock.conn6.SetBPF(ipv6Instructions) err = rawSock.conn6.SetBPF(ipv6Instructions)
if err != nil { if err != nil {
_ = rawSock.Close() return nil, fmt.Errorf("socket6.SetBPF failed with: %w", err)
return nil, fmt.Errorf("socket6.SetBPF failed with: %rawSock", err)
} }
} }
@ -121,7 +136,7 @@ func (s *SharedSocket) updateRouter() {
case <-ticker.C: case <-ticker.C:
router, err := netroute.New() router, err := netroute.New()
if err != nil { if err != nil {
log.Errorf("failed to create and update packet router for stunListener: %s", err) log.Errorf("Failed to create and update packet router for stunListener: %s", err)
continue continue
} }
s.routerMux.Lock() s.routerMux.Lock()
@ -144,7 +159,7 @@ func (s *SharedSocket) LocalAddr() net.Addr {
func (s *SharedSocket) SetDeadline(t time.Time) error { func (s *SharedSocket) SetDeadline(t time.Time) error {
err := s.conn4.SetDeadline(t) err := s.conn4.SetDeadline(t)
if err != nil { if err != nil {
return fmt.Errorf("s.conn4.SetDeadline error: %s", err) return fmt.Errorf("s.conn4.SetDeadline error: %w", err)
} }
if s.conn6 == nil { if s.conn6 == nil {
return nil return nil
@ -152,7 +167,7 @@ func (s *SharedSocket) SetDeadline(t time.Time) error {
err = s.conn6.SetDeadline(t) err = s.conn6.SetDeadline(t)
if err != nil { if err != nil {
return fmt.Errorf("s.conn6.SetDeadline error: %s", err) return fmt.Errorf("s.conn6.SetDeadline error: %w", err)
} }
return nil return nil
} }
@ -161,7 +176,7 @@ func (s *SharedSocket) SetDeadline(t time.Time) error {
func (s *SharedSocket) SetReadDeadline(t time.Time) error { func (s *SharedSocket) SetReadDeadline(t time.Time) error {
err := s.conn4.SetReadDeadline(t) err := s.conn4.SetReadDeadline(t)
if err != nil { if err != nil {
return fmt.Errorf("s.conn4.SetReadDeadline error: %s", err) return fmt.Errorf("s.conn4.SetReadDeadline error: %w", err)
} }
if s.conn6 == nil { if s.conn6 == nil {
return nil return nil
@ -169,7 +184,7 @@ func (s *SharedSocket) SetReadDeadline(t time.Time) error {
err = s.conn6.SetReadDeadline(t) err = s.conn6.SetReadDeadline(t)
if err != nil { if err != nil {
return fmt.Errorf("s.conn6.SetReadDeadline error: %s", err) return fmt.Errorf("s.conn6.SetReadDeadline error: %w", err)
} }
return nil return nil
} }
@ -178,7 +193,7 @@ func (s *SharedSocket) SetReadDeadline(t time.Time) error {
func (s *SharedSocket) SetWriteDeadline(t time.Time) error { func (s *SharedSocket) SetWriteDeadline(t time.Time) error {
err := s.conn4.SetWriteDeadline(t) err := s.conn4.SetWriteDeadline(t)
if err != nil { if err != nil {
return fmt.Errorf("s.conn4.SetWriteDeadline error: %s", err) return fmt.Errorf("s.conn4.SetWriteDeadline error: %w", err)
} }
if s.conn6 == nil { if s.conn6 == nil {
return nil return nil
@ -186,7 +201,7 @@ func (s *SharedSocket) SetWriteDeadline(t time.Time) error {
err = s.conn6.SetWriteDeadline(t) err = s.conn6.SetWriteDeadline(t)
if err != nil { if err != nil {
return fmt.Errorf("s.conn6.SetWriteDeadline error: %s", err) return fmt.Errorf("s.conn6.SetWriteDeadline error: %w", err)
} }
return nil return nil
} }
@ -282,7 +297,7 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
_, _, src, err := s.router.Route(rUDPAddr.IP) _, _, src, err := s.router.Route(rUDPAddr.IP)
if err != nil { if err != nil {
return 0, fmt.Errorf("got an error while checking route, err: %s", err) return 0, fmt.Errorf("got an error while checking route, err: %w", err)
} }
rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP) rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP)
@ -292,7 +307,7 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
} }
if err := gopacket.SerializeLayers(buffer, writeSerializerOptions, udp, payload); err != nil { if err := gopacket.SerializeLayers(buffer, writeSerializerOptions, udp, payload); err != nil {
return -1, fmt.Errorf("failed serialize rcvdPacket: %s", err) return -1, fmt.Errorf("failed serialize rcvdPacket: %w", err)
} }
bufser := buffer.Bytes() bufser := buffer.Bytes()

View File

@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/signal/proto"
nbgrpc "github.com/netbirdio/netbird/util/grpc"
) )
// ConnStateNotifier is a wrapper interface of the status recorder // ConnStateNotifier is a wrapper interface of the status recorder
@ -76,6 +77,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
sigCtx, sigCtx,
addr, addr,
transportOption, transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(), grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{ grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second, Time: 30 * time.Second,

View File

@ -0,0 +1,9 @@
//go:build !linux || android
package grpc
import "google.golang.org/grpc"
func WithCustomDialer() grpc.DialOption {
return grpc.EmptyDialOption{}
}

18
util/grpc/dialer_linux.go Normal file
View File

@ -0,0 +1,18 @@
//go:build !android
package grpc
import (
"context"
"net"
"google.golang.org/grpc"
nbnet "github.com/netbirdio/netbird/util/net"
)
func WithCustomDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return nbnet.NewDialer().DialContext(ctx, "tcp", addr)
})
}

View File

@ -0,0 +1,19 @@
//go:build !linux || android
package net
import (
"net"
)
func NewDialer() *net.Dialer {
return &net.Dialer{}
}
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
return net.DialUDP(network, laddr, raddr)
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
return net.DialTCP(network, laddr, raddr)
}

60
util/net/dialer_linux.go Normal file
View File

@ -0,0 +1,60 @@
//go:build !android
package net
import (
"context"
"fmt"
"net"
"syscall"
log "github.com/sirupsen/logrus"
)
func NewDialer() *net.Dialer {
return &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
return SetRawSocketMark(c)
},
}
}
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.DialContext(context.Background(), network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
udpConn, ok := conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type")
}
return udpConn, nil
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.DialContext(context.Background(), network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type")
}
return tcpConn, nil
}

View File

@ -0,0 +1,13 @@
//go:build !linux || android
package net
import "net"
func NewListener() *net.ListenConfig {
return &net.ListenConfig{}
}
func ListenUDP(network string, locAddr *net.UDPAddr) (*net.UDPConn, error) {
return net.ListenUDP(network, locAddr)
}

View File

@ -0,0 +1,30 @@
//go:build !android
package net
import (
"context"
"fmt"
"net"
"syscall"
)
func NewListener() *net.ListenConfig {
return &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
return SetRawSocketMark(c)
},
}
}
func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
pc, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listening on %s:%s with fwmark: %w", network, laddr, err)
}
udpConn, ok := pc.(*net.UDPConn)
if !ok {
return nil, fmt.Errorf("packetConn is not a *net.UDPConn")
}
return udpConn, nil
}

6
util/net/net.go Normal file
View File

@ -0,0 +1,6 @@
package net
const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
NetbirdFwmark = 0x1BD00
)

35
util/net/net_linux.go Normal file
View File

@ -0,0 +1,35 @@
//go:build !android
package net
import (
"fmt"
"syscall"
)
// SetSocketMark sets the SO_MARK option on the given socket connection
func SetSocketMark(conn syscall.Conn) error {
sysconn, err := conn.SyscallConn()
if err != nil {
return fmt.Errorf("get raw conn: %w", err)
}
return SetRawSocketMark(sysconn)
}
func SetRawSocketMark(conn syscall.RawConn) error {
var setErr error
err := conn.Control(func(fd uintptr) {
setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
})
if err != nil {
return fmt.Errorf("control: %w", err)
}
if setErr != nil {
return fmt.Errorf("set SO_MARK: %w", setErr)
}
return nil
}