mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-29 19:43:57 +01:00
415 lines
13 KiB
Go
415 lines
13 KiB
Go
//go:build !android && !ios
|
|
|
|
package routemanager
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"runtime"
|
|
"strconv"
|
|
|
|
"github.com/hashicorp/go-multierror"
|
|
"github.com/libp2p/go-netroute"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/netbirdio/netbird/client/internal/peer"
|
|
"github.com/netbirdio/netbird/iface"
|
|
nbnet "github.com/netbirdio/netbird/util/net"
|
|
)
|
|
|
|
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
|
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
|
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
|
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
|
|
|
var ErrRouteNotFound = errors.New("route not found")
|
|
var ErrRouteNotAllowed = errors.New("route not allowed")
|
|
|
|
// TODO: fix: for default our wg address now appears as the default gw
|
|
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
|
addr := netip.IPv4Unspecified()
|
|
if prefix.Addr().Is6() {
|
|
addr = netip.IPv6Unspecified()
|
|
}
|
|
|
|
defaultGateway, _, err := getNextHop(addr)
|
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
|
return fmt.Errorf("get existing route gateway: %s", err)
|
|
}
|
|
|
|
if !prefix.Contains(defaultGateway) {
|
|
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
|
|
return nil
|
|
}
|
|
|
|
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
|
|
if defaultGateway.Is6() {
|
|
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
|
|
}
|
|
|
|
ok, err := existsInRouteTable(gatewayPrefix)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
|
}
|
|
|
|
if ok {
|
|
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
|
return nil
|
|
}
|
|
|
|
gatewayHop, intf, err := getNextHop(defaultGateway)
|
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
|
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
|
}
|
|
|
|
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
|
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
|
|
}
|
|
|
|
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
|
r, err := netroute.New()
|
|
if err != nil {
|
|
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
|
|
}
|
|
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
|
if err != nil {
|
|
log.Warnf("Failed to get route for %s: %v", ip, err)
|
|
return netip.Addr{}, nil, ErrRouteNotFound
|
|
}
|
|
|
|
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
|
if gateway == nil {
|
|
if preferredSrc == nil {
|
|
return netip.Addr{}, nil, ErrRouteNotFound
|
|
}
|
|
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
|
|
|
|
addr, err := ipToAddr(preferredSrc, intf)
|
|
if err != nil {
|
|
return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err)
|
|
}
|
|
return addr.Unmap(), intf, nil
|
|
}
|
|
|
|
addr, err := ipToAddr(gateway, intf)
|
|
if err != nil {
|
|
return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err)
|
|
}
|
|
|
|
return addr, intf, nil
|
|
}
|
|
|
|
// converts a net.IP to a netip.Addr including the zone based on the passed interface
|
|
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
|
addr, ok := netip.AddrFromSlice(ip)
|
|
if !ok {
|
|
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
|
|
}
|
|
|
|
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
|
|
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
|
|
if runtime.GOOS == "windows" {
|
|
addr = addr.WithZone(strconv.Itoa(intf.Index))
|
|
} else {
|
|
addr = addr.WithZone(intf.Name)
|
|
}
|
|
}
|
|
|
|
return addr.Unmap(), nil
|
|
}
|
|
|
|
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
|
routes, err := getRoutesFromTable()
|
|
if err != nil {
|
|
return false, fmt.Errorf("get routes from table: %w", err)
|
|
}
|
|
for _, tableRoute := range routes {
|
|
if tableRoute == prefix {
|
|
return true, nil
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func isSubRange(prefix netip.Prefix) (bool, error) {
|
|
routes, err := getRoutesFromTable()
|
|
if err != nil {
|
|
return false, fmt.Errorf("get routes from table: %w", err)
|
|
}
|
|
for _, tableRoute := range routes {
|
|
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
|
return true, nil
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
|
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
|
func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
|
|
addr := prefix.Addr()
|
|
switch {
|
|
case addr.IsLoopback(),
|
|
addr.IsLinkLocalUnicast(),
|
|
addr.IsLinkLocalMulticast(),
|
|
addr.IsInterfaceLocalMulticast(),
|
|
addr.IsUnspecified(),
|
|
addr.IsMulticast():
|
|
|
|
return netip.Addr{}, nil, ErrRouteNotAllowed
|
|
}
|
|
|
|
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
|
nexthop, intf, err := getNextHop(addr)
|
|
if err != nil {
|
|
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
|
|
}
|
|
|
|
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
|
|
exitNextHop := nexthop
|
|
exitIntf := intf
|
|
|
|
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
|
if !ok {
|
|
return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
|
|
}
|
|
|
|
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
|
if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
|
|
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
|
|
exitNextHop = initialNextHop
|
|
exitIntf = initialIntf
|
|
}
|
|
|
|
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
|
|
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
|
|
return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
|
|
}
|
|
|
|
return exitNextHop, exitIntf, nil
|
|
}
|
|
|
|
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
|
// in two /1 prefixes to avoid replacing the existing default route
|
|
func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
|
if prefix == defaultv4 {
|
|
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
|
return err
|
|
}
|
|
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
|
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
|
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// TODO: remove once IPv6 is supported on the interface
|
|
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
|
return fmt.Errorf("add unreachable route split 1: %w", err)
|
|
}
|
|
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
|
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
|
}
|
|
return fmt.Errorf("add unreachable route split 2: %w", err)
|
|
}
|
|
|
|
return nil
|
|
} else if prefix == defaultv6 {
|
|
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
|
return fmt.Errorf("add unreachable route split 1: %w", err)
|
|
}
|
|
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
|
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
|
}
|
|
return fmt.Errorf("add unreachable route split 2: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
return addNonExistingRoute(prefix, intf)
|
|
}
|
|
|
|
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
|
func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
|
|
ok, err := existsInRouteTable(prefix)
|
|
if err != nil {
|
|
return fmt.Errorf("exists in route table: %w", err)
|
|
}
|
|
if ok {
|
|
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
|
return nil
|
|
}
|
|
|
|
ok, err = isSubRange(prefix)
|
|
if err != nil {
|
|
return fmt.Errorf("sub range: %w", err)
|
|
}
|
|
|
|
if ok {
|
|
err := addRouteForCurrentDefaultGateway(prefix)
|
|
if err != nil {
|
|
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
|
}
|
|
}
|
|
|
|
return addToRouteTable(prefix, netip.Addr{}, intf)
|
|
}
|
|
|
|
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
|
// it will remove the split /1 prefixes
|
|
func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
|
if prefix == defaultv4 {
|
|
var result *multierror.Error
|
|
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
|
result = multierror.Append(result, err)
|
|
}
|
|
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
|
result = multierror.Append(result, err)
|
|
}
|
|
|
|
// TODO: remove once IPv6 is supported on the interface
|
|
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
|
result = multierror.Append(result, err)
|
|
}
|
|
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
|
result = multierror.Append(result, err)
|
|
}
|
|
|
|
return result.ErrorOrNil()
|
|
} else if prefix == defaultv6 {
|
|
var result *multierror.Error
|
|
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
|
result = multierror.Append(result, err)
|
|
}
|
|
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
|
result = multierror.Append(result, err)
|
|
}
|
|
|
|
return result.ErrorOrNil()
|
|
}
|
|
|
|
return removeFromRouteTable(prefix, netip.Addr{}, intf)
|
|
}
|
|
|
|
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
|
|
addr, ok := netip.AddrFromSlice(ip)
|
|
if !ok {
|
|
return nil, fmt.Errorf("parse IP address: %s", ip)
|
|
}
|
|
addr = addr.Unmap()
|
|
|
|
var prefixLength int
|
|
switch {
|
|
case addr.Is4():
|
|
prefixLength = 32
|
|
case addr.Is6():
|
|
prefixLength = 128
|
|
default:
|
|
return nil, fmt.Errorf("invalid IP address: %s", addr)
|
|
}
|
|
|
|
prefix := netip.PrefixFrom(addr, prefixLength)
|
|
return &prefix, nil
|
|
}
|
|
|
|
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
|
initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified())
|
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
|
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
|
}
|
|
initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified())
|
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
|
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
|
}
|
|
|
|
*routeManager = NewRouteManager(
|
|
func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
|
|
addr := prefix.Addr()
|
|
nexthop, intf := initialNextHopV4, initialIntfV4
|
|
if addr.Is6() {
|
|
nexthop, intf = initialNextHopV6, initialIntfV6
|
|
}
|
|
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
|
|
},
|
|
removeFromRouteTable,
|
|
)
|
|
|
|
return setupHooks(*routeManager, initAddresses)
|
|
}
|
|
|
|
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
|
|
if routeManager == nil {
|
|
return nil
|
|
}
|
|
|
|
// TODO: Remove hooks selectively
|
|
nbnet.RemoveDialerHooks()
|
|
nbnet.RemoveListenerHooks()
|
|
|
|
if err := routeManager.Flush(); err != nil {
|
|
return fmt.Errorf("flush route manager: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
|
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
|
prefix, err := getPrefixFromIP(ip)
|
|
if err != nil {
|
|
return fmt.Errorf("convert ip to prefix: %w", err)
|
|
}
|
|
|
|
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
|
|
return fmt.Errorf("adding route reference: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
afterHook := func(connID nbnet.ConnectionID) error {
|
|
if err := routeManager.RemoveRouteRef(connID); err != nil {
|
|
return fmt.Errorf("remove route reference: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
for _, ip := range initAddresses {
|
|
if err := beforeHook("init", ip); err != nil {
|
|
log.Errorf("Failed to add route reference: %v", err)
|
|
}
|
|
}
|
|
|
|
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
|
|
var result *multierror.Error
|
|
for _, ip := range resolvedIPs {
|
|
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
|
}
|
|
return result.ErrorOrNil()
|
|
})
|
|
|
|
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
|
return afterHook(connID)
|
|
})
|
|
|
|
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
|
return beforeHook(connID, ip.IP)
|
|
})
|
|
|
|
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
|
return afterHook(connID)
|
|
})
|
|
|
|
return beforeHook, afterHook, nil
|
|
}
|