Linux legacy routing (#1774)

* Add Linux legacy routing if ip rule functionality is not available

* Ignore exclusion route errors if host has no route

* Exclude iOS from route manager

* Also retrieve IPv6 routes

* Ignore loopback addresses not being in the main table

* Ignore "not supported" errors on cleanup

* Fix regression in ListenUDP not using fwmarks
This commit is contained in:
Viktor Liu 2024-04-03 18:04:22 +02:00 committed by GitHub
parent 7938295190
commit bb0d5c5baf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 656 additions and 569 deletions

View File

@ -1,8 +1,9 @@
//go:build !android
//go:build !android && !ios
package routemanager
import (
"errors"
"fmt"
"net/netip"
"sync"
@ -53,6 +54,9 @@ func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Pref
if ref.count == 0 {
log.Debugf("Adding route for prefix %s", prefix)
nexthop, intf, err := rm.addRoute(prefix)
if errors.Is(err, errRouteNotFound) {
return nil
}
if err != nil {
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
}

View File

@ -0,0 +1,410 @@
//go:build !android && !ios
package routemanager
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var errRouteNotFound = fmt.Errorf("route not found")
// TODO: fix: for default our wg address now appears as the default gw
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
defaultGateway, _, err := getNextHop(addr)
if err != nil && !errors.Is(err, errRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(defaultGateway) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
if defaultGateway.Is6() {
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
var exitIntf string
gatewayHop, intf, err := getNextHop(defaultGateway)
if err != nil && !errors.Is(err, errRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
if intf != nil {
exitIntf = intf.Name
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf)
}
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
r, err := netroute.New()
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
}
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil {
log.Warnf("Failed to get route for %s: %v", ip, err)
return netip.Addr{}, nil, errRouteNotFound
}
log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if preferredSrc == nil {
return netip.Addr{}, nil, errRouteNotFound
}
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
addr, ok := netip.AddrFromSlice(preferredSrc)
if !ok {
return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc)
}
return addr.Unmap(), intf, nil
}
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway)
}
return addr.Unmap(), intf, nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// getRouteToNonVPNIntf returns the next hop and interface for the given prefix.
// If the next hop or interface is pointing to the VPN interface, it will return an error
func addRouteToNonVPNIntf(
prefix netip.Prefix,
vpnIntf *iface.WGIface,
initialNextHop netip.Addr,
initialIntf *net.Interface,
) (netip.Addr, string, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback():
return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix)
case addr.IsLinkLocalUnicast():
return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix)
case addr.IsLinkLocalMulticast():
return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix)
case addr.IsInterfaceLocalMulticast():
return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix)
case addr.IsUnspecified():
return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix)
case addr.IsMulticast():
return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix)
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, intf, err := getNextHop(addr)
if err != nil {
return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
exitNextHop := nexthop
var exitIntf string
if intf != nil {
exitIntf = intf.Name
}
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
exitNextHop = initialNextHop
if initialIntf != nil {
exitIntf = initialIntf.Name
}
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, exitIntf, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func genericAddVPNRoute(prefix netip.Prefix, intf string) error {
if prefix == defaultv4 {
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
return err
}
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return err
}
// TODO: remove once IPv6 is supported on the interface
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
} else if prefix == defaultv6 {
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
}
return addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func addNonExistingRoute(prefix netip.Prefix, intf string) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
err := addRouteForCurrentDefaultGateway(prefix)
if err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return addToRouteTable(prefix, netip.Addr{}, intf)
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error {
if prefix == defaultv4 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
// TODO: remove once IPv6 is supported on the interface
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
} else if prefix == defaultv6 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
}
return removeFromRouteTable(prefix, netip.Addr{}, intf)
}
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return nil, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
return &prefix, nil
}
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified())
if err != nil && !errors.Is(err, errRouteNotFound) {
log.Errorf("Unable to get initial v4 default next hop: %v", err)
}
initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified())
if err != nil && !errors.Is(err, errRouteNotFound) {
log.Errorf("Unable to get initial v6 default next hop: %v", err)
}
*routeManager = NewRouteManager(
func(prefix netip.Prefix) (netip.Addr, string, error) {
addr := prefix.Addr()
nexthop, intf := initialNextHopV4, initialIntfV4
if addr.Is6() {
nexthop, intf = initialNextHopV6, initialIntfV6
}
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
},
removeFromRouteTable,
)
return setupHooks(*routeManager, initAddresses)
}
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
if routeManager == nil {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
if err := routeManager.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
}
return nil
}
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := getPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
if err := routeManager.RemoveRouteRef(connID); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
return nil
}
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err)
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
var result *multierror.Error
for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP))
}
return result.ErrorOrNil()
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
return beforeHook, afterHook, nil
}

