Replaces powershell with the route command and cache route lookups on windows (#1880)

This commit is contained in:
Viktor Liu 2024-04-26 16:37:27 +02:00 committed by GitHub
parent 71c6437bab
commit 54b045d9ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 161 additions and 135 deletions

View File

@ -3,6 +3,7 @@ package routemanager
import (
"context"
"fmt"
"net"
"net/netip"
"time"
@ -215,7 +216,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
if c.chosenRoute != nil {
if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil {
if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil {
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
}
@ -256,7 +257,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
}
} else {
// otherwise add the route to the system
if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil {
if err := addVPNRoute(c.network, c.getAsInterface()); err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.network.String(), c.wgInterface.Address().IP.String(), err)
}
@ -344,3 +345,15 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
}
}
}
func (c *clientNetwork) getAsInterface() *net.Interface {
intf, err := net.InterfaceByName(c.wgInterface.Name())
if err != nil {
log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err)
intf = &net.Interface{
Name: c.wgInterface.Name(),
}
}
return intf
}

View File

@ -5,6 +5,7 @@ package routemanager
import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
@ -17,7 +18,7 @@ import (
type ref struct {
count int
nexthop netip.Addr
intf string
intf *net.Interface
}
type RouteManager struct {
@ -30,8 +31,8 @@ type RouteManager struct {
mutex sync.Mutex
}
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error)
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error)
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
// TODO: read initial routing table into refCountMap

View File

