mirror of
https://github.com/netbirdio/netbird.git
synced 2025-04-10 18:58:27 +02:00
Linux legacy routing (#1774)
* Add Linux legacy routing if ip rule functionality is not available * Ignore exclusion route errors if host has no route * Exclude iOS from route manager * Also retrieve IPv6 routes * Ignore loopback addresses not being in the main table * Ignore "not supported" errors on cleanup * Fix regression in ListenUDP not using fwmarks
This commit is contained in:
parent
7938295190
commit
bb0d5c5baf
@ -1,8 +1,9 @@
|
|||||||
//go:build !android
|
//go:build !android && !ios
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
@ -53,6 +54,9 @@ func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Pref
|
|||||||
if ref.count == 0 {
|
if ref.count == 0 {
|
||||||
log.Debugf("Adding route for prefix %s", prefix)
|
log.Debugf("Adding route for prefix %s", prefix)
|
||||||
nexthop, intf, err := rm.addRoute(prefix)
|
nexthop, intf, err := rm.addRoute(prefix)
|
||||||
|
if errors.Is(err, errRouteNotFound) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
|
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
|
||||||
}
|
}
|
||||||
|
410
client/internal/routemanager/systemops.go
Normal file
410
client/internal/routemanager/systemops.go
Normal file
@ -0,0 +1,410 @@
|
|||||||
|
//go:build !android && !ios
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/libp2p/go-netroute"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
||||||
|
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||||
|
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||||
|
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||||
|
|
||||||
|
var errRouteNotFound = fmt.Errorf("route not found")
|
||||||
|
|
||||||
|
// TODO: fix: for default our wg address now appears as the default gw
|
||||||
|
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||||
|
addr := netip.IPv4Unspecified()
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
addr = netip.IPv6Unspecified()
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultGateway, _, err := getNextHop(addr)
|
||||||
|
if err != nil && !errors.Is(err, errRouteNotFound) {
|
||||||
|
return fmt.Errorf("get existing route gateway: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !prefix.Contains(defaultGateway) {
|
||||||
|
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
|
||||||
|
if defaultGateway.Is6() {
|
||||||
|
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := existsInRouteTable(gatewayPrefix)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var exitIntf string
|
||||||
|
gatewayHop, intf, err := getNextHop(defaultGateway)
|
||||||
|
if err != nil && !errors.Is(err, errRouteNotFound) {
|
||||||
|
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||||
|
}
|
||||||
|
if intf != nil {
|
||||||
|
exitIntf = intf.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
||||||
|
return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
||||||
|
r, err := netroute.New()
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
|
||||||
|
}
|
||||||
|
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to get route for %s: %v", ip, err)
|
||||||
|
return netip.Addr{}, nil, errRouteNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
||||||
|
if gateway == nil {
|
||||||
|
if preferredSrc == nil {
|
||||||
|
return netip.Addr{}, nil, errRouteNotFound
|
||||||
|
}
|
||||||
|
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
|
||||||
|
|
||||||
|
addr, ok := netip.AddrFromSlice(preferredSrc)
|
||||||
|
if !ok {
|
||||||
|
return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc)
|
||||||
|
}
|
||||||
|
return addr.Unmap(), intf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, ok := netip.AddrFromSlice(gateway)
|
||||||
|
if !ok {
|
||||||
|
return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway)
|
||||||
|
}
|
||||||
|
|
||||||
|
return addr.Unmap(), intf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||||
|
routes, err := getRoutesFromTable()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("get routes from table: %w", err)
|
||||||
|
}
|
||||||
|
for _, tableRoute := range routes {
|
||||||
|
if tableRoute == prefix {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||||
|
routes, err := getRoutesFromTable()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("get routes from table: %w", err)
|
||||||
|
}
|
||||||
|
for _, tableRoute := range routes {
|
||||||
|
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRouteToNonVPNIntf returns the next hop and interface for the given prefix.
|
||||||
|
// If the next hop or interface is pointing to the VPN interface, it will return an error
|
||||||
|
func addRouteToNonVPNIntf(
|
||||||
|
prefix netip.Prefix,
|
||||||
|
vpnIntf *iface.WGIface,
|
||||||
|
initialNextHop netip.Addr,
|
||||||
|
initialIntf *net.Interface,
|
||||||
|
) (netip.Addr, string, error) {
|
||||||
|
addr := prefix.Addr()
|
||||||
|
switch {
|
||||||
|
case addr.IsLoopback():
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix)
|
||||||
|
case addr.IsLinkLocalUnicast():
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix)
|
||||||
|
case addr.IsLinkLocalMulticast():
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix)
|
||||||
|
case addr.IsInterfaceLocalMulticast():
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix)
|
||||||
|
case addr.IsUnspecified():
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix)
|
||||||
|
case addr.IsMulticast():
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||||
|
nexthop, intf, err := getNextHop(addr)
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
|
||||||
|
exitNextHop := nexthop
|
||||||
|
var exitIntf string
|
||||||
|
if intf != nil {
|
||||||
|
exitIntf = intf.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
||||||
|
if !ok {
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr")
|
||||||
|
}
|
||||||
|
|
||||||
|
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||||
|
if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() {
|
||||||
|
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
|
||||||
|
exitNextHop = initialNextHop
|
||||||
|
if initialIntf != nil {
|
||||||
|
exitIntf = initialIntf.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
|
||||||
|
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
|
||||||
|
return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return exitNextHop, exitIntf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||||
|
// in two /1 prefixes to avoid replacing the existing default route
|
||||||
|
func genericAddVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
if prefix == defaultv4 {
|
||||||
|
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
||||||
|
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
|
||||||
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: remove once IPv6 is supported on the interface
|
||||||
|
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||||
|
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||||
|
}
|
||||||
|
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||||
|
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
||||||
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
} else if prefix == defaultv6 {
|
||||||
|
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||||
|
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||||
|
}
|
||||||
|
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||||
|
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
||||||
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return addNonExistingRoute(prefix, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
||||||
|
func addNonExistingRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
ok, err := existsInRouteTable(prefix)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("exists in route table: %w", err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err = isSubRange(prefix)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sub range: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
err := addRouteForCurrentDefaultGateway(prefix)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return addToRouteTable(prefix, netip.Addr{}, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
||||||
|
// it will remove the split /1 prefixes
|
||||||
|
func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
if prefix == defaultv4 {
|
||||||
|
var result *multierror.Error
|
||||||
|
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: remove once IPv6 is supported on the interface
|
||||||
|
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
} else if prefix == defaultv6 {
|
||||||
|
var result *multierror.Error
|
||||||
|
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
return removeFromRouteTable(prefix, netip.Addr{}, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
|
||||||
|
addr, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("parse IP address: %s", ip)
|
||||||
|
}
|
||||||
|
addr = addr.Unmap()
|
||||||
|
|
||||||
|
var prefixLength int
|
||||||
|
switch {
|
||||||
|
case addr.Is4():
|
||||||
|
prefixLength = 32
|
||||||
|
case addr.Is6():
|
||||||
|
prefixLength = 128
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid IP address: %s", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := netip.PrefixFrom(addr, prefixLength)
|
||||||
|
return &prefix, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified())
|
||||||
|
if err != nil && !errors.Is(err, errRouteNotFound) {
|
||||||
|
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
||||||
|
}
|
||||||
|
initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified())
|
||||||
|
if err != nil && !errors.Is(err, errRouteNotFound) {
|
||||||
|
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
*routeManager = NewRouteManager(
|
||||||
|
func(prefix netip.Prefix) (netip.Addr, string, error) {
|
||||||
|
addr := prefix.Addr()
|
||||||
|
nexthop, intf := initialNextHopV4, initialIntfV4
|
||||||
|
if addr.Is6() {
|
||||||
|
nexthop, intf = initialNextHopV6, initialIntfV6
|
||||||
|
}
|
||||||
|
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
|
||||||
|
},
|
||||||
|
removeFromRouteTable,
|
||||||
|
)
|
||||||
|
|
||||||
|
return setupHooks(*routeManager, initAddresses)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
|
||||||
|
if routeManager == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Remove hooks selectively
|
||||||
|
nbnet.RemoveDialerHooks()
|
||||||
|
nbnet.RemoveListenerHooks()
|
||||||
|
|
||||||
|
if err := routeManager.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush route manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
|
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||||
|
prefix, err := getPrefixFromIP(ip)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
|
||||||
|
return fmt.Errorf("adding route reference: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
afterHook := func(connID nbnet.ConnectionID) error {
|
||||||
|
if err := routeManager.RemoveRouteRef(connID); err != nil {
|
||||||
|
return fmt.Errorf("remove route reference: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range initAddresses {
|
||||||
|
if err := beforeHook("init", ip); err != nil {
|
||||||
|
log.Errorf("Failed to add route reference: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
var result *multierror.Error
|
||||||
|
for _, ip := range resolvedIPs {
|
||||||
|
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
||||||
|
}
|
||||||
|
return result.ErrorOrNil()
|
||||||
|
})
|
||||||
|
|
||||||
|
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||||
|
return afterHook(connID)
|
||||||
|
})
|
||||||
|
|
||||||
|
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
||||||
|
return beforeHook(connID, ip.IP)
|
||||||
|
})
|
||||||
|
|
||||||
|
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
||||||
|
return afterHook(connID)
|
||||||
|
})
|
||||||
|
|
||||||
|
return beforeHook, afterHook, nil
|
||||||
|
}
|
@ -35,6 +35,9 @@ const (
|
|||||||
|
|
||||||
var ErrTableIDExists = errors.New("ID exists with different name")
|
var ErrTableIDExists = errors.New("ID exists with different name")
|
||||||
|
|
||||||
|
var routeManager = &RouteManager{}
|
||||||
|
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true"
|
||||||
|
|
||||||
type ruleParams struct {
|
type ruleParams struct {
|
||||||
fwmark int
|
fwmark int
|
||||||
tableID int
|
tableID int
|
||||||
@ -66,7 +69,12 @@ func getSetupRules() []ruleParams {
|
|||||||
// enabling VPN connectivity.
|
// enabling VPN connectivity.
|
||||||
//
|
//
|
||||||
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
|
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
|
||||||
func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
|
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
|
||||||
|
if isLegacy {
|
||||||
|
log.Infof("Using legacy routing setup")
|
||||||
|
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||||
|
}
|
||||||
|
|
||||||
if err = addRoutingTableName(); err != nil {
|
if err = addRoutingTableName(); err != nil {
|
||||||
log.Errorf("Error adding routing table name: %v", err)
|
log.Errorf("Error adding routing table name: %v", err)
|
||||||
}
|
}
|
||||||
@ -82,6 +90,11 @@ func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ pee
|
|||||||
rules := getSetupRules()
|
rules := getSetupRules()
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := addRule(rule); err != nil {
|
if err := addRule(rule); err != nil {
|
||||||
|
if errors.Is(err, syscall.EOPNOTSUPP) {
|
||||||
|
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
|
||||||
|
isLegacy = true
|
||||||
|
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||||
|
}
|
||||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -93,6 +106,10 @@ func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ pee
|
|||||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||||
func cleanupRouting() error {
|
func cleanupRouting() error {
|
||||||
|
if isLegacy {
|
||||||
|
return cleanupRoutingWithRouteManager(routeManager)
|
||||||
|
}
|
||||||
|
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
|
|
||||||
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
||||||
@ -104,7 +121,7 @@ func cleanupRouting() error {
|
|||||||
|
|
||||||
rules := getSetupRules()
|
rules := getSetupRules()
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := removeAllRules(rule); err != nil {
|
if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) {
|
||||||
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
|
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -112,49 +129,104 @@ func cleanupRouting() error {
|
|||||||
return result.ErrorOrNil()
|
return result.ErrorOrNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||||
|
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
||||||
|
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||||
|
}
|
||||||
|
|
||||||
func addVPNRoute(prefix netip.Prefix, intf string) error {
|
func addVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 2
|
if isLegacy {
|
||||||
|
return genericAddVPNRoute(prefix, intf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
|
||||||
|
|
||||||
// TODO remove this once we have ipv6 support
|
// TODO remove this once we have ipv6 support
|
||||||
if prefix == defaultv4 {
|
if prefix == defaultv4 {
|
||||||
if err := addUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
|
if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
|
||||||
return fmt.Errorf("add blackhole: %w", err)
|
return fmt.Errorf("add blackhole: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := addRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
||||||
return fmt.Errorf("add route: %w", err)
|
return fmt.Errorf("add route: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeVPNRoute(prefix netip.Prefix, intf string) error {
|
func removeVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
|
if isLegacy {
|
||||||
|
return genericRemoveVPNRoute(prefix, intf)
|
||||||
|
}
|
||||||
|
|
||||||
// TODO remove this once we have ipv6 support
|
// TODO remove this once we have ipv6 support
|
||||||
if prefix == defaultv4 {
|
if prefix == defaultv4 {
|
||||||
if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
|
if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
|
||||||
return fmt.Errorf("remove unreachable route: %w", err)
|
return fmt.Errorf("remove unreachable route: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := removeRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
||||||
return fmt.Errorf("remove route: %w", err)
|
return fmt.Errorf("remove route: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||||
|
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get v4 routes: %w", err)
|
||||||
|
}
|
||||||
|
v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get v6 routes: %w", err)
|
||||||
|
|
||||||
|
}
|
||||||
|
return append(v4Routes, v6Routes...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRoutes fetches routes from a specific routing table identified by tableID.
|
||||||
|
func getRoutes(tableID, family int) ([]netip.Prefix, error) {
|
||||||
|
var prefixList []netip.Prefix
|
||||||
|
|
||||||
|
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list routes from table %d: %v", tableID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, route := range routes {
|
||||||
|
if route.Dst != nil {
|
||||||
|
addr, ok := netip.AddrFromSlice(route.Dst.IP)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
ones, _ := route.Dst.Mask.Size()
|
||||||
|
|
||||||
|
prefix := netip.PrefixFrom(addr, ones)
|
||||||
|
if prefix.IsValid() {
|
||||||
|
prefixList = append(prefixList, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefixList, nil
|
||||||
|
}
|
||||||
|
|
||||||
// addRoute adds a route to a specific routing table identified by tableID.
|
// addRoute adds a route to a specific routing table identified by tableID.
|
||||||
func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error {
|
func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
|
||||||
route := &netlink.Route{
|
route := &netlink.Route{
|
||||||
Scope: netlink.SCOPE_UNIVERSE,
|
Scope: netlink.SCOPE_UNIVERSE,
|
||||||
Table: tableID,
|
Table: tableID,
|
||||||
Family: family,
|
Family: getAddressFamily(prefix),
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefix != nil {
|
|
||||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||||
}
|
}
|
||||||
route.Dst = ipNet
|
route.Dst = ipNet
|
||||||
}
|
|
||||||
|
|
||||||
if err := addNextHop(addr, intf, route); err != nil {
|
if err := addNextHop(addr, intf, route); err != nil {
|
||||||
return fmt.Errorf("add gateway and device: %w", err)
|
return fmt.Errorf("add gateway and device: %w", err)
|
||||||
@ -170,7 +242,7 @@ func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) err
|
|||||||
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
|
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
|
||||||
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
|
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
|
||||||
// tableID specifies the routing table to which the unreachable route will be added.
|
// tableID specifies the routing table to which the unreachable route will be added.
|
||||||
func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||||
@ -179,7 +251,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
|||||||
route := &netlink.Route{
|
route := &netlink.Route{
|
||||||
Type: syscall.RTN_UNREACHABLE,
|
Type: syscall.RTN_UNREACHABLE,
|
||||||
Table: tableID,
|
Table: tableID,
|
||||||
Family: ipFamily,
|
Family: getAddressFamily(prefix),
|
||||||
Dst: ipNet,
|
Dst: ipNet,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,7 +262,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||||
@ -199,7 +271,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
|||||||
route := &netlink.Route{
|
route := &netlink.Route{
|
||||||
Type: syscall.RTN_UNREACHABLE,
|
Type: syscall.RTN_UNREACHABLE,
|
||||||
Table: tableID,
|
Table: tableID,
|
||||||
Family: ipFamily,
|
Family: getAddressFamily(prefix),
|
||||||
Dst: ipNet,
|
Dst: ipNet,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -212,7 +284,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// removeRoute removes a route from a specific routing table identified by tableID.
|
// removeRoute removes a route from a specific routing table identified by tableID.
|
||||||
func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error {
|
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
|
||||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||||
@ -221,7 +293,7 @@ func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int)
|
|||||||
route := &netlink.Route{
|
route := &netlink.Route{
|
||||||
Scope: netlink.SCOPE_UNIVERSE,
|
Scope: netlink.SCOPE_UNIVERSE,
|
||||||
Table: tableID,
|
Table: tableID,
|
||||||
Family: family,
|
Family: getAddressFamily(prefix),
|
||||||
Dst: ipNet,
|
Dst: ipNet,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -392,23 +464,25 @@ func removeAllRules(params ruleParams) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addNextHop adds the gateway and device to the route.
|
// addNextHop adds the gateway and device to the route.
|
||||||
func addNextHop(addr *string, intf *string, route *netlink.Route) error {
|
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
|
||||||
if addr != nil {
|
if addr.IsValid() {
|
||||||
ip := net.ParseIP(*addr)
|
route.Gw = addr.AsSlice()
|
||||||
if ip == nil {
|
|
||||||
return fmt.Errorf("parsing address %s failed", *addr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
route.Gw = ip
|
if intf != "" {
|
||||||
}
|
link, err := netlink.LinkByName(intf)
|
||||||
|
|
||||||
if intf != nil {
|
|
||||||
link, err := netlink.LinkByName(*intf)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("set interface %s: %w", *intf, err)
|
return fmt.Errorf("set interface %s: %w", intf, err)
|
||||||
}
|
}
|
||||||
route.LinkIndex = link.Attrs().Index
|
route.LinkIndex = link.Attrs().Index
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getAddressFamily(prefix netip.Prefix) int {
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
return netlink.FAMILY_V4
|
||||||
|
}
|
||||||
|
return netlink.FAMILY_V6
|
||||||
|
}
|
||||||
|
@ -21,8 +21,6 @@ var expectedLoopbackInt = "lo"
|
|||||||
var expectedExternalInt = "dummyext0"
|
var expectedExternalInt = "dummyext0"
|
||||||
var expectedInternalInt = "dummyint0"
|
var expectedInternalInt = "dummyint0"
|
||||||
|
|
||||||
var errRouteNotFound = fmt.Errorf("route not found")
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
testCases = append(testCases, []testCase{
|
testCases = append(testCases, []testCase{
|
||||||
{
|
{
|
||||||
|
@ -3,414 +3,21 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
"github.com/libp2p/go-netroute"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
|
||||||
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
|
||||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
|
||||||
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
|
||||||
|
|
||||||
var errRouteNotFound = fmt.Errorf("route not found")
|
|
||||||
|
|
||||||
func enableIPForwarding() error {
|
func enableIPForwarding() error {
|
||||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: fix: for default our wg address now appears as the default gw
|
|
||||||
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
|
||||||
addr := netip.IPv4Unspecified()
|
|
||||||
if prefix.Addr().Is6() {
|
|
||||||
addr = netip.IPv6Unspecified()
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultGateway, _, err := getNextHop(addr)
|
|
||||||
if err != nil && !errors.Is(err, errRouteNotFound) {
|
|
||||||
return fmt.Errorf("get existing route gateway: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !prefix.Contains(defaultGateway) {
|
|
||||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
|
|
||||||
if defaultGateway.Is6() {
|
|
||||||
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := existsInRouteTable(gatewayPrefix)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var exitIntf string
|
|
||||||
gatewayHop, intf, err := getNextHop(defaultGateway)
|
|
||||||
if err != nil && !errors.Is(err, errRouteNotFound) {
|
|
||||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
|
||||||
}
|
|
||||||
if intf != nil {
|
|
||||||
exitIntf = intf.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
|
||||||
return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
|
||||||
r, err := netroute.New()
|
|
||||||
if err != nil {
|
|
||||||
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
|
|
||||||
}
|
|
||||||
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Getting routes returned an error: %v", err)
|
|
||||||
return netip.Addr{}, nil, errRouteNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
|
||||||
if gateway == nil {
|
|
||||||
if preferredSrc == nil {
|
|
||||||
return netip.Addr{}, nil, errRouteNotFound
|
|
||||||
}
|
|
||||||
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
|
|
||||||
|
|
||||||
addr, ok := netip.AddrFromSlice(preferredSrc)
|
|
||||||
if !ok {
|
|
||||||
return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc)
|
|
||||||
}
|
|
||||||
return addr.Unmap(), intf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, ok := netip.AddrFromSlice(gateway)
|
|
||||||
if !ok {
|
|
||||||
return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway)
|
|
||||||
}
|
|
||||||
|
|
||||||
return addr.Unmap(), intf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
|
||||||
routes, err := getRoutesFromTable()
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("get routes from table: %w", err)
|
|
||||||
}
|
|
||||||
for _, tableRoute := range routes {
|
|
||||||
if tableRoute == prefix {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
|
||||||
routes, err := getRoutesFromTable()
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("get routes from table: %w", err)
|
|
||||||
}
|
|
||||||
for _, tableRoute := range routes {
|
|
||||||
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRouteToNonVPNIntf returns the next hop and interface for the given prefix.
|
|
||||||
// If the next hop or interface is pointing to the VPN interface, it will return an error
|
|
||||||
func addRouteToNonVPNIntf(
|
|
||||||
prefix netip.Prefix,
|
|
||||||
vpnIntf *iface.WGIface,
|
|
||||||
initialNextHop netip.Addr,
|
|
||||||
initialIntf *net.Interface,
|
|
||||||
) (netip.Addr, string, error) {
|
|
||||||
addr := prefix.Addr()
|
|
||||||
switch {
|
|
||||||
case addr.IsLoopback():
|
|
||||||
return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix)
|
|
||||||
case addr.IsLinkLocalUnicast():
|
|
||||||
return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix)
|
|
||||||
case addr.IsLinkLocalMulticast():
|
|
||||||
return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix)
|
|
||||||
case addr.IsInterfaceLocalMulticast():
|
|
||||||
return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix)
|
|
||||||
case addr.IsUnspecified():
|
|
||||||
return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix)
|
|
||||||
case addr.IsMulticast():
|
|
||||||
return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
|
||||||
nexthop, intf, err := getNextHop(addr)
|
|
||||||
if err != nil {
|
|
||||||
return netip.Addr{}, "", fmt.Errorf("get next hop: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
|
|
||||||
exitNextHop := nexthop
|
|
||||||
var exitIntf string
|
|
||||||
if intf != nil {
|
|
||||||
exitIntf = intf.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
|
||||||
if !ok {
|
|
||||||
return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr")
|
|
||||||
}
|
|
||||||
|
|
||||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
|
||||||
if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() {
|
|
||||||
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
|
|
||||||
exitNextHop = initialNextHop
|
|
||||||
if initialIntf != nil {
|
|
||||||
exitIntf = initialIntf.Name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
|
|
||||||
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
|
|
||||||
return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return exitNextHop, exitIntf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
|
||||||
// in two /1 prefixes to avoid replacing the existing default route
|
|
||||||
func addVPNRoute(prefix netip.Prefix, intf string) error {
|
func addVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
if prefix == defaultv4 {
|
return genericAddVPNRoute(prefix, intf)
|
||||||
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
|
||||||
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
|
|
||||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: remove once IPv6 is supported on the interface
|
|
||||||
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
|
||||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
|
||||||
}
|
|
||||||
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
|
||||||
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
|
||||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
} else if prefix == defaultv6 {
|
|
||||||
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
|
||||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
|
||||||
}
|
|
||||||
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
|
||||||
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
|
||||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return addNonExistingRoute(prefix, intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
|
||||||
func addNonExistingRoute(prefix netip.Prefix, intf string) error {
|
|
||||||
ok, err := existsInRouteTable(prefix)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("exists in route table: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = isSubRange(prefix)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("sub range: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
err := addRouteForCurrentDefaultGateway(prefix)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return addToRouteTable(prefix, netip.Addr{}, intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
|
||||||
// it will remove the split /1 prefixes
|
|
||||||
func removeVPNRoute(prefix netip.Prefix, intf string) error {
|
func removeVPNRoute(prefix netip.Prefix, intf string) error {
|
||||||
if prefix == defaultv4 {
|
return genericRemoveVPNRoute(prefix, intf)
|
||||||
var result *multierror.Error
|
|
||||||
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: remove once IPv6 is supported on the interface
|
|
||||||
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.ErrorOrNil()
|
|
||||||
} else if prefix == defaultv6 {
|
|
||||||
var result *multierror.Error
|
|
||||||
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.ErrorOrNil()
|
|
||||||
}
|
|
||||||
|
|
||||||
return removeFromRouteTable(prefix, netip.Addr{}, intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
|
|
||||||
addr, ok := netip.AddrFromSlice(ip)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("parse IP address: %s", ip)
|
|
||||||
}
|
|
||||||
addr = addr.Unmap()
|
|
||||||
|
|
||||||
var prefixLength int
|
|
||||||
switch {
|
|
||||||
case addr.Is4():
|
|
||||||
prefixLength = 32
|
|
||||||
case addr.Is6():
|
|
||||||
prefixLength = 128
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid IP address: %s", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix := netip.PrefixFrom(addr, prefixLength)
|
|
||||||
return &prefix, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
|
||||||
initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified())
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
|
||||||
}
|
|
||||||
initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified())
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
*routeManager = NewRouteManager(
|
|
||||||
func(prefix netip.Prefix) (netip.Addr, string, error) {
|
|
||||||
addr := prefix.Addr()
|
|
||||||
nexthop, intf := initialNextHopV4, initialIntfV4
|
|
||||||
if addr.Is6() {
|
|
||||||
nexthop, intf = initialNextHopV6, initialIntfV6
|
|
||||||
}
|
|
||||||
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
|
|
||||||
},
|
|
||||||
removeFromRouteTable,
|
|
||||||
)
|
|
||||||
|
|
||||||
return setupHooks(*routeManager, initAddresses)
|
|
||||||
}
|
|
||||||
|
|
||||||
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
|
|
||||||
if routeManager == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Remove hooks selectively
|
|
||||||
nbnet.RemoveDialerHooks()
|
|
||||||
nbnet.RemoveListenerHooks()
|
|
||||||
|
|
||||||
if err := routeManager.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush route manager: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
|
||||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
|
||||||
prefix, err := getPrefixFromIP(ip)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
|
|
||||||
return fmt.Errorf("adding route reference: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
afterHook := func(connID nbnet.ConnectionID) error {
|
|
||||||
if err := routeManager.RemoveRouteRef(connID); err != nil {
|
|
||||||
return fmt.Errorf("remove route reference: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ip := range initAddresses {
|
|
||||||
if err := beforeHook("init", ip); err != nil {
|
|
||||||
log.Errorf("Failed to add route reference: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
var result *multierror.Error
|
|
||||||
for _, ip := range resolvedIPs {
|
|
||||||
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
|
||||||
}
|
|
||||||
return result.ErrorOrNil()
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
|
||||||
return afterHook(connID)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
|
||||||
return beforeHook(connID, ip.IP)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
|
||||||
return afterHook(connID)
|
|
||||||
})
|
|
||||||
|
|
||||||
return beforeHook, afterHook, nil
|
|
||||||
}
|
}
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
//go:build !linux && !ios
|
//go:build !android && !ios
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -20,16 +22,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
type dialer interface {
|
||||||
t.Helper()
|
Dial(network, address string) (net.Conn, error)
|
||||||
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
prefixGateway, _, err := getNextHop(prefix.Addr())
|
|
||||||
require.NoError(t, err, "getNextHop should not return err")
|
|
||||||
if invert {
|
|
||||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddRemoveRoutes(t *testing.T) {
|
func TestAddRemoveRoutes(t *testing.T) {
|
||||||
@ -72,8 +67,8 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
assert.NoError(t, cleanupRouting())
|
assert.NoError(t, cleanupRouting())
|
||||||
})
|
})
|
||||||
|
|
||||||
err = addVPNRoute(testCase.prefix, wgInterface.Name())
|
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "addVPNRoute should not return err")
|
require.NoError(t, err, "genericAddVPNRoute should not return err")
|
||||||
|
|
||||||
if testCase.shouldRouteToWireguard {
|
if testCase.shouldRouteToWireguard {
|
||||||
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
|
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
|
||||||
@ -83,8 +78,8 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
exists, err := existsInRouteTable(testCase.prefix)
|
exists, err := existsInRouteTable(testCase.prefix)
|
||||||
require.NoError(t, err, "existsInRouteTable should not return err")
|
require.NoError(t, err, "existsInRouteTable should not return err")
|
||||||
if exists && testCase.shouldRouteToWireguard {
|
if exists && testCase.shouldRouteToWireguard {
|
||||||
err = removeVPNRoute(testCase.prefix, wgInterface.Name())
|
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "removeVPNRoute should not return err")
|
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
||||||
|
|
||||||
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
|
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
|
||||||
require.NoError(t, err, "getNextHop should not return err")
|
require.NoError(t, err, "getNextHop should not return err")
|
||||||
@ -144,7 +139,7 @@ func TestGetNextHop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||||
defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||||
t.Log("defaultGateway: ", defaultGateway)
|
t.Log("defaultGateway: ", defaultGateway)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -205,20 +200,14 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
|||||||
err = wgInterface.Create()
|
err = wgInterface.Create()
|
||||||
require.NoError(t, err, "should create testing wireguard interface")
|
require.NoError(t, err, "should create testing wireguard interface")
|
||||||
|
|
||||||
_, _, err = setupRouting(nil, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, cleanupRouting())
|
|
||||||
})
|
|
||||||
|
|
||||||
// Prepare the environment
|
// Prepare the environment
|
||||||
if testCase.preExistingPrefix.IsValid() {
|
if testCase.preExistingPrefix.IsValid() {
|
||||||
err := addVPNRoute(testCase.preExistingPrefix, wgInterface.Name())
|
err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "should not return err when adding pre-existing route")
|
require.NoError(t, err, "should not return err when adding pre-existing route")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the route
|
// Add the route
|
||||||
err = addVPNRoute(testCase.prefix, wgInterface.Name())
|
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "should not return err when adding route")
|
require.NoError(t, err, "should not return err when adding route")
|
||||||
|
|
||||||
if testCase.shouldAddRoute {
|
if testCase.shouldAddRoute {
|
||||||
@ -228,7 +217,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
|||||||
require.True(t, ok, "route should exist")
|
require.True(t, ok, "route should exist")
|
||||||
|
|
||||||
// remove route again if added
|
// remove route again if added
|
||||||
err = removeVPNRoute(testCase.prefix, wgInterface.Name())
|
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
|
||||||
require.NoError(t, err, "should not return err")
|
require.NoError(t, err, "should not return err")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -284,12 +273,6 @@ func TestIsSubRange(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestExistsInRouteTable(t *testing.T) {
|
func TestExistsInRouteTable(t *testing.T) {
|
||||||
_, _, err := setupRouting(nil, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, cleanupRouting())
|
|
||||||
})
|
|
||||||
|
|
||||||
addresses, err := net.InterfaceAddrs()
|
addresses, err := net.InterfaceAddrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||||
@ -298,10 +281,19 @@ func TestExistsInRouteTable(t *testing.T) {
|
|||||||
var addressPrefixes []netip.Prefix
|
var addressPrefixes []netip.Prefix
|
||||||
for _, address := range addresses {
|
for _, address := range addresses {
|
||||||
p := netip.MustParsePrefix(address.String())
|
p := netip.MustParsePrefix(address.String())
|
||||||
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
|
if p.Addr().Is6() {
|
||||||
if p.Addr().Is4() && !p.Addr().IsLinkLocalUnicast() {
|
continue
|
||||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
|
||||||
}
|
}
|
||||||
|
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
|
||||||
|
if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
|
||||||
|
if runtime.GOOS == "linux" && p.Addr().IsLoopback() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
addressPrefixes = append(addressPrefixes, p.Masked())
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, prefix := range addressPrefixes {
|
for _, prefix := range addressPrefixes {
|
||||||
@ -314,3 +306,97 @@ func TestExistsInRouteTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
newNet, err := stdnet.NewNet()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||||
|
require.NoError(t, err, "should create testing WireGuard interface")
|
||||||
|
|
||||||
|
err = wgInterface.Create()
|
||||||
|
require.NoError(t, err, "should create testing WireGuard interface")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
wgInterface.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return wgInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestEnv(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
setupDummyInterfacesAndRoutes(t)
|
||||||
|
|
||||||
|
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, wgIface.Close())
|
||||||
|
})
|
||||||
|
|
||||||
|
_, _, err := setupRouting(nil, wgIface)
|
||||||
|
require.NoError(t, err, "setupRouting should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, cleanupRouting())
|
||||||
|
})
|
||||||
|
|
||||||
|
// default route exists in main table and vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 10.0.0.0/8 route exists in main table and vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 10.10.0.0/24 more specific route exists in vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 127.0.10.0/24 more specific route exists in vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
|
||||||
|
// unique route in vpn table
|
||||||
|
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
|
||||||
|
require.NoError(t, err, "addVPNRoute should not return err")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
|
||||||
|
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
||||||
|
t.Helper()
|
||||||
|
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prefixGateway, _, err := getNextHop(prefix.Addr())
|
||||||
|
require.NoError(t, err, "getNextHop should not return err")
|
||||||
|
if invert {
|
||||||
|
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
||||||
|
}
|
||||||
|
}
|
@ -1,101 +0,0 @@
|
|||||||
//go:build !android && !ios
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
type dialer interface {
|
|
||||||
Dial(network, address string) (net.Conn, error)
|
|
||||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet(nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
|
||||||
require.NoError(t, err, "should create testing WireGuard interface")
|
|
||||||
|
|
||||||
err = wgInterface.Create()
|
|
||||||
require.NoError(t, err, "should create testing WireGuard interface")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
wgInterface.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
return wgInterface
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupTestEnv(t *testing.T) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
setupDummyInterfacesAndRoutes(t)
|
|
||||||
|
|
||||||
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, wgIface.Close())
|
|
||||||
})
|
|
||||||
|
|
||||||
_, _, err := setupRouting(nil, wgIface)
|
|
||||||
require.NoError(t, err, "setupRouting should not return err")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, cleanupRouting())
|
|
||||||
})
|
|
||||||
|
|
||||||
// default route exists in main table and vpn table
|
|
||||||
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addVPNRoute should not return err")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
|
|
||||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
|
||||||
})
|
|
||||||
|
|
||||||
// 10.0.0.0/8 route exists in main table and vpn table
|
|
||||||
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addVPNRoute should not return err")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
|
|
||||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
|
||||||
})
|
|
||||||
|
|
||||||
// 10.10.0.0/24 more specific route exists in vpn table
|
|
||||||
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addVPNRoute should not return err")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
|
|
||||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
|
||||||
})
|
|
||||||
|
|
||||||
// 127.0.10.0/24 more specific route exists in vpn table
|
|
||||||
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addVPNRoute should not return err")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
|
|
||||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
|
||||||
})
|
|
||||||
|
|
||||||
// unique route in vpn table
|
|
||||||
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addVPNRoute should not return err")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
|
|
||||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
|
||||||
})
|
|
||||||
}
|
|
@ -35,7 +35,7 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
|||||||
udpConn, ok := conn.(*net.UDPConn)
|
udpConn, ok := conn.(*net.UDPConn)
|
||||||
if !ok {
|
if !ok {
|
||||||
if err := conn.Close(); err != nil {
|
if err := conn.Close(); err != nil {
|
||||||
log.Errorf("Failed to closeConn connection: %v", err)
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("expected UDP connection, got different type")
|
return nil, fmt.Errorf("expected UDP connection, got different type")
|
||||||
}
|
}
|
||||||
|
@ -145,10 +145,19 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
|
|||||||
// ListenUDP listens on the network address and returns a transport.UDPConn
|
// ListenUDP listens on the network address and returns a transport.UDPConn
|
||||||
// which includes support for write and close hooks.
|
// which includes support for write and close hooks.
|
||||||
func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
|
func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
|
||||||
udpConn, err := net.ListenUDP(network, laddr)
|
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("listen UDP: %w", err)
|
return nil, fmt.Errorf("listen UDP: %w", err)
|
||||||
}
|
}
|
||||||
connID := GenerateConnID()
|
|
||||||
return &UDPConn{UDPConn: udpConn, ID: connID, seenAddrs: &sync.Map{}}, nil
|
packetConn := conn.(*PacketConn)
|
||||||
|
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := packetConn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected UDPConn, got different type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user