View File

@ -35,6 +35,9 @@ const (
var ErrTableIDExists = errors.New("ID exists with different name")
var routeManager = &RouteManager{}
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true"
type ruleParams struct {
fwmark int
tableID int
@ -66,7 +69,12 @@ func getSetupRules() []ruleParams {
// enabling VPN connectivity.
//
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy {
log.Infof("Using legacy routing setup")
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
}
@ -82,6 +90,11 @@ func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ pee
rules := getSetupRules()
for _, rule := range rules {
if err := addRule(rule); err != nil {
if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
isLegacy = true
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
}
}
@ -93,6 +106,10 @@ func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ pee
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process.
func cleanupRouting() error {
if isLegacy {
return cleanupRoutingWithRouteManager(routeManager)
}
var result *multierror.Error
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
@ -104,7 +121,7 @@ func cleanupRouting() error {
rules := getSetupRules()
for _, rule := range rules {
if err := removeAllRules(rule); err != nil {
if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) {
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
}
}
@ -112,49 +129,104 @@ func cleanupRouting() error {
return result.ErrorOrNil()
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
}
func addVPNRoute(prefix netip.Prefix, intf string) error {
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 2
if isLegacy {
return genericAddVPNRoute(prefix, intf)
}
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support
if prefix == defaultv4 {
if err := addUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add blackhole: %w", err)
}
}
if err := addRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add route: %w", err)
}
return nil
}
func removeVPNRoute(prefix netip.Prefix, intf string) error {
if isLegacy {
return genericRemoveVPNRoute(prefix, intf)
}
// TODO remove this once we have ipv6 support
if prefix == defaultv4 {
if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove unreachable route: %w", err)
}
}
if err := removeRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
}
func getRoutesFromTable() ([]netip.Prefix, error) {
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
if err != nil {
return nil, fmt.Errorf("get v4 routes: %w", err)
}
v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6)
if err != nil {
return nil, fmt.Errorf("get v6 routes: %w", err)
}
return append(v4Routes, v6Routes...), nil
}
// getRoutes fetches routes from a specific routing table identified by tableID.
func getRoutes(tableID, family int) ([]netip.Prefix, error) {
var prefixList []netip.Prefix
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
if err != nil {
return nil, fmt.Errorf("list routes from table %d: %v", tableID, err)
}
for _, route := range routes {
if route.Dst != nil {
addr, ok := netip.AddrFromSlice(route.Dst.IP)
if !ok {
return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP)
}
ones, _ := route.Dst.Mask.Size()
prefix := netip.PrefixFrom(addr, ones)
if prefix.IsValid() {
prefixList = append(prefixList, prefix)
}
}
}
return prefixList, nil
}
// addRoute adds a route to a specific routing table identified by tableID.
func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error {
func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
Family: family,
Family: getAddressFamily(prefix),
}
if prefix != nil {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route.Dst = ipNet
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route.Dst = ipNet
if err := addNextHop(addr, intf, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err)
@ -170,7 +242,7 @@ func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) err
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
// tableID specifies the routing table to which the unreachable route will be added.
func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
@ -179,7 +251,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
route := &netlink.Route{
Type: syscall.RTN_UNREACHABLE,
Table: tableID,
Family: ipFamily,
Family: getAddressFamily(prefix),
Dst: ipNet,
}
@ -190,7 +262,7 @@ func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
return nil
}
func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
@ -199,7 +271,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
route := &netlink.Route{
Type: syscall.RTN_UNREACHABLE,
Table: tableID,
Family: ipFamily,
Family: getAddressFamily(prefix),
Dst: ipNet,
}
@ -212,7 +284,7 @@ func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error {
}
// removeRoute removes a route from a specific routing table identified by tableID.
func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error {
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
@ -221,7 +293,7 @@ func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int)
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
Family: family,
Family: getAddressFamily(prefix),
Dst: ipNet,
}
@ -392,23 +464,25 @@ func removeAllRules(params ruleParams) error {
}
// addNextHop adds the gateway and device to the route.
func addNextHop(addr *string, intf *string, route *netlink.Route) error {
if addr != nil {
ip := net.ParseIP(*addr)
if ip == nil {
return fmt.Errorf("parsing address %s failed", *addr)
}
route.Gw = ip
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
if addr.IsValid() {
route.Gw = addr.AsSlice()
}
if intf != nil {
link, err := netlink.LinkByName(*intf)
if intf != "" {
link, err := netlink.LinkByName(intf)
if err != nil {
return fmt.Errorf("set interface %s: %w", *intf, err)
return fmt.Errorf("set interface %s: %w", intf, err)
}
route.LinkIndex = link.Attrs().Index
}
return nil
}
func getAddressFamily(prefix netip.Prefix) int {
if prefix.Addr().Is4() {
return netlink.FAMILY_V4
}
return netlink.FAMILY_V6
}

View File

@ -21,8 +21,6 @@ var expectedLoopbackInt = "lo"
var expectedExternalInt = "dummyext0"
var expectedInternalInt = "dummyint0"
var errRouteNotFound = fmt.Errorf("route not found")
func init() {
testCases = append(testCases, []testCase{
{

View File

@ -3,414 +3,21 @@
package routemanager
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var errRouteNotFound = fmt.Errorf("route not found")
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
// TODO: fix: for default our wg address now appears as the default gw
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
defaultGateway, _, err := getNextHop(addr)
if err != nil && !errors.Is(err, errRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(defaultGateway) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
if defaultGateway.Is6() {
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
var exitIntf string
gatewayHop, intf, err := getNextHop(defaultGateway)
if err != nil && !errors.Is(err, errRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
if intf != nil {
exitIntf = intf.Name
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf)
}
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
r, err := netroute.New()
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
}
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil {
log.Errorf("Getting routes returned an error: %v", err)
return netip.Addr{}, nil, errRouteNotFound
}
log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if preferredSrc == nil {
return netip.Addr{}, nil, errRouteNotFound
}
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
addr, ok := netip.AddrFromSlice(preferredSrc)
if !ok {
return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc)
}
return addr.Unmap(), intf, nil
}
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway)
}
return addr.Unmap(), intf, nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// getRouteToNonVPNIntf returns the next hop and interface for the given prefix.
// If the next hop or interface is pointing to the VPN interface, it will return an error
func addRouteToNonVPNIntf(
prefix netip.Prefix,
vpnIntf *iface.WGIface,
initialNextHop netip.Addr,
initialIntf *net.Interface,
) (netip.Addr, string, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback():
return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix)
case addr.IsLinkLocalUnicast():
return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix)
case addr.IsLinkLocalMulticast():
return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix)
case addr.IsInterfaceLocalMulticast():
return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix)
case addr.IsUnspecified():
return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix)
case addr.IsMulticast():
return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix)
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, intf, err := getNextHop(addr)
if err != nil {
return netip.Addr{}, "", fmt.Errorf("get next hop: %s", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
exitNextHop := nexthop
var exitIntf string
if intf != nil {
exitIntf = intf.Name
}
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
exitNextHop = initialNextHop
if initialIntf != nil {
exitIntf = initialIntf.Name
}
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, exitIntf, nil
}
// addVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func addVPNRoute(prefix netip.Prefix, intf string) error {
if prefix == defaultv4 {
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
return err
}
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return err
}
// TODO: remove once IPv6 is supported on the interface
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
} else if prefix == defaultv6 {
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
}
return addNonExistingRoute(prefix, intf)
return genericAddVPNRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func addNonExistingRoute(prefix netip.Prefix, intf string) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
err := addRouteForCurrentDefaultGateway(prefix)
if err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return addToRouteTable(prefix, netip.Addr{}, intf)
}
// removeVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func removeVPNRoute(prefix netip.Prefix, intf string) error {
if prefix == defaultv4 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
// TODO: remove once IPv6 is supported on the interface
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
} else if prefix == defaultv6 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
}
return removeFromRouteTable(prefix, netip.Addr{}, intf)
}
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return nil, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
return &prefix, nil
}
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified())
if err != nil {
log.Errorf("Unable to get initial v4 default next hop: %v", err)
}
initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified())
if err != nil {
log.Errorf("Unable to get initial v6 default next hop: %v", err)
}
*routeManager = NewRouteManager(
func(prefix netip.Prefix) (netip.Addr, string, error) {
addr := prefix.Addr()
nexthop, intf := initialNextHopV4, initialIntfV4
if addr.Is6() {
nexthop, intf = initialNextHopV6, initialIntfV6
}
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
},
removeFromRouteTable,
)
return setupHooks(*routeManager, initAddresses)
}
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
if routeManager == nil {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
if err := routeManager.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
}
return nil
}
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := getPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
if err := routeManager.RemoveRouteRef(connID); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
return nil
}
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err)
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
var result *multierror.Error
for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP))
}
return result.ErrorOrNil()
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
return beforeHook, afterHook, nil
return genericRemoveVPNRoute(prefix, intf)
}