@ -60,17 +60,13 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
return nil
}
var exitIntf string
gatewayHop, intf, err := getNextHop(defaultGateway)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
if intf != nil {
exitIntf = intf.Name
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf)
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
}
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
@ -84,7 +80,7 @@ func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
return netip.Addr{}, nil, ErrRouteNotFound
}
log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if preferredSrc == nil {
return netip.Addr{}, nil, ErrRouteNotFound
@ -153,12 +149,7 @@ func isSubRange(prefix netip.Prefix) (bool, error) {
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func addRouteToNonVPNIntf(
prefix netip.Prefix,
vpnIntf *iface.WGIface,
initialNextHop netip.Addr,
initialIntf *net.Interface,
) (netip.Addr, string, error) {
func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
@ -168,39 +159,34 @@ func addRouteToNonVPNIntf(
addr.IsUnspecified(),
addr.IsMulticast():
return netip.Addr{}, "", ErrRouteNotAllowed
return netip.Addr{}, nil, ErrRouteNotAllowed
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, intf, err := getNextHop(addr)
if err != nil {
return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err)
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
exitNextHop := nexthop
var exitIntf string
if intf != nil {
exitIntf = intf.Name
}
exitIntf := intf
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr")
return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() {
if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
exitNextHop = initialNextHop
if initialIntf != nil {
exitIntf = initialIntf.Name
}
exitIntf = initialIntf
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err)
return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, exitIntf, nil
@ -208,7 +194,7 @@ func addRouteToNonVPNIntf(
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func genericAddVPNRoute(prefix netip.Prefix, intf string) error {
func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 {
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
return err
@ -250,7 +236,7 @@ func genericAddVPNRoute(prefix netip.Prefix, intf string) error {
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func addNonExistingRoute(prefix netip.Prefix, intf string) error {
func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
@ -277,7 +263,7 @@ func addNonExistingRoute(prefix netip.Prefix, intf string) error {
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error {
func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
@ -343,7 +329,7 @@ func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []n
}
*routeManager = NewRouteManager(
func(prefix netip.Prefix) (netip.Addr, string, error) {
func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr()
nexthop, intf := initialNextHopV4, initialIntfV4
if addr.Is6() {

View File

@ -24,10 +24,10 @@ func enableIPForwarding() error {
return nil
}
func addVPNRoute(netip.Prefix, string) error {
func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func removeVPNRoute(netip.Prefix, string) error {
func removeVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}

View File

@ -27,15 +27,15 @@ func cleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager)
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return routeCmd("add", prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return routeCmd("delete", prefix, nexthop, intf)
}
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error {
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
inet := "-inet"
network := prefix.String()
if prefix.IsSingleIP() {
@ -46,15 +46,15 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf strin
// Special case for IPv6 split default route, pointing to the wg interface fails
// TODO: Remove once we have IPv6 support on the interface
if prefix.Bits() == 1 {
intf = "lo0"
intf = &net.Interface{Name: "lo0"}
}
}
args := []string{"-n", action, inet, network}
if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String())
} else if intf != "" {
args = append(args, "-interface", intf)
} else if intf != nil {
args = append(args, "-interface", intf.Name)
}
if err := retryRouteCmd(args); err != nil {

View File

@ -33,7 +33,7 @@ func init() {
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := "lo0"
intf := &net.Interface{Name: "lo0"}
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {

View File

@ -24,10 +24,10 @@ func enableIPForwarding() error {
return nil
}
func addVPNRoute(netip.Prefix, string) error {
func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func removeVPNRoute(netip.Prefix, string) error {
func removeVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}

View File

@ -46,9 +46,6 @@ var routeManager = &RouteManager{}
// originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int
// determines whether to use the legacy routing setup
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
var sysctlFailed bool
@ -62,6 +59,20 @@ type ruleParams struct {
description string
}
// isLegacy determines whether to use the legacy routing setup
func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
}
// setIsLegacy sets the legacy routing setup
func setIsLegacy(b bool) {
if b {
os.Setenv("NB_USE_LEGACY_ROUTING", "true")
} else {
os.Unsetenv("NB_USE_LEGACY_ROUTING")
}
}
func getSetupRules() []ruleParams {
return []ruleParams{
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
@ -82,7 +93,7 @@ func getSetupRules() []ruleParams {
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy {
if isLegacy() {
log.Infof("Using legacy routing setup")
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
@ -111,7 +122,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
if err := addRule(rule); err != nil {
if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
isLegacy = true
setIsLegacy(true)
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
@ -125,7 +136,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process.
func cleanupRouting() error {
if isLegacy {
if isLegacy() {
return cleanupRoutingWithRouteManager(routeManager)
}
@ -154,16 +165,16 @@ func cleanupRouting() error {
return result.ErrorOrNil()
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
}
func addVPNRoute(prefix netip.Prefix, intf string) error {
if isLegacy {
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() {
return genericAddVPNRoute(prefix, intf)
}
@ -185,8 +196,8 @@ func addVPNRoute(prefix netip.Prefix, intf string) error {
return nil
}
func removeVPNRoute(prefix netip.Prefix, intf string) error {
if isLegacy {
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() {
return genericRemoveVPNRoute(prefix, intf)
}
@ -244,7 +255,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
}
// addRoute adds a route to a specific routing table identified by tableID.
func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
@ -316,7 +327,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
}
// removeRoute removes a route from a specific routing table identified by tableID.
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
@ -470,21 +481,23 @@ func removeRule(params ruleParams) error {
}
// addNextHop adds the gateway and device to the route.
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
if addr.IsValid() {
route.Gw = addr.AsSlice()
if intf == "" {
intf = addr.Zone()
}
func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error {
if intf != nil {
route.LinkIndex = intf.Index
}
if intf != "" {
link, err := netlink.LinkByName(intf)
if addr.IsValid() {
route.Gw = addr.AsSlice()
// if zone is set, it means the gateway is a link-local address, so we set the link index
if addr.Zone() != "" && intf == nil {
link, err := netlink.LinkByName(addr.Zone())
if err != nil {
return fmt.Errorf("set interface %s: %w", intf, err)
return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err)
}
route.LinkIndex = link.Attrs().Index
}
}
return nil
}

View File

@ -3,6 +3,7 @@
package routemanager
import (
"net"
"net/netip"
"runtime"
@ -14,10 +15,10 @@ func enableIPForwarding() error {
return nil
}
func addVPNRoute(prefix netip.Prefix, intf string) error {
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericAddVPNRoute(prefix, intf)
}
func removeVPNRoute(prefix netip.Prefix, intf string) error {
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericRemoveVPNRoute(prefix, intf)
}

View File

@ -50,6 +50,8 @@ func TestAddRemoveRoutes(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
@ -67,7 +69,11 @@ func TestAddRemoveRoutes(t *testing.T) {
assert.NoError(t, cleanupRouting())
})
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
err = addVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericAddVPNRoute should not return err")
if testCase.shouldRouteToWireguard {
@ -78,7 +84,7 @@ func TestAddRemoveRoutes(t *testing.T) {
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
err = removeVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
@ -182,12 +188,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
}
for n, testCase := range testCases {
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
@ -200,14 +210,18 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name())
err := addVPNRoute(testCase.preExistingPrefix, intf)
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
err = addVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
@ -217,7 +231,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.True(t, ok, "route should exist")
// remove route again if added
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
err = removeVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err")
}
@ -345,43 +359,47 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, cleanupRouting())
})
index, err := net.InterfaceByName(wgIface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgIface.Name()}
// default route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.0.0.0/8 route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.10.0.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 127.0.10.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// unique route in vpn table
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
}

View File

@ -6,8 +6,12 @@ import (
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
@ -21,6 +25,10 @@ type Win32_IP4RouteTable struct {
Mask string
}
var prefixList []netip.Prefix
var lastUpdate time.Time
var mux = sync.Mutex{}
var routeManager *RouteManager
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
@ -32,15 +40,23 @@ func cleanupRouting() error {
}
func getRoutesFromTable() ([]netip.Prefix, error) {
var routes []Win32_IP4RouteTable
mux.Lock()
defer mux.Unlock()
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
// If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result
if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second {
return prefixList, nil
}
var routes []Win32_IP4RouteTable
err := wmi.Query(query, &routes)
if err != nil {
return nil, fmt.Errorf("get routes: %w", err)
}
var prefixList []netip.Prefix
prefixList = nil
for _, route := range routes {
addr, err := netip.ParseAddr(route.Destination)
if err != nil {
@ -60,54 +76,29 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
prefixList = append(prefixList, routePrefix)
}
}
lastUpdate = time.Now()
return prefixList, nil
}
func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error {
destinationPrefix := prefix.String()
psCmd := "New-NetRoute"
addressFamily := "IPv4"
if prefix.Addr().Is6() {
addressFamily = "IPv6"
}
script := fmt.Sprintf(
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`,
psCmd, addressFamily, destinationPrefix,
)
if intfIdx != "" {
script = fmt.Sprintf(
`%s -InterfaceIndex %s`, script, intfIdx,
)
} else {
script = fmt.Sprintf(
`%s -InterfaceAlias "%s"`, script, intf,
)
}
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
args := []string{"add", prefix.String()}
if nexthop.IsValid() {
script = fmt.Sprintf(
`%s -NextHop "%s"`, script, nexthop,
)
args = append(args, nexthop.Unmap().String())
} else {
addr := "0.0.0.0"
if prefix.Addr().Is6() {
addr = "::"
}
args = append(args, addr)
}
out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
log.Tracef("PowerShell %s: %s", script, string(out))
if err != nil {
return fmt.Errorf("PowerShell add route: %w", err)
if intf != nil {
args = append(args, "if", strconv.Itoa(intf.Index))
}
return nil
}
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
args := []string{"add", prefix.String(), nexthop.Unmap().String()}
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("route add: %w", err)
@ -116,21 +107,20 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
return nil
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
var intfIdx string
if nexthop.Zone() != "" {
intfIdx = nexthop.Zone()
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
if nexthop.Zone() != "" && intf == nil {
zone, err := strconv.Atoi(nexthop.Zone())
if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
intf = &net.Interface{Index: zone}
nexthop.WithZone("")
}
// Powershell doesn't support adding routes without an interface but allows to add interface by name
if intf != "" || intfIdx != "" {
return addRoutePowershell(prefix, nexthop, intf, intfIdx)
}
return addRouteCmd(prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
args := []string{"delete", prefix.String()}
if nexthop.IsValid() {
nexthop.WithZone("")
@ -145,3 +135,7 @@ func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) err
}
return nil
}
func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
}