mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
Support client default routes for Linux (#1667)
All routes are now installed in a custom netbird routing table. Management and wireguard traffic is now marked with a custom fwmark. When the mark is present the traffic is routed via the main routing table, bypassing the VPN. When the mark is absent the traffic is routed via the netbird routing table, if: - there's no match in the main routing table - it would match the default route in the routing table IPv6 traffic is blocked when a default route IPv4 route is configured to avoid leakage.
This commit is contained in:
parent
846871913d
commit
2475473227
16
.github/workflows/golang-test-linux.yml
vendored
16
.github/workflows/golang-test-linux.yml
vendored
@ -14,8 +14,8 @@ jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: ['386','amd64']
|
||||
store: ['jsonfile', 'sqlite']
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'jsonfile', 'sqlite' ]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
@ -36,7 +36,11 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
@ -67,7 +71,7 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
@ -82,7 +86,7 @@ jobs:
|
||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||
|
||||
- name: Generate RouteManager Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
||||
run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/...
|
||||
|
||||
- name: Generate nftables Manager Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
||||
@ -109,7 +113,7 @@ jobs:
|
||||
|
||||
- name: Run Engine tests in docker with file store
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
|
||||
- name: Run Engine tests in docker with sqlite store
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
|
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@ -40,7 +40,7 @@ jobs:
|
||||
cache: false
|
||||
- name: Install dependencies
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
with:
|
||||
|
@ -230,8 +230,8 @@ func (e *Engine) Start() error {
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
if err != nil {
|
||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
return err
|
||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
|
||||
return fmt.Errorf("new wg interface: %w", err)
|
||||
}
|
||||
e.wgInterface = wgIface
|
||||
|
||||
@ -244,29 +244,33 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.rpManager, err = rosenpass.NewManager(e.config.PreSharedKey, e.config.WgIfaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("create rosenpass manager: %w", err)
|
||||
}
|
||||
err := e.rpManager.Run()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("run rosenpass manager: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
initialRoutes, dnsServer, err := e.newDnsServer()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return err
|
||||
return fmt.Errorf("create dns server: %w", err)
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
|
||||
if err := e.routeManager.Init(); err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("init route manager: %w", err)
|
||||
}
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
|
||||
err = e.wgInterfaceCreate()
|
||||
if err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
e.close()
|
||||
return err
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
}
|
||||
|
||||
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
|
||||
@ -278,7 +282,7 @@ func (e *Engine) Start() error {
|
||||
err = e.routeManager.EnableServerRouter(e.firewall)
|
||||
if err != nil {
|
||||
e.close()
|
||||
return err
|
||||
return fmt.Errorf("enable server router: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -286,7 +290,7 @@ func (e *Engine) Start() error {
|
||||
if err != nil {
|
||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||
e.close()
|
||||
return err
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
if e.firewall != nil {
|
||||
@ -296,7 +300,7 @@ func (e *Engine) Start() error {
|
||||
err = e.dnsServer.Initialize()
|
||||
if err != nil {
|
||||
e.close()
|
||||
return err
|
||||
return fmt.Errorf("initialize dns server: %w", err)
|
||||
}
|
||||
|
||||
e.receiveSignalEvents()
|
||||
|
@ -10,6 +10,9 @@ import (
|
||||
"github.com/pion/stun/v2"
|
||||
"github.com/pion/turn/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// ProbeResult holds the info about the result of a relay probe request
|
||||
@ -27,7 +30,15 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
|
||||
}
|
||||
}()
|
||||
|
||||
client, err := stun.DialURI(uri, &stun.DialConfig{})
|
||||
net, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
probeErr = fmt.Errorf("new net: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
client, err := stun.DialURI(uri, &stun.DialConfig{
|
||||
Net: net,
|
||||
})
|
||||
if err != nil {
|
||||
probeErr = fmt.Errorf("dial: %w", err)
|
||||
return
|
||||
@ -85,14 +96,13 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
|
||||
switch uri.Proto {
|
||||
case stun.ProtoTypeUDP:
|
||||
var err error
|
||||
conn, err = net.ListenPacket("udp", "")
|
||||
conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "")
|
||||
if err != nil {
|
||||
probeErr = fmt.Errorf("listen: %w", err)
|
||||
return
|
||||
}
|
||||
case stun.ProtoTypeTCP:
|
||||
dialer := net.Dialer{}
|
||||
tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr)
|
||||
tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr)
|
||||
if err != nil {
|
||||
probeErr = fmt.Errorf("dial: %w", err)
|
||||
return
|
||||
@ -109,12 +119,18 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
|
||||
}
|
||||
}()
|
||||
|
||||
net, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
probeErr = fmt.Errorf("new net: %w", err)
|
||||
return
|
||||
}
|
||||
cfg := &turn.ClientConfig{
|
||||
STUNServerAddr: turnServerAddr,
|
||||
TURNServerAddr: turnServerAddr,
|
||||
Conn: conn,
|
||||
Username: uri.Username,
|
||||
Password: uri.Password,
|
||||
Net: net,
|
||||
}
|
||||
client, err := turn.NewClient(cfg)
|
||||
if err != nil {
|
||||
|
@ -41,6 +41,7 @@ type clientNetwork struct {
|
||||
|
||||
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
client := &clientNetwork{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
@ -72,6 +73,18 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
||||
return routePeerStatuses
|
||||
}
|
||||
|
||||
// getBestRouteFromStatuses determines the most optimal route from the available routes
|
||||
// within a clientNetwork, taking into account peer connection status, route metrics, and
|
||||
// preference for non-relayed and direct connections.
|
||||
//
|
||||
// It follows these prioritization rules:
|
||||
// * Connected peers: Only routes with connected peers are considered.
|
||||
// * Metric: Routes with lower metrics (better) are prioritized.
|
||||
// * Non-relayed: Routes without relays are preferred.
|
||||
// * Direct connections: Routes with direct peer connections are favored.
|
||||
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
||||
//
|
||||
// It returns the ID of the selected optimal route.
|
||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
||||
chosen := ""
|
||||
chosenScore := 0
|
||||
@ -158,7 +171,7 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("get peer state: %v", err)
|
||||
}
|
||||
|
||||
delete(state.Routes, c.network.String())
|
||||
@ -172,7 +185,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||
|
||||
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v",
|
||||
return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v",
|
||||
c.network, c.chosenRoute.Peer, err)
|
||||
}
|
||||
return nil
|
||||
@ -180,30 +193,26 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||
|
||||
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
if c.chosenRoute != nil {
|
||||
err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
|
||||
if err != nil {
|
||||
return err
|
||||
if err := removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil {
|
||||
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
|
||||
}
|
||||
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't remove route %s from system, err: %v",
|
||||
c.network, err)
|
||||
|
||||
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
|
||||
return fmt.Errorf("remove route: %v", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
|
||||
var err error
|
||||
|
||||
routerPeerStatuses := c.getRouterPeerStatuses()
|
||||
|
||||
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||
|
||||
// If no route is chosen, remove the route from the peer and system
|
||||
if chosen == "" {
|
||||
err = c.removeRouteFromPeerAndSystem()
|
||||
if err != nil {
|
||||
return err
|
||||
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
||||
return fmt.Errorf("remove route from peer and system: %v", err)
|
||||
}
|
||||
|
||||
c.chosenRoute = nil
|
||||
@ -211,6 +220,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the chosen route is the same as the current route, do nothing
|
||||
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
|
||||
if c.chosenRoute.IsEqual(c.routes[chosen]) {
|
||||
return nil
|
||||
@ -218,13 +228,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
}
|
||||
|
||||
if c.chosenRoute != nil {
|
||||
err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
|
||||
if err != nil {
|
||||
return err
|
||||
// If a previous route exists, remove it from the peer
|
||||
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
|
||||
return fmt.Errorf("remove route from peer: %v", err)
|
||||
}
|
||||
} else {
|
||||
err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String())
|
||||
if err != nil {
|
||||
// otherwise add the route to the system
|
||||
if err := addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil {
|
||||
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
||||
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
||||
}
|
||||
@ -245,8 +255,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
}
|
||||
}
|
||||
|
||||
err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String())
|
||||
if err != nil {
|
||||
if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil {
|
||||
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
|
||||
c.network, c.chosenRoute.Peer, err)
|
||||
}
|
||||
@ -287,21 +296,21 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
log.Debugf("stopping watcher for network %s", c.network)
|
||||
err := c.removeRouteFromPeerAndSystem()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err)
|
||||
}
|
||||
return
|
||||
case <-c.peerStateUpdate:
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.Errorf("Couldn't recalculate route and update peer and system: %v", err)
|
||||
}
|
||||
case update := <-c.routeUpdate:
|
||||
if update.updateSerial < c.updateSerial {
|
||||
log.Warnf("received a routes update with smaller serial number, ignoring it")
|
||||
log.Warnf("Received a routes update with smaller serial number, ignoring it")
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("received a new client network route update for %s", c.network)
|
||||
log.Debugf("Received a new client network route update for %s", c.network)
|
||||
|
||||
c.handleUpdate(update)
|
||||
|
||||
@ -309,7 +318,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err)
|
||||
}
|
||||
|
||||
c.startPeersStatusChangeWatcher()
|
||||
|
@ -2,6 +2,8 @@ package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
@ -15,8 +17,14 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
|
||||
// nolint:unused
|
||||
var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
Init() error
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
@ -56,6 +64,19 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
||||
return dm
|
||||
}
|
||||
|
||||
// Init sets up the routing
|
||||
func (m *DefaultManager) Init() error {
|
||||
if err := cleanupRouting(); err != nil {
|
||||
log.Warnf("Failed cleaning up routing: %v", err)
|
||||
}
|
||||
|
||||
if err := setupRouting(); err != nil {
|
||||
return fmt.Errorf("setup routing: %w", err)
|
||||
}
|
||||
log.Info("Routing setup complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
var err error
|
||||
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||
@ -71,9 +92,15 @@ func (m *DefaultManager) Stop() {
|
||||
if m.serverRouter != nil {
|
||||
m.serverRouter.cleanUp()
|
||||
}
|
||||
if err := cleanupRouting(); err != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", err)
|
||||
} else {
|
||||
log.Info("Routing cleanup complete")
|
||||
}
|
||||
m.ctx = nil
|
||||
}
|
||||
|
||||
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
@ -91,7 +118,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
||||
if m.serverRouter != nil {
|
||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("update routes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -156,11 +183,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]
|
||||
for _, newRoute := range newRoutes {
|
||||
networkID := route.GetHAUniqueID(newRoute)
|
||||
if !ownNetworkIDs[networkID] {
|
||||
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
||||
// we skip this route management
|
||||
if newRoute.Network.Bits() < minRangeBits {
|
||||
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route",
|
||||
version.NetbirdVersion(), newRoute.Network)
|
||||
if !isPrefixSupported(newRoute.Network) {
|
||||
continue
|
||||
}
|
||||
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
||||
@ -178,3 +201,18 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
|
||||
}
|
||||
return rs
|
||||
}
|
||||
|
||||
func isPrefixSupported(prefix netip.Prefix) bool {
|
||||
if runtime.GOOS == "linux" {
|
||||
return true
|
||||
}
|
||||
|
||||
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
||||
// we skip this prefix management
|
||||
if prefix.Bits() < minRangeBits {
|
||||
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
|
||||
version.NetbirdVersion(), prefix)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
@ -28,13 +28,14 @@ const remotePeerKey2 = "remote1"
|
||||
|
||||
func TestManagerUpdateRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputInitRoutes []*route.Route
|
||||
inputRoutes []*route.Route
|
||||
inputSerial uint64
|
||||
removeSrvRouter bool
|
||||
serverRoutesExpected int
|
||||
clientNetworkWatchersExpected int
|
||||
name string
|
||||
inputInitRoutes []*route.Route
|
||||
inputRoutes []*route.Route
|
||||
inputSerial uint64
|
||||
removeSrvRouter bool
|
||||
serverRoutesExpected int
|
||||
clientNetworkWatchersExpected int
|
||||
clientNetworkWatchersExpectedLinux int
|
||||
}{
|
||||
{
|
||||
name: "Should create 2 client networks",
|
||||
@ -200,8 +201,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
clientNetworkWatchersExpectedLinux: 1,
|
||||
},
|
||||
{
|
||||
name: "Remove 1 Client Route",
|
||||
@ -415,6 +417,8 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
statusRecorder := peer.NewRecorder("https://mgm")
|
||||
ctx := context.TODO()
|
||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
||||
err = routeManager.Init()
|
||||
require.NoError(t, err, "should init route manager")
|
||||
defer routeManager.Stop()
|
||||
|
||||
if testCase.removeSrvRouter {
|
||||
@ -429,7 +433,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||
require.NoError(t, err, "should update routes")
|
||||
|
||||
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
|
||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||
if runtime.GOOS == "linux" && testCase.clientNetworkWatchersExpectedLinux != 0 {
|
||||
expectedWatchers = testCase.clientNetworkWatchersExpectedLinux
|
||||
}
|
||||
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
||||
|
||||
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
||||
sr := routeManager.serverRouter.(*defaultServerRouter)
|
||||
|
@ -16,6 +16,10 @@ type MockManager struct {
|
||||
StopFunc func()
|
||||
}
|
||||
|
||||
func (m *MockManager) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
||||
func (m *MockManager) InitialRouteRange() []string {
|
||||
return nil
|
||||
|
@ -4,6 +4,7 @@ package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
@ -48,7 +49,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
|
||||
oldRoute := m.routes[routeID]
|
||||
err := m.removeFromServerNetwork(oldRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
||||
log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v",
|
||||
oldRoute.ID, oldRoute.Network, err)
|
||||
}
|
||||
delete(m.routes, routeID)
|
||||
@ -62,7 +63,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
|
||||
|
||||
err := m.addToServerNetwork(newRoute)
|
||||
if err != nil {
|
||||
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||
log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||
continue
|
||||
}
|
||||
m.routes[id] = newRoute
|
||||
@ -81,15 +82,22 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
|
||||
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not removing from server network because context is done")
|
||||
log.Infof("Not removing from server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||
|
||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
|
||||
err = m.firewall.RemoveRoutingRules(routerPair)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove routing rules: %w", err)
|
||||
}
|
||||
|
||||
delete(m.routes, route.ID)
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
@ -103,15 +111,22 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
|
||||
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not adding to server network because context is done")
|
||||
log.Infof("Not adding to server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||
|
||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
|
||||
err = m.firewall.InsertRoutingRules(routerPair)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert routing rules: %w", err)
|
||||
}
|
||||
|
||||
m.routes[route.ID] = route
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
@ -129,9 +144,15 @@ func (m *defaultServerRouter) cleanUp() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
for _, r := range m.routes {
|
||||
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r))
|
||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r)
|
||||
if err != nil {
|
||||
log.Warnf("failed to remove clean up route: %s", r.ID)
|
||||
log.Errorf("Failed to convert route to router pair: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = m.firewall.RemoveRoutingRules(routerPair)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to remove cleanup route: %v", err)
|
||||
}
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
@ -139,13 +160,15 @@ func (m *defaultServerRouter) cleanUp() {
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
}
|
||||
}
|
||||
|
||||
func routeToRouterPair(source string, route *route.Route) firewall.RouterPair {
|
||||
parsed := netip.MustParsePrefix(source).Masked()
|
||||
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
|
||||
parsed, err := netip.ParsePrefix(source)
|
||||
if err != nil {
|
||||
return firewall.RouterPair{}, err
|
||||
}
|
||||
return firewall.RouterPair{
|
||||
ID: route.ID,
|
||||
Source: parsed.String(),
|
||||
Destination: route.Network.Masked().String(),
|
||||
Masquerade: route.Masquerade,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
@ -4,10 +4,10 @@ import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -1,5 +1,4 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
// +build darwin dragonfly freebsd netbsd openbsd
|
||||
|
||||
package routemanager
|
||||
|
||||
|
13
client/internal/routemanager/systemops_bsd_nonios.go
Normal file
13
client/internal/routemanager/systemops_bsd_nonios.go
Normal file
@ -0,0 +1,13 @@
|
||||
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import "net/netip"
|
||||
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
|
||||
return genericAddToRouteTableIfNoExists(prefix, addr, intf)
|
||||
}
|
||||
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
|
||||
return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf)
|
||||
}
|
@ -1,15 +1,13 @@
|
||||
//go:build ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -3,142 +3,298 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html
|
||||
// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'.
|
||||
type routeInfoInMemory struct {
|
||||
Family byte
|
||||
DstLen byte
|
||||
SrcLen byte
|
||||
TOS byte
|
||||
const (
|
||||
// NetbirdVPNTableID is the ID of the custom routing table used by Netbird.
|
||||
NetbirdVPNTableID = 0x1BD0
|
||||
// NetbirdVPNTableName is the name of the custom routing table used by Netbird.
|
||||
NetbirdVPNTableName = "netbird"
|
||||
|
||||
Table byte
|
||||
Protocol byte
|
||||
Scope byte
|
||||
Type byte
|
||||
// rtTablesPath is the path to the file containing the routing table names.
|
||||
rtTablesPath = "/etc/iproute2/rt_tables"
|
||||
|
||||
Flags uint32
|
||||
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
||||
ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
||||
)
|
||||
|
||||
var ErrTableIDExists = errors.New("ID exists with different name")
|
||||
|
||||
type ruleParams struct {
|
||||
fwmark int
|
||||
tableID int
|
||||
family int
|
||||
priority int
|
||||
invert bool
|
||||
suppressPrefix int
|
||||
description string
|
||||
}
|
||||
|
||||
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
||||
func getSetupRules() []ruleParams {
|
||||
return []ruleParams{
|
||||
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "add rule v4 netbird"},
|
||||
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "add rule v6 netbird"},
|
||||
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "add rule with suppress prefixlen v4"},
|
||||
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "add rule with suppress prefixlen v6"},
|
||||
}
|
||||
}
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, addr string) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return err
|
||||
// setupRouting establishes the routing configuration for the VPN, including essential rules
|
||||
// to ensure proper traffic flow for management, locally configured routes, and VPN traffic.
|
||||
//
|
||||
// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over
|
||||
// potential routes received and configured for the VPN. This rule is skipped for the default route and routes
|
||||
// that are not in the main table.
|
||||
//
|
||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||
// This table is where a default route or other specific routes received from the management server are configured,
|
||||
// enabling VPN connectivity.
|
||||
//
|
||||
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
|
||||
func setupRouting() (err error) {
|
||||
if err = addRoutingTableName(); err != nil {
|
||||
log.Errorf("Error adding routing table name: %v", err)
|
||||
}
|
||||
|
||||
addrMask := "/32"
|
||||
if prefix.Addr().Unmap().Is6() {
|
||||
addrMask = "/128"
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if cleanErr := cleanupRouting(); cleanErr != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ip, _, err := net.ParseCIDR(addr + addrMask)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
route := &netlink.Route{
|
||||
Scope: netlink.SCOPE_UNIVERSE,
|
||||
Dst: ipNet,
|
||||
Gw: ip,
|
||||
}
|
||||
|
||||
err = netlink.RouteAdd(route)
|
||||
if err != nil {
|
||||
return err
|
||||
rules := getSetupRules()
|
||||
for _, rule := range rules {
|
||||
if err := addRule(rule); err != nil {
|
||||
return fmt.Errorf("%s: %w", rule.description, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, addr string) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return err
|
||||
// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||
func cleanupRouting() error {
|
||||
var result *multierror.Error
|
||||
|
||||
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("flush routes v4: %w", err))
|
||||
}
|
||||
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("flush routes v6: %w", err))
|
||||
}
|
||||
|
||||
addrMask := "/32"
|
||||
if prefix.Addr().Unmap().Is6() {
|
||||
addrMask = "/128"
|
||||
rules := getSetupRules()
|
||||
for _, rule := range rules {
|
||||
if err := removeAllRules(rule); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
|
||||
}
|
||||
}
|
||||
|
||||
ip, _, err := net.ParseCIDR(addr + addrMask)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
route := &netlink.Route{
|
||||
Scope: netlink.SCOPE_UNIVERSE,
|
||||
Dst: ipNet,
|
||||
Gw: ip,
|
||||
}
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, _ string, intf string) error {
|
||||
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 2
|
||||
|
||||
err = netlink.RouteDel(route)
|
||||
if err != nil {
|
||||
return err
|
||||
// TODO remove this once we have ipv6 support
|
||||
if prefix == defaultv4 {
|
||||
if err := addUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
|
||||
return fmt.Errorf("add blackhole: %w", err)
|
||||
}
|
||||
}
|
||||
if err := addRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
||||
return fmt.Errorf("add route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, _ string, intf string) error {
|
||||
// TODO remove this once we have ipv6 support
|
||||
if prefix == defaultv4 {
|
||||
if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
|
||||
return fmt.Errorf("remove unreachable route: %w", err)
|
||||
}
|
||||
}
|
||||
if err := removeRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return getRoutes(NetbirdVPNTableID, netlink.FAMILY_V4)
|
||||
}
|
||||
|
||||
// addRoute adds a route to a specific routing table identified by tableID.
|
||||
func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error {
|
||||
route := &netlink.Route{
|
||||
Scope: netlink.SCOPE_UNIVERSE,
|
||||
Table: tableID,
|
||||
Family: family,
|
||||
}
|
||||
msgs, err := syscall.ParseNetlinkMessage(tab)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
if prefix != nil {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||
}
|
||||
route.Dst = ipNet
|
||||
}
|
||||
var prefixList []netip.Prefix
|
||||
loop:
|
||||
for _, m := range msgs {
|
||||
switch m.Header.Type {
|
||||
case syscall.NLMSG_DONE:
|
||||
break loop
|
||||
case syscall.RTM_NEWROUTE:
|
||||
rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0]))
|
||||
msg := m
|
||||
attrs, err := syscall.ParseNetlinkRouteAttr(&msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
if err := addNextHop(addr, intf, route); err != nil {
|
||||
return fmt.Errorf("add gateway and device: %w", err)
|
||||
}
|
||||
|
||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||
return fmt.Errorf("netlink add route: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
|
||||
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
|
||||
// tableID specifies the routing table to which the unreachable route will be added.
|
||||
func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||
}
|
||||
|
||||
route := &netlink.Route{
|
||||
Type: syscall.RTN_UNREACHABLE,
|
||||
Table: tableID,
|
||||
Family: ipFamily,
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||
return fmt.Errorf("netlink add unreachable route: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||
}
|
||||
|
||||
route := &netlink.Route{
|
||||
Type: syscall.RTN_UNREACHABLE,
|
||||
Table: tableID,
|
||||
Family: ipFamily,
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) {
|
||||
return fmt.Errorf("netlink remove unreachable route: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// removeRoute removes a route from a specific routing table identified by tableID.
|
||||
func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||
}
|
||||
|
||||
route := &netlink.Route{
|
||||
Scope: netlink.SCOPE_UNIVERSE,
|
||||
Table: tableID,
|
||||
Family: family,
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
if err := addNextHop(addr, intf, route); err != nil {
|
||||
return fmt.Errorf("add gateway and device: %w", err)
|
||||
}
|
||||
|
||||
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) {
|
||||
return fmt.Errorf("netlink remove route: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func flushRoutes(tableID, family int) error {
|
||||
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list routes from table %d: %w", tableID, err)
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for i := range routes {
|
||||
route := routes[i]
|
||||
// unreachable default routes don't come back with Dst set
|
||||
if route.Gw == nil && route.Src == nil && route.Dst == nil {
|
||||
if family == netlink.FAMILY_V4 {
|
||||
routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)}
|
||||
} else {
|
||||
routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
|
||||
}
|
||||
if rt.Family != syscall.AF_INET {
|
||||
continue loop
|
||||
}
|
||||
if err := netlink.RouteDel(&routes[i]); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err))
|
||||
}
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// getRoutes fetches routes from a specific routing table identified by tableID.
|
||||
func getRoutes(tableID, family int) ([]netip.Prefix, error) {
|
||||
var prefixList []netip.Prefix
|
||||
|
||||
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list routes from table %d: %v", tableID, err)
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Dst != nil {
|
||||
addr, ok := netip.AddrFromSlice(route.Dst.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP)
|
||||
}
|
||||
|
||||
for _, attr := range attrs {
|
||||
if attr.Attr.Type == syscall.RTA_DST {
|
||||
addr, ok := netip.AddrFromSlice(attr.Value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8)
|
||||
cidr, _ := mask.Size()
|
||||
routePrefix := netip.PrefixFrom(addr, cidr)
|
||||
if routePrefix.IsValid() && routePrefix.Addr().Is4() {
|
||||
prefixList = append(prefixList, routePrefix)
|
||||
}
|
||||
}
|
||||
ones, _ := route.Dst.Mask.Size()
|
||||
|
||||
prefix := netip.PrefixFrom(addr, ones)
|
||||
if prefix.IsValid() {
|
||||
prefixList = append(prefixList, prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return prefixList, nil
|
||||
}
|
||||
|
||||
func enableIPForwarding() error {
|
||||
bytes, err := os.ReadFile(ipv4ForwardingPath)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err)
|
||||
}
|
||||
|
||||
// check if it is already enabled
|
||||
@ -147,5 +303,142 @@ func enableIPForwarding() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) //nolint:gosec
|
||||
//nolint:gosec
|
||||
if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil {
|
||||
return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// entryExists checks if the specified ID or name already exists in the rt_tables file
|
||||
// and verifies if existing names start with "netbird_".
|
||||
func entryExists(file *os.File, id int) (bool, error) {
|
||||
if _, err := file.Seek(0, 0); err != nil {
|
||||
return false, fmt.Errorf("seek rt_tables: %w", err)
|
||||
}
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
var existingID int
|
||||
var existingName string
|
||||
if _, err := fmt.Sscanf(line, "%d %s\n", &existingID, &existingName); err == nil {
|
||||
if existingID == id {
|
||||
if existingName != NetbirdVPNTableName {
|
||||
return true, ErrTableIDExists
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return false, fmt.Errorf("scan rt_tables: %w", err)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// addRoutingTableName adds human-readable names for custom routing tables.
|
||||
func addRoutingTableName() error {
|
||||
file, err := os.Open(rtTablesPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("open rt_tables: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
log.Errorf("Error closing rt_tables: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
exists, err := entryExists(file, NetbirdVPNTableID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("verify entry %d, %s: %w", NetbirdVPNTableID, NetbirdVPNTableName, err)
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reopen the file in append mode to add new entries
|
||||
if err := file.Close(); err != nil {
|
||||
log.Errorf("Error closing rt_tables before appending: %v", err)
|
||||
}
|
||||
file, err = os.OpenFile(rtTablesPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open rt_tables for appending: %w", err)
|
||||
}
|
||||
|
||||
if _, err := file.WriteString(fmt.Sprintf("\n%d\t%s\n", NetbirdVPNTableID, NetbirdVPNTableName)); err != nil {
|
||||
return fmt.Errorf("append entry to rt_tables: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRule adds a routing rule to a specific routing table identified by tableID.
|
||||
func addRule(params ruleParams) error {
|
||||
rule := netlink.NewRule()
|
||||
rule.Table = params.tableID
|
||||
rule.Mark = params.fwmark
|
||||
rule.Family = params.family
|
||||
rule.Priority = params.priority
|
||||
rule.Invert = params.invert
|
||||
rule.SuppressPrefixlen = params.suppressPrefix
|
||||
|
||||
if err := netlink.RuleAdd(rule); err != nil {
|
||||
return fmt.Errorf("add routing rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeRule removes a routing rule from a specific routing table identified by tableID.
|
||||
func removeRule(params ruleParams) error {
|
||||
rule := netlink.NewRule()
|
||||
rule.Table = params.tableID
|
||||
rule.Mark = params.fwmark
|
||||
rule.Family = params.family
|
||||
rule.Invert = params.invert
|
||||
rule.Priority = params.priority
|
||||
rule.SuppressPrefixlen = params.suppressPrefix
|
||||
|
||||
if err := netlink.RuleDel(rule); err != nil {
|
||||
return fmt.Errorf("remove routing rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeAllRules(params ruleParams) error {
|
||||
for {
|
||||
if err := removeRule(params); err != nil {
|
||||
if errors.Is(err, syscall.ENOENT) {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNextHop adds the gateway and device to the route.
|
||||
func addNextHop(addr *string, intf *string, route *netlink.Route) error {
|
||||
if addr != nil {
|
||||
ip := net.ParseIP(*addr)
|
||||
if ip == nil {
|
||||
return fmt.Errorf("parsing address %s failed", *addr)
|
||||
}
|
||||
|
||||
route.Gw = ip
|
||||
}
|
||||
|
||||
if intf != nil {
|
||||
link, err := netlink.LinkByName(*intf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set interface %s: %w", *intf, err)
|
||||
}
|
||||
route.LinkIndex = link.Attrs().Index
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
469
client/internal/routemanager/systemops_linux_test.go
Normal file
469
client/internal/routemanager/systemops_linux_test.go
Normal file
@ -0,0 +1,469 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gopacket/gopacket"
|
||||
"github.com/gopacket/gopacket/layers"
|
||||
"github.com/gopacket/gopacket/pcap"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type PacketExpectation struct {
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
SrcPort int
|
||||
DstPort int
|
||||
UDP bool
|
||||
TCP bool
|
||||
}
|
||||
|
||||
func TestEntryExists(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
||||
|
||||
content := []string{
|
||||
"1000 reserved",
|
||||
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
||||
"9999 other_table",
|
||||
}
|
||||
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
||||
|
||||
file, err := os.Open(tempFilePath)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
assert.NoError(t, file.Close())
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id int
|
||||
shouldExist bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "ExistsWithNetbirdPrefix",
|
||||
id: 7120,
|
||||
shouldExist: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "ExistsWithDifferentName",
|
||||
id: 1000,
|
||||
shouldExist: true,
|
||||
err: ErrTableIDExists,
|
||||
},
|
||||
{
|
||||
name: "DoesNotExist",
|
||||
id: 1234,
|
||||
shouldExist: false,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
exists, err := entryExists(file, tc.id)
|
||||
if tc.err != nil {
|
||||
assert.ErrorIs(t, err, tc.err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.shouldExist, exists)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutingWithTables(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
destination string
|
||||
captureInterface string
|
||||
dialer *net.Dialer
|
||||
packetExpectation PacketExpectation
|
||||
}{
|
||||
{
|
||||
name: "To external host without fwmark via vpn",
|
||||
destination: "192.0.2.1:53",
|
||||
captureInterface: "wgtest0",
|
||||
dialer: &net.Dialer{},
|
||||
packetExpectation: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To external host with fwmark via physical interface",
|
||||
destination: "192.0.2.1:53",
|
||||
captureInterface: "dummyext0",
|
||||
dialer: nbnet.NewDialer(),
|
||||
packetExpectation: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To duplicate internal route with fwmark via physical interface",
|
||||
destination: "10.0.0.1:53",
|
||||
captureInterface: "dummyint0",
|
||||
dialer: nbnet.NewDialer(),
|
||||
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To duplicate internal route without fwmark via physical interface", // local route takes precedence
|
||||
destination: "10.0.0.1:53",
|
||||
captureInterface: "dummyint0",
|
||||
dialer: &net.Dialer{},
|
||||
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To unique vpn route with fwmark via physical interface",
|
||||
destination: "172.16.0.1:53",
|
||||
captureInterface: "dummyext0",
|
||||
dialer: nbnet.NewDialer(),
|
||||
packetExpectation: createPacketExpectation("192.168.0.1", 12345, "172.16.0.1", 53),
|
||||
},
|
||||
{
|
||||
name: "To unique vpn route without fwmark via vpn",
|
||||
destination: "172.16.0.1:53",
|
||||
captureInterface: "wgtest0",
|
||||
dialer: &net.Dialer{},
|
||||
packetExpectation: createPacketExpectation("100.64.0.1", 12345, "172.16.0.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To more specific route without fwmark via vpn interface",
|
||||
destination: "10.10.0.1:53",
|
||||
captureInterface: "dummyint0",
|
||||
dialer: &net.Dialer{},
|
||||
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.10.0.1", 53),
|
||||
},
|
||||
|
||||
{
|
||||
name: "To more specific route (local) without fwmark via physical interface",
|
||||
destination: "127.0.10.1:53",
|
||||
captureInterface: "lo",
|
||||
dialer: &net.Dialer{},
|
||||
packetExpectation: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
wgIface, _, _ := setupTestEnv(t)
|
||||
|
||||
// default route exists in main table and vpn table
|
||||
err := addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name())
|
||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
||||
|
||||
// 10.0.0.0/8 route exists in main table and vpn table
|
||||
err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name())
|
||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
||||
|
||||
// 10.10.0.0/24 more specific route exists in vpn table
|
||||
err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name())
|
||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
||||
|
||||
// 127.0.10.0/24 more specific route exists in vpn table
|
||||
err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name())
|
||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
||||
|
||||
// unique route in vpn table
|
||||
err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/16"), wgIface.Address().IP.String(), wgIface.Name())
|
||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
||||
|
||||
filter := createBPFFilter(tc.destination)
|
||||
handle := startPacketCapture(t, tc.captureInterface, filter)
|
||||
|
||||
sendTestPacket(t, tc.destination, tc.packetExpectation.SrcPort, tc.dialer)
|
||||
|
||||
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
|
||||
packet, err := packetSource.NextPacket()
|
||||
require.NoError(t, err)
|
||||
|
||||
verifyPacket(t, packet, tc.packetExpectation)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) {
|
||||
t.Helper()
|
||||
|
||||
ipLayer := packet.Layer(layers.LayerTypeIPv4)
|
||||
require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet")
|
||||
|
||||
ip, ok := ipLayer.(*layers.IPv4)
|
||||
require.True(t, ok, "Failed to cast to IPv4 layer")
|
||||
|
||||
// Convert both source and destination IP addresses to 16-byte representation
|
||||
expectedSrcIP := exp.SrcIP.To16()
|
||||
actualSrcIP := ip.SrcIP.To16()
|
||||
assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch")
|
||||
|
||||
expectedDstIP := exp.DstIP.To16()
|
||||
actualDstIP := ip.DstIP.To16()
|
||||
assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch")
|
||||
|
||||
if exp.UDP {
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
require.NotNil(t, udpLayer, "Expected UDP layer not found in packet")
|
||||
|
||||
udp, ok := udpLayer.(*layers.UDP)
|
||||
require.True(t, ok, "Failed to cast to UDP layer")
|
||||
|
||||
assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch")
|
||||
assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch")
|
||||
}
|
||||
|
||||
if exp.TCP {
|
||||
tcpLayer := packet.Layer(layers.LayerTypeTCP)
|
||||
require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet")
|
||||
|
||||
tcp, ok := tcpLayer.(*layers.TCP)
|
||||
require.True(t, ok, "Failed to cast to TCP layer")
|
||||
|
||||
assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch")
|
||||
assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) *netlink.Dummy {
|
||||
t.Helper()
|
||||
|
||||
dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}}
|
||||
err := netlink.LinkDel(dummy)
|
||||
if err != nil && !errors.Is(err, syscall.EINVAL) {
|
||||
t.Logf("Failed to delete dummy interface: %v", err)
|
||||
}
|
||||
|
||||
err = netlink.LinkAdd(dummy)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = netlink.LinkSetUp(dummy)
|
||||
require.NoError(t, err)
|
||||
|
||||
if ipAddressCIDR != "" {
|
||||
addr, err := netlink.ParseAddr(ipAddressCIDR)
|
||||
require.NoError(t, err)
|
||||
err = netlink.AddrAdd(dummy, addr)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return dummy
|
||||
}
|
||||
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) {
|
||||
t.Helper()
|
||||
|
||||
_, dstIPNet, err := net.ParseCIDR(dstCIDR)
|
||||
require.NoError(t, err)
|
||||
|
||||
if dstIPNet.String() == "0.0.0.0/0" {
|
||||
gw, linkIndex, err := fetchOriginalGateway(netlink.FAMILY_V4)
|
||||
if err != nil {
|
||||
t.Logf("Failed to fetch original gateway: %v", err)
|
||||
}
|
||||
|
||||
// Handle existing routes with metric 0
|
||||
err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0})
|
||||
if err == nil {
|
||||
t.Cleanup(func() {
|
||||
err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: gw, LinkIndex: linkIndex, Priority: 0})
|
||||
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||
t.Fatalf("Failed to add route: %v", err)
|
||||
}
|
||||
})
|
||||
} else if !errors.Is(err, syscall.ESRCH) {
|
||||
t.Logf("Failed to delete route: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
route := &netlink.Route{
|
||||
Dst: dstIPNet,
|
||||
Gw: gw,
|
||||
LinkIndex: linkIndex,
|
||||
}
|
||||
err = netlink.RouteDel(route)
|
||||
if err != nil && !errors.Is(err, syscall.ESRCH) {
|
||||
t.Logf("Failed to delete route: %v", err)
|
||||
}
|
||||
|
||||
err = netlink.RouteAdd(route)
|
||||
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||
t.Fatalf("Failed to add route: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// fetchOriginalGateway returns the original gateway IP address and the interface index.
|
||||
func fetchOriginalGateway(family int) (net.IP, int, error) {
|
||||
routes, err := netlink.RouteList(nil, family)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Dst == nil {
|
||||
return route.Gw, route.LinkIndex, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, 0, fmt.Errorf("default route not found")
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) (string, string) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24")
|
||||
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy.Attrs().Index)
|
||||
|
||||
otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24")
|
||||
addDummyRoute(t, "10.0.0.0/8", nil, otherDummy.Attrs().Index)
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := netlink.LinkDel(defaultDummy)
|
||||
assert.NoError(t, err)
|
||||
err = netlink.LinkDel(otherDummy)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
return defaultDummy.Name, otherDummy.Name
|
||||
}
|
||||
|
||||
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
|
||||
t.Helper()
|
||||
|
||||
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
require.NoError(t, err, "should create testing WireGuard interface")
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing WireGuard interface")
|
||||
|
||||
t.Cleanup(func() {
|
||||
wgInterface.Close()
|
||||
})
|
||||
|
||||
return wgInterface
|
||||
}
|
||||
|
||||
func setupTestEnv(t *testing.T) (*iface.WGIface, string, string) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy, otherDummy := setupDummyInterfacesAndRoutes(t)
|
||||
|
||||
wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, wgIface.Close())
|
||||
})
|
||||
|
||||
err := setupRouting()
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
})
|
||||
|
||||
return wgIface, defaultDummy, otherDummy
|
||||
}
|
||||
|
||||
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
||||
t.Helper()
|
||||
|
||||
inactive, err := pcap.NewInactiveHandle(intf)
|
||||
require.NoError(t, err, "Failed to create inactive pcap handle")
|
||||
defer inactive.CleanUp()
|
||||
|
||||
err = inactive.SetSnapLen(1600)
|
||||
require.NoError(t, err, "Failed to set snap length on inactive handle")
|
||||
|
||||
err = inactive.SetTimeout(time.Second * 10)
|
||||
require.NoError(t, err, "Failed to set timeout on inactive handle")
|
||||
|
||||
err = inactive.SetImmediateMode(true)
|
||||
require.NoError(t, err, "Failed to set immediate mode on inactive handle")
|
||||
|
||||
handle, err := inactive.Activate()
|
||||
require.NoError(t, err, "Failed to activate pcap handle")
|
||||
t.Cleanup(handle.Close)
|
||||
|
||||
err = handle.SetBPFFilter(filter)
|
||||
require.NoError(t, err, "Failed to set BPF filter")
|
||||
|
||||
return handle
|
||||
}
|
||||
|
||||
func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer *net.Dialer) {
|
||||
t.Helper()
|
||||
|
||||
if dialer == nil {
|
||||
dialer = &net.Dialer{}
|
||||
}
|
||||
|
||||
if sourcePort != 0 {
|
||||
localUDPAddr := &net.UDPAddr{
|
||||
IP: net.IPv4zero,
|
||||
Port: sourcePort,
|
||||
}
|
||||
dialer.LocalAddr = localUDPAddr
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.Id = dns.Id()
|
||||
msg.RecursionDesired = true
|
||||
msg.Question = []dns.Question{
|
||||
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
|
||||
conn, err := dialer.Dial("udp", destination)
|
||||
require.NoError(t, err, "Failed to dial UDP")
|
||||
defer conn.Close()
|
||||
|
||||
data, err := msg.Pack()
|
||||
require.NoError(t, err, "Failed to pack DNS message")
|
||||
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "required key not available") {
|
||||
t.Logf("Ignoring WireGuard key error: %v", err)
|
||||
return
|
||||
}
|
||||
t.Fatalf("Failed to send DNS query: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createBPFFilter(destination string) string {
|
||||
host, port, err := net.SplitHostPort(destination)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("udp and dst host %s and dst port %s", host, port)
|
||||
}
|
||||
return "udp"
|
||||
}
|
||||
|
||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
||||
return PacketExpectation{
|
||||
SrcIP: net.ParseIP(srcIP),
|
||||
DstIP: net.ParseIP(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
UDP: true,
|
||||
}
|
||||
}
|
@ -1,11 +1,15 @@
|
||||
//go:build !android && !ios
|
||||
//go:build !android
|
||||
|
||||
//nolint:unused
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/libp2p/go-netroute"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -13,41 +17,16 @@ import (
|
||||
|
||||
var errRouteNotFound = fmt.Errorf("route not found")
|
||||
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
||||
ok, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
log.Warnf("skipping adding a new route for network %s because it already exists", prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, err = isSubRange(prefix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok {
|
||||
err := addRouteForCurrentDefaultGateway(prefix)
|
||||
if err != nil {
|
||||
log.Warnf("unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return addToRouteTable(prefix, addr)
|
||||
}
|
||||
|
||||
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
||||
if err != nil && err != errRouteNotFound {
|
||||
return err
|
||||
func genericAddRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
defaultGateway, err := getExistingRIBRouteGateway(defaultv4)
|
||||
if err != nil && !errors.Is(err, errRouteNotFound) {
|
||||
return fmt.Errorf("get existing route gateway: %s", err)
|
||||
}
|
||||
|
||||
addr := netip.MustParseAddr(defaultGateway.String())
|
||||
|
||||
if !prefix.Contains(addr) {
|
||||
log.Debugf("skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix)
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -59,22 +38,93 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
}
|
||||
|
||||
if ok {
|
||||
log.Debugf("skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix)
|
||||
if err != nil && err != errRouteNotFound {
|
||||
if err != nil && !errors.Is(err, errRouteNotFound) {
|
||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||
}
|
||||
log.Debugf("adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
||||
return addToRouteTable(gatewayPrefix, gatewayHop.String())
|
||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
||||
return genericAddToRouteTable(gatewayPrefix, gatewayHop.String(), "")
|
||||
}
|
||||
|
||||
func genericAddToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
|
||||
ok, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exists in route table: %w", err)
|
||||
}
|
||||
if ok {
|
||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, err = isSubRange(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sub range: %w", err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
err := genericAddRouteForCurrentDefaultGateway(prefix)
|
||||
if err != nil {
|
||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return genericAddToRouteTable(prefix, addr, intf)
|
||||
}
|
||||
|
||||
func genericRemoveFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
|
||||
return genericRemoveFromRouteTable(prefix, addr, intf)
|
||||
}
|
||||
|
||||
func genericAddToRouteTable(prefix netip.Prefix, addr, _ string) error {
|
||||
cmd := exec.Command("route", "add", prefix.String(), addr)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("add route: %w", err)
|
||||
}
|
||||
log.Debugf(string(out))
|
||||
return nil
|
||||
}
|
||||
|
||||
func genericRemoveFromRouteTable(prefix netip.Prefix, addr, _ string) error {
|
||||
args := []string{"delete", prefix.String()}
|
||||
if runtime.GOOS == "darwin" {
|
||||
args = append(args, addr)
|
||||
}
|
||||
cmd := exec.Command("route", args...)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
}
|
||||
log.Debugf(string(out))
|
||||
return nil
|
||||
}
|
||||
|
||||
func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
|
||||
r, err := netroute.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new netroute: %w", err)
|
||||
}
|
||||
_, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice())
|
||||
if err != nil {
|
||||
log.Errorf("Getting routes returned an error: %v", err)
|
||||
return nil, errRouteNotFound
|
||||
}
|
||||
|
||||
if gateway == nil {
|
||||
return preferredSrc, nil
|
||||
}
|
||||
|
||||
return gateway, nil
|
||||
}
|
||||
|
||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute == prefix {
|
||||
@ -87,34 +137,12 @@ func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||
if isPrefixSupported(tableRoute) && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
|
||||
return removeFromRouteTable(prefix, addr)
|
||||
}
|
||||
|
||||
func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
|
||||
r, err := netroute.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice())
|
||||
if err != nil {
|
||||
log.Errorf("getting routes returned an error: %v", err)
|
||||
return nil, errRouteNotFound
|
||||
}
|
||||
|
||||
if gateway == nil {
|
||||
return preferredSrc, nil
|
||||
}
|
||||
|
||||
return gateway, nil
|
||||
}
|
||||
|
@ -8,17 +8,63 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
outIntf, err := getOutgoingInterfaceLinux(prefix.Addr().String())
|
||||
require.NoError(t, err, "getOutgoingInterfaceLinux should not return error")
|
||||
if invert {
|
||||
require.NotEqual(t, wgIface.Name(), outIntf, "outgoing interface should not be the wireguard interface")
|
||||
} else {
|
||||
require.Equal(t, wgIface.Name(), outIntf, "outgoing interface should be the wireguard interface")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
prefixGateway, err := getExistingRIBRouteGateway(prefix)
|
||||
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
|
||||
if invert {
|
||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
||||
} else {
|
||||
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
||||
}
|
||||
}
|
||||
|
||||
func getOutgoingInterfaceLinux(destination string) (string, error) {
|
||||
cmd := exec.Command("ip", "route", "get", destination)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("executing ip route get: %w", err)
|
||||
}
|
||||
|
||||
return parseOutgoingInterface(string(output)), nil
|
||||
}
|
||||
|
||||
func parseOutgoingInterface(routeGetOutput string) string {
|
||||
fields := strings.Fields(routeGetOutput)
|
||||
for i, field := range fields {
|
||||
if field == "dev" && i+1 < len(fields) {
|
||||
return fields[i+1]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TestAddRemoveRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@ -54,23 +100,26 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String())
|
||||
require.NoError(t, setupRouting())
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
})
|
||||
|
||||
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name())
|
||||
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
||||
|
||||
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
|
||||
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
|
||||
if testCase.shouldRouteToWireguard {
|
||||
require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
||||
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
|
||||
} else {
|
||||
require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface")
|
||||
assertWGOutInterface(t, testCase.prefix, wgInterface, true)
|
||||
}
|
||||
exists, err := existsInRouteTable(testCase.prefix)
|
||||
require.NoError(t, err, "existsInRouteTable should not return err")
|
||||
if exists && testCase.shouldRouteToWireguard {
|
||||
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String())
|
||||
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name())
|
||||
require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err")
|
||||
|
||||
prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix)
|
||||
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
|
||||
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
|
||||
|
||||
internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
||||
@ -189,16 +238,21 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
require.NoError(t, setupRouting())
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
})
|
||||
|
||||
MockAddr := wgInterface.Address().IP.String()
|
||||
|
||||
// Prepare the environment
|
||||
if testCase.preExistingPrefix.IsValid() {
|
||||
err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr)
|
||||
err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr, wgInterface.Name())
|
||||
require.NoError(t, err, "should not return err when adding pre-existing route")
|
||||
}
|
||||
|
||||
// Add the route
|
||||
err = addToRouteTableIfNoExists(testCase.prefix, MockAddr)
|
||||
err = addToRouteTableIfNoExists(testCase.prefix, MockAddr, wgInterface.Name())
|
||||
require.NoError(t, err, "should not return err when adding route")
|
||||
|
||||
if testCase.shouldAddRoute {
|
||||
@ -208,7 +262,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
||||
require.True(t, ok, "route should exist")
|
||||
|
||||
// remove route again if added
|
||||
err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr)
|
||||
err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr, wgInterface.Name())
|
||||
require.NoError(t, err, "should not return err")
|
||||
}
|
||||
|
||||
@ -217,72 +271,12 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
||||
ok, err := existsInRouteTable(testCase.prefix)
|
||||
t.Log("Buffer string: ", buf.String())
|
||||
require.NoError(t, err, "should not return err")
|
||||
if !strings.Contains(buf.String(), "because it already exists") {
|
||||
|
||||
// Linux uses a separate routing table, so the route can exist in both tables.
|
||||
// The main routing table takes precedence over the wireguard routing table.
|
||||
if !strings.Contains(buf.String(), "because it already exists") && runtime.GOOS != "linux" {
|
||||
require.False(t, ok, "route should not exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistsInRouteTable(t *testing.T) {
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var addressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
if p.Addr().Is4() {
|
||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range addressPrefixes {
|
||||
exists, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("address %s should exist in route table", prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSubRange(t *testing.T) {
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var subRangeAddressPrefixes []netip.Prefix
|
||||
var nonSubRangeAddressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
|
||||
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
|
||||
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
|
||||
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range subRangeAddressPrefixes {
|
||||
isSubRangePrefix, err := isSubRange(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
||||
}
|
||||
if !isSubRangePrefix {
|
||||
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range nonSubRangeAddressPrefixes {
|
||||
isSubRangePrefix, err := isSubRange(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
||||
}
|
||||
if isSubRangePrefix {
|
||||
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,41 +1,22 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
//go:build !linux || android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, addr string) error {
|
||||
cmd := exec.Command("route", "add", prefix.String(), addr)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf(string(out))
|
||||
func setupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, addr string) error {
|
||||
args := []string{"delete", prefix.String()}
|
||||
if runtime.GOOS == "darwin" {
|
||||
args = append(args, addr)
|
||||
}
|
||||
cmd := exec.Command("route", args...)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf(string(out))
|
||||
func cleanupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableIPForwarding() error {
|
||||
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
80
client/internal/routemanager/systemops_nonlinux_test.go
Normal file
80
client/internal/routemanager/systemops_nonlinux_test.go
Normal file
@ -0,0 +1,80 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsSubRange(t *testing.T) {
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var subRangeAddressPrefixes []netip.Prefix
|
||||
var nonSubRangeAddressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
|
||||
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
|
||||
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
|
||||
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range subRangeAddressPrefixes {
|
||||
isSubRangePrefix, err := isSubRange(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
||||
}
|
||||
if !isSubRangePrefix {
|
||||
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range nonSubRangeAddressPrefixes {
|
||||
isSubRangePrefix, err := isSubRange(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
||||
}
|
||||
if isSubRangePrefix {
|
||||
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistsInRouteTable(t *testing.T) {
|
||||
require.NoError(t, setupRouting())
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
})
|
||||
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var addressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
if p.Addr().Is4() {
|
||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range addressPrefixes {
|
||||
exists, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("address %s should exist in route table", prefix)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,12 +1,13 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/yusufpapurcu/wmi"
|
||||
)
|
||||
|
||||
@ -21,17 +22,19 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
|
||||
err := wmi.Query(query, &routes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
|
||||
var prefixList []netip.Prefix
|
||||
for _, route := range routes {
|
||||
addr, err := netip.ParseAddr(route.Destination)
|
||||
if err != nil {
|
||||
log.Warnf("Unable to parse route destination %s: %v", route.Destination, err)
|
||||
continue
|
||||
}
|
||||
maskSlice := net.ParseIP(route.Mask).To4()
|
||||
if maskSlice == nil {
|
||||
log.Warnf("Unable to parse route mask %s", route.Mask)
|
||||
continue
|
||||
}
|
||||
mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3])
|
||||
@ -44,3 +47,11 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
}
|
||||
return prefixList, nil
|
||||
}
|
||||
|
||||
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error {
|
||||
return genericAddToRouteTableIfNoExists(prefix, addr, intf)
|
||||
}
|
||||
|
||||
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error {
|
||||
return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf)
|
||||
}
|
||||
|
24
client/internal/stdnet/dialer.go
Normal file
24
client/internal/stdnet/dialer.go
Normal file
@ -0,0 +1,24 @@
|
||||
package stdnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// Dial connects to the address on the named network.
|
||||
func (n *Net) Dial(network, address string) (net.Conn, error) {
|
||||
return nbnet.NewDialer().Dial(network, address)
|
||||
}
|
||||
|
||||
// DialUDP connects to the address on the named UDP network.
|
||||
func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||
return nbnet.DialUDP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
// DialTCP connects to the address on the named TCP network.
|
||||
func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) {
|
||||
return nbnet.DialTCP(network, laddr, raddr)
|
||||
}
|
20
client/internal/stdnet/listener.go
Normal file
20
client/internal/stdnet/listener.go
Normal file
@ -0,0 +1,20 @@
|
||||
package stdnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// ListenPacket listens for incoming packets on the given network and address.
|
||||
func (n *Net) ListenPacket(network, address string) (net.PacketConn, error) {
|
||||
return nbnet.NewListener().ListenPacket(context.Background(), network, address)
|
||||
}
|
||||
|
||||
// ListenUDP acts like ListenPacket for UDP networks.
|
||||
func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||
return nbnet.ListenUDP(network, locAddr)
|
||||
}
|
@ -1,8 +1,10 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -23,7 +25,7 @@ func (pl portLookup) searchFreePort() (int, error) {
|
||||
}
|
||||
|
||||
func (pl portLookup) tryToBind(port int) error {
|
||||
l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port))
|
||||
l, err := nbnet.NewListener().ListenPacket(context.Background(), "udp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
@ -66,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
|
||||
p.conn, err = net.ListenUDP("udp", &addr)
|
||||
p.conn, err = nbnet.ListenUDP("udp", &addr)
|
||||
if err != nil {
|
||||
cErr := p.Free()
|
||||
if cErr != nil {
|
||||
@ -208,20 +209,41 @@ generatePort:
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
return net.FilePacketConn(os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)))
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
|
||||
|
@ -6,6 +6,8 @@ import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// WGUserSpaceProxy proxies
|
||||
@ -33,7 +35,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
var err error
|
||||
p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||
return nil, err
|
||||
|
4
go.mod
4
go.mod
@ -47,8 +47,9 @@ require (
|
||||
github.com/google/go-cmp v0.5.9
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
||||
github.com/gopacket/gopacket v1.1.1
|
||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
||||
github.com/hashicorp/go-multierror v1.1.0
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||
github.com/hashicorp/go-version v1.6.0
|
||||
github.com/libp2p/go-netroute v0.2.0
|
||||
@ -123,7 +124,6 @@ require (
|
||||
github.com/google/s2a-go v0.1.4 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.10.0 // indirect
|
||||
github.com/gopacket/gopacket v1.1.1 // indirect
|
||||
github.com/hashicorp/errwrap v1.0.0 // indirect
|
||||
github.com/hashicorp/go-uuid v1.0.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
|
4
go.sum
4
go.sum
@ -291,8 +291,8 @@ github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f2
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-multierror v1.1.0 h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI=
|
||||
github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA=
|
||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng=
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=
|
||||
github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE=
|
||||
|
@ -23,6 +23,24 @@ func parseWGAddress(address string) (WGAddress, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Masked returns the WGAddress with the IP address part masked according to its network mask.
|
||||
func (addr WGAddress) Masked() WGAddress {
|
||||
ip := addr.IP.To4()
|
||||
if ip == nil {
|
||||
ip = addr.IP.To16()
|
||||
}
|
||||
|
||||
maskedIP := make(net.IP, len(ip))
|
||||
for i := range ip {
|
||||
maskedIP[i] = ip[i] & addr.Network.Mask[i]
|
||||
}
|
||||
|
||||
return WGAddress{
|
||||
IP: maskedIP,
|
||||
Network: addr.Network,
|
||||
}
|
||||
}
|
||||
|
||||
func (addr WGAddress) String() string {
|
||||
maskSize, _ := addr.Network.Mask.Size()
|
||||
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
||||
|
@ -10,6 +10,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type wgKernelConfigurer struct {
|
||||
@ -29,7 +31,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fwmark := 0
|
||||
fwmark := nbnet.NetbirdFwmark
|
||||
config := wgtypes.Config{
|
||||
PrivateKey: &key,
|
||||
ReplacePeers: true,
|
||||
|
@ -13,6 +13,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type wgUSPConfigurer struct {
|
||||
@ -37,7 +39,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fwmark := 0
|
||||
fwmark := getFwmark()
|
||||
config := wgtypes.Config{
|
||||
PrivateKey: &key,
|
||||
ReplacePeers: true,
|
||||
@ -345,3 +347,10 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func getFwmark() int {
|
||||
if runtime.GOOS == "linux" {
|
||||
return nbnet.NetbirdFwmark
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
nbgrpc "github.com/netbirdio/netbird/util/grpc"
|
||||
)
|
||||
|
||||
const ConnectTimeout = 10 * time.Second
|
||||
@ -57,6 +58,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
|
||||
mgmCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
nbgrpc.WithCustomDialer(),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
|
@ -21,6 +21,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// ErrSharedSockStopped indicates that shared socket has been stopped
|
||||
@ -55,8 +57,7 @@ var writeSerializerOptions = gopacket.SerializeOptions{
|
||||
}
|
||||
|
||||
// Listen creates an IPv4 and IPv6 raw sockets, starts a reader and routing table routines
|
||||
func Listen(port int, filter BPFFilter) (net.PacketConn, error) {
|
||||
var err error
|
||||
func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
rawSock := &SharedSocket{
|
||||
ctx: ctx,
|
||||
@ -65,37 +66,51 @@ func Listen(port int, filter BPFFilter) (net.PacketConn, error) {
|
||||
packetDemux: make(chan rcvdPacket),
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if closeErr := rawSock.Close(); closeErr != nil {
|
||||
log.Errorf("Failed to close raw socket: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
rawSock.router, err = netroute.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create raw socket router: %v", err)
|
||||
return nil, fmt.Errorf("failed to create raw socket router: %w", err)
|
||||
}
|
||||
|
||||
rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create ipv4 raw socket: %v", err)
|
||||
return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err)
|
||||
}
|
||||
|
||||
rawSock.conn6, err = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create ipv6 raw socket: %v", err)
|
||||
if err = nbnet.SetSocketMark(rawSock.conn4); err != nil {
|
||||
return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err)
|
||||
}
|
||||
|
||||
var sockErr error
|
||||
rawSock.conn6, sockErr = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil)
|
||||
if sockErr != nil {
|
||||
log.Errorf("Failed to create ipv6 raw socket: %v", err)
|
||||
} else {
|
||||
if err = nbnet.SetSocketMark(rawSock.conn6); err != nil {
|
||||
return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port))
|
||||
if err != nil {
|
||||
_ = rawSock.Close()
|
||||
return nil, fmt.Errorf("getBPFInstructions failed with: %rawSock", err)
|
||||
return nil, fmt.Errorf("getBPFInstructions failed with: %w", err)
|
||||
}
|
||||
|
||||
err = rawSock.conn4.SetBPF(ipv4Instructions)
|
||||
if err != nil {
|
||||
_ = rawSock.Close()
|
||||
return nil, fmt.Errorf("socket4.SetBPF failed with: %rawSock", err)
|
||||
return nil, fmt.Errorf("socket4.SetBPF failed with: %w", err)
|
||||
}
|
||||
if rawSock.conn6 != nil {
|
||||
err = rawSock.conn6.SetBPF(ipv6Instructions)
|
||||
if err != nil {
|
||||
_ = rawSock.Close()
|
||||
return nil, fmt.Errorf("socket6.SetBPF failed with: %rawSock", err)
|
||||
return nil, fmt.Errorf("socket6.SetBPF failed with: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -121,7 +136,7 @@ func (s *SharedSocket) updateRouter() {
|
||||
case <-ticker.C:
|
||||
router, err := netroute.New()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create and update packet router for stunListener: %s", err)
|
||||
log.Errorf("Failed to create and update packet router for stunListener: %s", err)
|
||||
continue
|
||||
}
|
||||
s.routerMux.Lock()
|
||||
@ -144,7 +159,7 @@ func (s *SharedSocket) LocalAddr() net.Addr {
|
||||
func (s *SharedSocket) SetDeadline(t time.Time) error {
|
||||
err := s.conn4.SetDeadline(t)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.conn4.SetDeadline error: %s", err)
|
||||
return fmt.Errorf("s.conn4.SetDeadline error: %w", err)
|
||||
}
|
||||
if s.conn6 == nil {
|
||||
return nil
|
||||
@ -152,7 +167,7 @@ func (s *SharedSocket) SetDeadline(t time.Time) error {
|
||||
|
||||
err = s.conn6.SetDeadline(t)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.conn6.SetDeadline error: %s", err)
|
||||
return fmt.Errorf("s.conn6.SetDeadline error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -161,7 +176,7 @@ func (s *SharedSocket) SetDeadline(t time.Time) error {
|
||||
func (s *SharedSocket) SetReadDeadline(t time.Time) error {
|
||||
err := s.conn4.SetReadDeadline(t)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.conn4.SetReadDeadline error: %s", err)
|
||||
return fmt.Errorf("s.conn4.SetReadDeadline error: %w", err)
|
||||
}
|
||||
if s.conn6 == nil {
|
||||
return nil
|
||||
@ -169,7 +184,7 @@ func (s *SharedSocket) SetReadDeadline(t time.Time) error {
|
||||
|
||||
err = s.conn6.SetReadDeadline(t)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.conn6.SetReadDeadline error: %s", err)
|
||||
return fmt.Errorf("s.conn6.SetReadDeadline error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -178,7 +193,7 @@ func (s *SharedSocket) SetReadDeadline(t time.Time) error {
|
||||
func (s *SharedSocket) SetWriteDeadline(t time.Time) error {
|
||||
err := s.conn4.SetWriteDeadline(t)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.conn4.SetWriteDeadline error: %s", err)
|
||||
return fmt.Errorf("s.conn4.SetWriteDeadline error: %w", err)
|
||||
}
|
||||
if s.conn6 == nil {
|
||||
return nil
|
||||
@ -186,7 +201,7 @@ func (s *SharedSocket) SetWriteDeadline(t time.Time) error {
|
||||
|
||||
err = s.conn6.SetWriteDeadline(t)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.conn6.SetWriteDeadline error: %s", err)
|
||||
return fmt.Errorf("s.conn6.SetWriteDeadline error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -282,7 +297,7 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||||
|
||||
_, _, src, err := s.router.Route(rUDPAddr.IP)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("got an error while checking route, err: %s", err)
|
||||
return 0, fmt.Errorf("got an error while checking route, err: %w", err)
|
||||
}
|
||||
|
||||
rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP)
|
||||
@ -292,7 +307,7 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||||
}
|
||||
|
||||
if err := gopacket.SerializeLayers(buffer, writeSerializerOptions, udp, payload); err != nil {
|
||||
return -1, fmt.Errorf("failed serialize rcvdPacket: %s", err)
|
||||
return -1, fmt.Errorf("failed serialize rcvdPacket: %w", err)
|
||||
}
|
||||
|
||||
bufser := buffer.Bytes()
|
||||
|
@ -23,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
nbgrpc "github.com/netbirdio/netbird/util/grpc"
|
||||
)
|
||||
|
||||
// ConnStateNotifier is a wrapper interface of the status recorder
|
||||
@ -76,6 +77,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
|
||||
sigCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
nbgrpc.WithCustomDialer(),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
|
9
util/grpc/dialer_generic.go
Normal file
9
util/grpc/dialer_generic.go
Normal file
@ -0,0 +1,9 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package grpc
|
||||
|
||||
import "google.golang.org/grpc"
|
||||
|
||||
func WithCustomDialer() grpc.DialOption {
|
||||
return grpc.EmptyDialOption{}
|
||||
}
|
18
util/grpc/dialer_linux.go
Normal file
18
util/grpc/dialer_linux.go
Normal file
@ -0,0 +1,18 @@
|
||||
//go:build !android
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func WithCustomDialer() grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||
})
|
||||
}
|
19
util/net/dialer_generic.go
Normal file
19
util/net/dialer_generic.go
Normal file
@ -0,0 +1,19 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func NewDialer() *net.Dialer {
|
||||
return &net.Dialer{}
|
||||
}
|
||||
|
||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
return net.DialUDP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||
return net.DialTCP(network, laddr, raddr)
|
||||
}
|
60
util/net/dialer_linux.go
Normal file
60
util/net/dialer_linux.go
Normal file
@ -0,0 +1,60 @@
|
||||
//go:build !android
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func NewDialer() *net.Dialer {
|
||||
return &net.Dialer{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
return SetRawSocketMark(c)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.DialContext(context.Background(), network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
udpConn, ok := conn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDP connection, got different type")
|
||||
}
|
||||
|
||||
return udpConn, nil
|
||||
}
|
||||
|
||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.DialContext(context.Background(), network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected TCP connection, got different type")
|
||||
}
|
||||
|
||||
return tcpConn, nil
|
||||
}
|
13
util/net/listener_generic.go
Normal file
13
util/net/listener_generic.go
Normal file
@ -0,0 +1,13 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package net
|
||||
|
||||
import "net"
|
||||
|
||||
func NewListener() *net.ListenConfig {
|
||||
return &net.ListenConfig{}
|
||||
}
|
||||
|
||||
func ListenUDP(network string, locAddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
return net.ListenUDP(network, locAddr)
|
||||
}
|
30
util/net/listener_linux.go
Normal file
30
util/net/listener_linux.go
Normal file
@ -0,0 +1,30 @@
|
||||
//go:build !android
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func NewListener() *net.ListenConfig {
|
||||
return &net.ListenConfig{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
return SetRawSocketMark(c)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
pc, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listening on %s:%s with fwmark: %w", network, laddr, err)
|
||||
}
|
||||
udpConn, ok := pc.(*net.UDPConn)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("packetConn is not a *net.UDPConn")
|
||||
}
|
||||
return udpConn, nil
|
||||
}
|
6
util/net/net.go
Normal file
6
util/net/net.go
Normal file
@ -0,0 +1,6 @@
|
||||
package net
|
||||
|
||||
const (
|
||||
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
|
||||
NetbirdFwmark = 0x1BD00
|
||||
)
|
35
util/net/net_linux.go
Normal file
35
util/net/net_linux.go
Normal file
@ -0,0 +1,35 @@
|
||||
//go:build !android
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// SetSocketMark sets the SO_MARK option on the given socket connection
|
||||
func SetSocketMark(conn syscall.Conn) error {
|
||||
sysconn, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get raw conn: %w", err)
|
||||
}
|
||||
|
||||
return SetRawSocketMark(sysconn)
|
||||
}
|
||||
|
||||
func SetRawSocketMark(conn syscall.RawConn) error {
|
||||
var setErr error
|
||||
|
||||
err := conn.Control(func(fd uintptr) {
|
||||
setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("control: %w", err)
|
||||
}
|
||||
|
||||
if setErr != nil {
|
||||
return fmt.Errorf("set SO_MARK: %w", setErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user