View File

@ -1,13 +1,15 @@
//go:build !linux && !ios
//go:build !android && !ios
package routemanager
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
"os"
"runtime"
"strings"
"testing"
@ -20,16 +22,9 @@ import (
"github.com/netbirdio/netbird/iface"
)
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
t.Helper()
prefixGateway, _, err := getNextHop(prefix.Addr())
require.NoError(t, err, "getNextHop should not return err")
if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
} else {
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
}
type dialer interface {
Dial(network, address string) (net.Conn, error)
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func TestAddRemoveRoutes(t *testing.T) {
@ -72,8 +67,8 @@ func TestAddRemoveRoutes(t *testing.T) {
assert.NoError(t, cleanupRouting())
})
err = addVPNRoute(testCase.prefix, wgInterface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
require.NoError(t, err, "genericAddVPNRoute should not return err")
if testCase.shouldRouteToWireguard {
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
@ -83,8 +78,8 @@ func TestAddRemoveRoutes(t *testing.T) {
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = removeVPNRoute(testCase.prefix, wgInterface.Name())
require.NoError(t, err, "removeVPNRoute should not return err")
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
require.NoError(t, err, "getNextHop should not return err")
@ -144,7 +139,7 @@ func TestGetNextHop(t *testing.T) {
}
}
func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
func TestAddExistAndRemoveRoute(t *testing.T) {
defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
t.Log("defaultGateway: ", defaultGateway)
if err != nil {
@ -205,20 +200,14 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
_, _, err = setupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := addVPNRoute(testCase.preExistingPrefix, wgInterface.Name())
err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name())
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = addVPNRoute(testCase.prefix, wgInterface.Name())
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
@ -228,7 +217,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
require.True(t, ok, "route should exist")
// remove route again if added
err = removeVPNRoute(testCase.prefix, wgInterface.Name())
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
require.NoError(t, err, "should not return err")
}
@ -284,12 +273,6 @@ func TestIsSubRange(t *testing.T) {
}
func TestExistsInRouteTable(t *testing.T) {
_, _, err := setupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
@ -298,10 +281,19 @@ func TestExistsInRouteTable(t *testing.T) {
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
if p.Addr().Is4() && !p.Addr().IsLinkLocalUnicast() {
addressPrefixes = append(addressPrefixes, p.Masked())
if p.Addr().Is6() {
continue
}
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() {
continue
}
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
if runtime.GOOS == "linux" && p.Addr().IsLoopback() {
continue
}
addressPrefixes = append(addressPrefixes, p.Masked())
}
for _, prefix := range addressPrefixes {
@ -314,3 +306,97 @@ func TestExistsInRouteTable(t *testing.T) {
}
}
}
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
t.Helper()
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
newNet, err := stdnet.NewNet()
require.NoError(t, err)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WireGuard interface")
err = wgInterface.Create()
require.NoError(t, err, "should create testing WireGuard interface")
t.Cleanup(func() {
wgInterface.Close()
})
return wgInterface
}
func setupTestEnv(t *testing.T) {
t.Helper()
setupDummyInterfacesAndRoutes(t)
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
t.Cleanup(func() {
assert.NoError(t, wgIface.Close())
})
_, _, err := setupRouting(nil, wgIface)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
// default route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.0.0.0/8 route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.10.0.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 127.0.10.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// unique route in vpn table
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
}
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
t.Helper()
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
return
}
prefixGateway, _, err := getNextHop(prefix.Addr())
require.NoError(t, err, "getNextHop should not return err")
if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
} else {
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
}
}

