mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
Replaces powershell with the route command and cache route lookups on windows (#1880)
This commit is contained in:
parent
71c6437bab
commit
54b045d9ca
@ -3,6 +3,7 @@ package routemanager
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
@ -215,7 +216,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||
|
||||
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
if c.chosenRoute != nil {
|
||||
if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil {
|
||||
if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil {
|
||||
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
|
||||
}
|
||||
|
||||
@ -256,7 +257,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
}
|
||||
} else {
|
||||
// otherwise add the route to the system
|
||||
if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil {
|
||||
if err := addVPNRoute(c.network, c.getAsInterface()); err != nil {
|
||||
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
||||
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
||||
}
|
||||
@ -344,3 +345,15 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) getAsInterface() *net.Interface {
|
||||
intf, err := net.InterfaceByName(c.wgInterface.Name())
|
||||
if err != nil {
|
||||
log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err)
|
||||
intf = &net.Interface{
|
||||
Name: c.wgInterface.Name(),
|
||||
}
|
||||
}
|
||||
|
||||
return intf
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ package routemanager
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
@ -17,7 +18,7 @@ import (
|
||||
type ref struct {
|
||||
count int
|
||||
nexthop netip.Addr
|
||||
intf string
|
||||
intf *net.Interface
|
||||
}
|
||||
|
||||
type RouteManager struct {
|
||||
@ -30,8 +31,8 @@ type RouteManager struct {
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error)
|
||||
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error
|
||||
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error)
|
||||
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error
|
||||
|
||||
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
|
||||
// TODO: read initial routing table into refCountMap
|
||||
|
@ -60,17 +60,13 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exitIntf string
|
||||
gatewayHop, intf, err := getNextHop(defaultGateway)
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||
}
|
||||
if intf != nil {
|
||||
exitIntf = intf.Name
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
||||
return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf)
|
||||
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
|
||||
}
|
||||
|
||||
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
||||
@ -84,7 +80,7 @@ func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
||||
return netip.Addr{}, nil, ErrRouteNotFound
|
||||
}
|
||||
|
||||
log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
||||
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
||||
if gateway == nil {
|
||||
if preferredSrc == nil {
|
||||
return netip.Addr{}, nil, ErrRouteNotFound
|
||||
@ -153,12 +149,7 @@ func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||
|
||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||
func addRouteToNonVPNIntf(
|
||||
prefix netip.Prefix,
|
||||
vpnIntf *iface.WGIface,
|
||||
initialNextHop netip.Addr,
|
||||
initialIntf *net.Interface,
|
||||
) (netip.Addr, string, error) {
|
||||
func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
|
||||
addr := prefix.Addr()
|
||||
switch {
|
||||
case addr.IsLoopback(),
|
||||
@ -168,39 +159,34 @@ func addRouteToNonVPNIntf(
|
||||
addr.IsUnspecified(),
|
||||
addr.IsMulticast():
|
||||
|
||||
return netip.Addr{}, "", ErrRouteNotAllowed
|
||||
return netip.Addr{}, nil, ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||
nexthop, intf, err := getNextHop(addr)
|
||||
if err != nil {
|
||||
return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err)
|
||||
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
|
||||
exitNextHop := nexthop
|
||||
var exitIntf string
|
||||
if intf != nil {
|
||||
exitIntf = intf.Name
|
||||
}
|
||||
exitIntf := intf
|
||||
|
||||
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
||||
if !ok {
|
||||
return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr")
|
||||
return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
|
||||
}
|
||||
|
||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||
if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() {
|
||||
if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
|
||||
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
|
||||
exitNextHop = initialNextHop
|
||||
if initialIntf != nil {
|
||||
exitIntf = initialIntf.Name
|
||||
}
|
||||
exitIntf = initialIntf
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
|
||||
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
|
||||
return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err)
|
||||
return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
|
||||
}
|
||||
|
||||
return exitNextHop, exitIntf, nil
|
||||
@ -208,7 +194,7 @@ func addRouteToNonVPNIntf(
|
||||
|
||||
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||
// in two /1 prefixes to avoid replacing the existing default route
|
||||
func genericAddVPNRoute(prefix netip.Prefix, intf string) error {
|
||||
func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if prefix == defaultv4 {
|
||||
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||
return err
|
||||
@ -250,7 +236,7 @@ func genericAddVPNRoute(prefix netip.Prefix, intf string) error {
|
||||
}
|
||||
|
||||
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
||||
func addNonExistingRoute(prefix netip.Prefix, intf string) error {
|
||||
func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
ok, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exists in route table: %w", err)
|
||||
@ -277,7 +263,7 @@ func addNonExistingRoute(prefix netip.Prefix, intf string) error {
|
||||
|
||||
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
||||
// it will remove the split /1 prefixes
|
||||
func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error {
|
||||
func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if prefix == defaultv4 {
|
||||
var result *multierror.Error
|
||||
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||
@ -343,7 +329,7 @@ func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []n
|
||||
}
|
||||
|
||||
*routeManager = NewRouteManager(
|
||||
func(prefix netip.Prefix) (netip.Addr, string, error) {
|
||||
func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
|
||||
addr := prefix.Addr()
|
||||
nexthop, intf := initialNextHopV4, initialIntfV4
|
||||
if addr.Is6() {
|
||||
|
@ -24,10 +24,10 @@ func enableIPForwarding() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(netip.Prefix, string) error {
|
||||
func addVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeVPNRoute(netip.Prefix, string) error {
|
||||
func removeVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -27,15 +27,15 @@ func cleanupRouting() error {
|
||||
return cleanupRoutingWithRouteManager(routeManager)
|
||||
}
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return routeCmd("add", prefix, nexthop, intf)
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return routeCmd("delete", prefix, nexthop, intf)
|
||||
}
|
||||
|
||||
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
inet := "-inet"
|
||||
network := prefix.String()
|
||||
if prefix.IsSingleIP() {
|
||||
@ -46,15 +46,15 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf strin
|
||||
// Special case for IPv6 split default route, pointing to the wg interface fails
|
||||
// TODO: Remove once we have IPv6 support on the interface
|
||||
if prefix.Bits() == 1 {
|
||||
intf = "lo0"
|
||||
intf = &net.Interface{Name: "lo0"}
|
||||
}
|
||||
}
|
||||
|
||||
args := []string{"-n", action, inet, network}
|
||||
if nexthop.IsValid() {
|
||||
args = append(args, nexthop.Unmap().String())
|
||||
} else if intf != "" {
|
||||
args = append(args, "-interface", intf)
|
||||
} else if intf != nil {
|
||||
args = append(args, "-interface", intf.Name)
|
||||
}
|
||||
|
||||
if err := retryRouteCmd(args); err != nil {
|
||||
|
@ -33,7 +33,7 @@ func init() {
|
||||
|
||||
func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
intf := "lo0"
|
||||
intf := &net.Interface{Name: "lo0"}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
|
@ -24,10 +24,10 @@ func enableIPForwarding() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(netip.Prefix, string) error {
|
||||
func addVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeVPNRoute(netip.Prefix, string) error {
|
||||
func removeVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -46,9 +46,6 @@ var routeManager = &RouteManager{}
|
||||
// originalSysctl stores the original sysctl values before they are modified
|
||||
var originalSysctl map[string]int
|
||||
|
||||
// determines whether to use the legacy routing setup
|
||||
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
|
||||
|
||||
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
|
||||
var sysctlFailed bool
|
||||
|
||||
@ -62,6 +59,20 @@ type ruleParams struct {
|
||||
description string
|
||||
}
|
||||
|
||||
// isLegacy determines whether to use the legacy routing setup
|
||||
func isLegacy() bool {
|
||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
|
||||
}
|
||||
|
||||
// setIsLegacy sets the legacy routing setup
|
||||
func setIsLegacy(b bool) {
|
||||
if b {
|
||||
os.Setenv("NB_USE_LEGACY_ROUTING", "true")
|
||||
} else {
|
||||
os.Unsetenv("NB_USE_LEGACY_ROUTING")
|
||||
}
|
||||
}
|
||||
|
||||
func getSetupRules() []ruleParams {
|
||||
return []ruleParams{
|
||||
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
|
||||
@ -82,7 +93,7 @@ func getSetupRules() []ruleParams {
|
||||
// This table is where a default route or other specific routes received from the management server are configured,
|
||||
// enabling VPN connectivity.
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
|
||||
if isLegacy {
|
||||
if isLegacy() {
|
||||
log.Infof("Using legacy routing setup")
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
}
|
||||
@ -111,7 +122,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
|
||||
if err := addRule(rule); err != nil {
|
||||
if errors.Is(err, syscall.EOPNOTSUPP) {
|
||||
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
|
||||
isLegacy = true
|
||||
setIsLegacy(true)
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
||||
@ -125,7 +136,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
|
||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||
func cleanupRouting() error {
|
||||
if isLegacy {
|
||||
if isLegacy() {
|
||||
return cleanupRoutingWithRouteManager(routeManager)
|
||||
}
|
||||
|
||||
@ -154,16 +165,16 @@ func cleanupRouting() error {
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||
}
|
||||
|
||||
func addVPNRoute(prefix netip.Prefix, intf string) error {
|
||||
if isLegacy {
|
||||
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if isLegacy() {
|
||||
return genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
@ -185,8 +196,8 @@ func addVPNRoute(prefix netip.Prefix, intf string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeVPNRoute(prefix netip.Prefix, intf string) error {
|
||||
if isLegacy {
|
||||
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if isLegacy() {
|
||||
return genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
@ -244,7 +255,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
|
||||
}
|
||||
|
||||
// addRoute adds a route to a specific routing table identified by tableID.
|
||||
func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
|
||||
func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
|
||||
route := &netlink.Route{
|
||||
Scope: netlink.SCOPE_UNIVERSE,
|
||||
Table: tableID,
|
||||
@ -316,7 +327,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||
}
|
||||
|
||||
// removeRoute removes a route from a specific routing table identified by tableID.
|
||||
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
|
||||
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||
@ -470,21 +481,23 @@ func removeRule(params ruleParams) error {
|
||||
}
|
||||
|
||||
// addNextHop adds the gateway and device to the route.
|
||||
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
|
||||
if addr.IsValid() {
|
||||
route.Gw = addr.AsSlice()
|
||||
if intf == "" {
|
||||
intf = addr.Zone()
|
||||
}
|
||||
func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error {
|
||||
if intf != nil {
|
||||
route.LinkIndex = intf.Index
|
||||
}
|
||||
|
||||
if intf != "" {
|
||||
link, err := netlink.LinkByName(intf)
|
||||
if addr.IsValid() {
|
||||
route.Gw = addr.AsSlice()
|
||||
|
||||
// if zone is set, it means the gateway is a link-local address, so we set the link index
|
||||
if addr.Zone() != "" && intf == nil {
|
||||
link, err := netlink.LinkByName(addr.Zone())
|
||||
if err != nil {
|
||||
return fmt.Errorf("set interface %s: %w", intf, err)
|
||||
return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err)
|
||||
}
|
||||
route.LinkIndex = link.Attrs().Index
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -3,6 +3,7 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
@ -14,10 +15,10 @@ func enableIPForwarding() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(prefix netip.Prefix, intf string) error {
|
||||
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
func removeVPNRoute(prefix netip.Prefix, intf string) error {
|
||||
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
|
@ -50,6 +50,8 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
@ -67,7 +69,11 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
})
|
||||
|
||||
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
err = addVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "genericAddVPNRoute should not return err")
|
||||
|
||||
if testCase.shouldRouteToWireguard {
|
||||
@ -78,7 +84,7 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
exists, err := existsInRouteTable(testCase.prefix)
|
||||
require.NoError(t, err, "existsInRouteTable should not return err")
|
||||
if exists && testCase.shouldRouteToWireguard {
|
||||
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
|
||||
err = removeVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
||||
|
||||
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
|
||||
@ -182,12 +188,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
|
||||
var buf bytes.Buffer
|
||||
log.SetOutput(&buf)
|
||||
defer func() {
|
||||
log.SetOutput(os.Stderr)
|
||||
}()
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
|
||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
@ -200,14 +210,18 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
// Prepare the environment
|
||||
if testCase.preExistingPrefix.IsValid() {
|
||||
err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name())
|
||||
err := addVPNRoute(testCase.preExistingPrefix, intf)
|
||||
require.NoError(t, err, "should not return err when adding pre-existing route")
|
||||
}
|
||||
|
||||
// Add the route
|
||||
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
|
||||
err = addVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "should not return err when adding route")
|
||||
|
||||
if testCase.shouldAddRoute {
|
||||
@ -217,7 +231,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
require.True(t, ok, "route should exist")
|
||||
|
||||
// remove route again if added
|
||||
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
|
||||
err = removeVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "should not return err")
|
||||
}
|
||||
|
||||
@ -345,43 +359,47 @@ func setupTestEnv(t *testing.T) {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgIface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgIface.Name()}
|
||||
|
||||
// default route exists in main table and vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
|
||||
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
|
||||
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
|
||||
// 10.0.0.0/8 route exists in main table and vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
|
||||
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
|
||||
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
|
||||
// 10.10.0.0/24 more specific route exists in vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
|
||||
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
|
||||
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
|
||||
// 127.0.10.0/24 more specific route exists in vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
|
||||
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
|
||||
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
|
||||
// unique route in vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
|
||||
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
|
||||
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
}
|
||||
|
@ -6,8 +6,12 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/yusufpapurcu/wmi"
|
||||
@ -21,6 +25,10 @@ type Win32_IP4RouteTable struct {
|
||||
Mask string
|
||||
}
|
||||
|
||||
var prefixList []netip.Prefix
|
||||
var lastUpdate time.Time
|
||||
var mux = sync.Mutex{}
|
||||
|
||||
var routeManager *RouteManager
|
||||
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
@ -32,15 +40,23 @@ func cleanupRouting() error {
|
||||
}
|
||||
|
||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
var routes []Win32_IP4RouteTable
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
|
||||
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
|
||||
|
||||
// If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result
|
||||
if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second {
|
||||
return prefixList, nil
|
||||
}
|
||||
|
||||
var routes []Win32_IP4RouteTable
|
||||
err := wmi.Query(query, &routes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
|
||||
var prefixList []netip.Prefix
|
||||
prefixList = nil
|
||||
for _, route := range routes {
|
||||
addr, err := netip.ParseAddr(route.Destination)
|
||||
if err != nil {
|
||||
@ -60,54 +76,29 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
prefixList = append(prefixList, routePrefix)
|
||||
}
|
||||
}
|
||||
|
||||
lastUpdate = time.Now()
|
||||
return prefixList, nil
|
||||
}
|
||||
|
||||
func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error {
|
||||
destinationPrefix := prefix.String()
|
||||
psCmd := "New-NetRoute"
|
||||
|
||||
addressFamily := "IPv4"
|
||||
if prefix.Addr().Is6() {
|
||||
addressFamily = "IPv6"
|
||||
}
|
||||
|
||||
script := fmt.Sprintf(
|
||||
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`,
|
||||
psCmd, addressFamily, destinationPrefix,
|
||||
)
|
||||
|
||||
if intfIdx != "" {
|
||||
script = fmt.Sprintf(
|
||||
`%s -InterfaceIndex %s`, script, intfIdx,
|
||||
)
|
||||
} else {
|
||||
script = fmt.Sprintf(
|
||||
`%s -InterfaceAlias "%s"`, script, intf,
|
||||
)
|
||||
}
|
||||
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
args := []string{"add", prefix.String()}
|
||||
|
||||
if nexthop.IsValid() {
|
||||
script = fmt.Sprintf(
|
||||
`%s -NextHop "%s"`, script, nexthop,
|
||||
)
|
||||
args = append(args, nexthop.Unmap().String())
|
||||
} else {
|
||||
addr := "0.0.0.0"
|
||||
if prefix.Addr().Is6() {
|
||||
addr = "::"
|
||||
}
|
||||
args = append(args, addr)
|
||||
}
|
||||
|
||||
out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||
log.Tracef("PowerShell %s: %s", script, string(out))
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("PowerShell add route: %w", err)
|
||||
if intf != nil {
|
||||
args = append(args, "if", strconv.Itoa(intf.Index))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
|
||||
args := []string{"add", prefix.String(), nexthop.Unmap().String()}
|
||||
|
||||
out, err := exec.Command("route", args...).CombinedOutput()
|
||||
|
||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("route add: %w", err)
|
||||
@ -116,21 +107,20 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||
var intfIdx string
|
||||
if nexthop.Zone() != "" {
|
||||
intfIdx = nexthop.Zone()
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
if nexthop.Zone() != "" && intf == nil {
|
||||
zone, err := strconv.Atoi(nexthop.Zone())
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid zone: %w", err)
|
||||
}
|
||||
intf = &net.Interface{Index: zone}
|
||||
nexthop.WithZone("")
|
||||
}
|
||||
|
||||
// Powershell doesn't support adding routes without an interface but allows to add interface by name
|
||||
if intf != "" || intfIdx != "" {
|
||||
return addRoutePowershell(prefix, nexthop, intf, intfIdx)
|
||||
}
|
||||
return addRouteCmd(prefix, nexthop, intf)
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
|
||||
args := []string{"delete", prefix.String()}
|
||||
if nexthop.IsValid() {
|
||||
nexthop.WithZone("")
|
||||
@ -145,3 +135,7 @@ func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isCacheDisabled() bool {
|
||||
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user