View File

@ -1,101 +0,0 @@
//go:build !android && !ios
package routemanager
import (
"context"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
)
type dialer interface {
Dial(network, address string) (net.Conn, error)
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
t.Helper()
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
newNet, err := stdnet.NewNet(nil)
require.NoError(t, err)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WireGuard interface")
err = wgInterface.Create()
require.NoError(t, err, "should create testing WireGuard interface")
t.Cleanup(func() {
wgInterface.Close()
})
return wgInterface
}
func setupTestEnv(t *testing.T) {
t.Helper()
setupDummyInterfacesAndRoutes(t)
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
t.Cleanup(func() {
assert.NoError(t, wgIface.Close())
})
_, _, err := setupRouting(nil, wgIface)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
// default route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.0.0.0/8 route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.10.0.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 127.0.10.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// unique route in vpn table
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
assert.NoError(t, err, "removeVPNRoute should not return err")
})
}

View File

@ -35,7 +35,7 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
udpConn, ok := conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to closeConn connection: %v", err)
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type")
}

View File

@ -145,10 +145,19 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
udpConn, err := net.ListenUDP(network, laddr)
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)
}
connID := GenerateConnID()
return &UDPConn{UDPConn: udpConn, ID: connID, seenAddrs: &sync.Map{}}, nil
packetConn := conn.(*PacketConn)
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
if !ok {
if err := packetConn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDPConn, got different type")
}
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
}