mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-12 18:00:49 +01:00
fd67892cb4
Refactor the flat code structure
509 lines
15 KiB
Go
509 lines
15 KiB
Go
//go:build !android && !ios
|
|
|
|
package systemops
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"runtime"
|
|
"strconv"
|
|
|
|
"github.com/hashicorp/go-multierror"
|
|
"github.com/libp2p/go-netroute"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
"github.com/netbirdio/netbird/client/iface"
|
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
|
nbnet "github.com/netbirdio/netbird/util/net"
|
|
)
|
|
|
|
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
|
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
|
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
|
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
|
|
|
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
|
|
|
func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
|
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
|
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
|
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
|
}
|
|
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
|
|
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
|
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
|
}
|
|
|
|
refCounter := refcounter.New(
|
|
func(prefix netip.Prefix, _ struct{}) (Nexthop, error) {
|
|
initialNexthop := initialNextHopV4
|
|
if prefix.Addr().Is6() {
|
|
initialNexthop = initialNextHopV6
|
|
}
|
|
|
|
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
|
|
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
|
|
log.Tracef("Adding for prefix %s: %v", prefix, err)
|
|
// These errors are not critical, but also we should not track and try to remove the routes either.
|
|
return nexthop, refcounter.ErrIgnore
|
|
}
|
|
return nexthop, err
|
|
},
|
|
r.removeFromRouteTable,
|
|
)
|
|
|
|
r.refCounter = refCounter
|
|
|
|
return r.setupHooks(initAddresses)
|
|
}
|
|
|
|
func (r *SysOps) cleanupRefCounter() error {
|
|
if r.refCounter == nil {
|
|
return nil
|
|
}
|
|
|
|
// TODO: Remove hooks selectively
|
|
nbnet.RemoveDialerHooks()
|
|
nbnet.RemoveListenerHooks()
|
|
|
|
if err := r.refCounter.Flush(); err != nil {
|
|
return fmt.Errorf("flush route manager: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// TODO: fix: for default our wg address now appears as the default gw
|
|
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
|
addr := netip.IPv4Unspecified()
|
|
if prefix.Addr().Is6() {
|
|
addr = netip.IPv6Unspecified()
|
|
}
|
|
|
|
nexthop, err := GetNextHop(addr)
|
|
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
|
return fmt.Errorf("get existing route gateway: %s", err)
|
|
}
|
|
|
|
if !prefix.Contains(nexthop.IP) {
|
|
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
|
|
return nil
|
|
}
|
|
|
|
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
|
|
if nexthop.IP.Is6() {
|
|
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
|
|
}
|
|
|
|
ok, err := existsInRouteTable(gatewayPrefix)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
|
}
|
|
|
|
if ok {
|
|
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
|
return nil
|
|
}
|
|
|
|
nexthop, err = GetNextHop(nexthop.IP)
|
|
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
|
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
|
}
|
|
|
|
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
|
|
return r.addToRouteTable(gatewayPrefix, nexthop)
|
|
}
|
|
|
|
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
|
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
|
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIface, initialNextHop Nexthop) (Nexthop, error) {
|
|
addr := prefix.Addr()
|
|
switch {
|
|
case addr.IsLoopback(),
|
|
addr.IsLinkLocalUnicast(),
|
|
addr.IsLinkLocalMulticast(),
|
|
addr.IsInterfaceLocalMulticast(),
|
|
addr.IsUnspecified(),
|
|
addr.IsMulticast():
|
|
|
|
return Nexthop{}, vars.ErrRouteNotAllowed
|
|
}
|
|
|
|
// Check if the prefix is part of any local subnets
|
|
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
|
|
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
|
|
}
|
|
|
|
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
|
nexthop, err := GetNextHop(addr)
|
|
if err != nil {
|
|
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
|
|
}
|
|
|
|
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP)
|
|
exitNextHop := Nexthop{
|
|
IP: nexthop.IP,
|
|
Intf: nexthop.Intf,
|
|
}
|
|
|
|
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
|
if !ok {
|
|
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
|
|
}
|
|
|
|
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
|
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
|
|
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
|
|
|
|
exitNextHop = initialNextHop
|
|
}
|
|
|
|
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop.IP)
|
|
if err := r.addToRouteTable(prefix, exitNextHop); err != nil {
|
|
return Nexthop{}, fmt.Errorf("add route to table: %w", err)
|
|
}
|
|
|
|
return exitNextHop, nil
|
|
}
|
|
|
|
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
|
|
localInterfaces, err := net.Interfaces()
|
|
if err != nil {
|
|
log.Errorf("Failed to get local interfaces: %v", err)
|
|
return false, nil
|
|
}
|
|
|
|
for _, intf := range localInterfaces {
|
|
addrs, err := intf.Addrs()
|
|
if err != nil {
|
|
log.Errorf("Failed to get addresses for interface %s: %v", intf.Name, err)
|
|
continue
|
|
}
|
|
|
|
for _, addr := range addrs {
|
|
ipnet, ok := addr.(*net.IPNet)
|
|
if !ok {
|
|
log.Errorf("Failed to convert address to IPNet: %v", addr)
|
|
continue
|
|
}
|
|
|
|
if ipnet.Contains(prefix.Addr().AsSlice()) {
|
|
return true, ipnet
|
|
}
|
|
}
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
|
// in two /1 prefixes to avoid replacing the existing default route
|
|
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
|
nextHop := Nexthop{netip.Addr{}, intf}
|
|
|
|
if prefix == vars.Defaultv4 {
|
|
if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil {
|
|
return err
|
|
}
|
|
if err := r.addToRouteTable(splitDefaultv4_2, nextHop); err != nil {
|
|
if err2 := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err2 != nil {
|
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// TODO: remove once IPv6 is supported on the interface
|
|
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
|
return fmt.Errorf("add unreachable route split 1: %w", err)
|
|
}
|
|
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
|
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
|
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
|
}
|
|
return fmt.Errorf("add unreachable route split 2: %w", err)
|
|
}
|
|
|
|
return nil
|
|
} else if prefix == vars.Defaultv6 {
|
|
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
|
return fmt.Errorf("add unreachable route split 1: %w", err)
|
|
}
|
|
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
|
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
|
|
log.Warnf("Failed to rollback route addition: %s", err2)
|
|
}
|
|
return fmt.Errorf("add unreachable route split 2: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
return r.addNonExistingRoute(prefix, intf)
|
|
}
|
|
|
|
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
|
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
|
|
ok, err := existsInRouteTable(prefix)
|
|
if err != nil {
|
|
return fmt.Errorf("exists in route table: %w", err)
|
|
}
|
|
if ok {
|
|
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
|
return nil
|
|
}
|
|
|
|
ok, err = isSubRange(prefix)
|
|
if err != nil {
|
|
return fmt.Errorf("sub range: %w", err)
|
|
}
|
|
|
|
if ok {
|
|
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
|
|
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
|
}
|
|
}
|
|
|
|
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
|
|
}
|
|
|
|
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
|
// it will remove the split /1 prefixes
|
|
func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
|
nextHop := Nexthop{netip.Addr{}, intf}
|
|
|
|
if prefix == vars.Defaultv4 {
|
|
var result *multierror.Error
|
|
if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil {
|
|
result = multierror.Append(result, err)
|
|
}
|
|
if err := r.removeFromRouteTable(splitDefaultv4_2, nextHop); err != nil {
|
|
result = multierror.Append(result, err)
|
|
}
|
|
|
|
// TODO: remove once IPv6 is supported on the interface
|
|
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
|
result = multierror.Append(result, err)
|
|
}
|
|
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
|
result = multierror.Append(result, err)
|
|
}
|
|
|
|
return nberrors.FormatErrorOrNil(result)
|
|
} else if prefix == vars.Defaultv6 {
|
|
var result *multierror.Error
|
|
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
|
result = multierror.Append(result, err)
|
|
}
|
|
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
|
result = multierror.Append(result, err)
|
|
}
|
|
|
|
return nberrors.FormatErrorOrNil(result)
|
|
}
|
|
|
|
return r.removeFromRouteTable(prefix, nextHop)
|
|
}
|
|
|
|
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
|
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
|
prefix, err := util.GetPrefixFromIP(ip)
|
|
if err != nil {
|
|
return fmt.Errorf("convert ip to prefix: %w", err)
|
|
}
|
|
|
|
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
|
return fmt.Errorf("adding route reference: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
afterHook := func(connID nbnet.ConnectionID) error {
|
|
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
|
|
return fmt.Errorf("remove route reference: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
for _, ip := range initAddresses {
|
|
if err := beforeHook("init", ip); err != nil {
|
|
log.Errorf("Failed to add route reference: %v", err)
|
|
}
|
|
}
|
|
|
|
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
|
|
var result *multierror.Error
|
|
for _, ip := range resolvedIPs {
|
|
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
|
}
|
|
return nberrors.FormatErrorOrNil(result)
|
|
})
|
|
|
|
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
|
return afterHook(connID)
|
|
})
|
|
|
|
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
|
return beforeHook(connID, ip.IP)
|
|
})
|
|
|
|
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
|
return afterHook(connID)
|
|
})
|
|
|
|
return beforeHook, afterHook, nil
|
|
}
|
|
|
|
func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
|
r, err := netroute.New()
|
|
if err != nil {
|
|
return Nexthop{}, fmt.Errorf("new netroute: %w", err)
|
|
}
|
|
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
|
if err != nil {
|
|
log.Debugf("Failed to get route for %s: %v", ip, err)
|
|
return Nexthop{}, vars.ErrRouteNotFound
|
|
}
|
|
|
|
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
|
if gateway == nil {
|
|
if runtime.GOOS == "freebsd" {
|
|
return Nexthop{Intf: intf}, nil
|
|
}
|
|
|
|
if preferredSrc == nil {
|
|
return Nexthop{}, vars.ErrRouteNotFound
|
|
}
|
|
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
|
|
|
|
addr, err := ipToAddr(preferredSrc, intf)
|
|
if err != nil {
|
|
return Nexthop{}, fmt.Errorf("convert preferred source to address: %w", err)
|
|
}
|
|
return Nexthop{
|
|
IP: addr,
|
|
Intf: intf,
|
|
}, nil
|
|
}
|
|
|
|
addr, err := ipToAddr(gateway, intf)
|
|
if err != nil {
|
|
return Nexthop{}, fmt.Errorf("convert gateway to address: %w", err)
|
|
}
|
|
|
|
return Nexthop{
|
|
IP: addr,
|
|
Intf: intf,
|
|
}, nil
|
|
}
|
|
|
|
// converts a net.IP to a netip.Addr including the zone based on the passed interface
|
|
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
|
addr, ok := netip.AddrFromSlice(ip)
|
|
if !ok {
|
|
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
|
|
}
|
|
|
|
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
|
|
zone := intf.Name
|
|
if runtime.GOOS == "windows" {
|
|
zone = strconv.Itoa(intf.Index)
|
|
}
|
|
log.Tracef("Adding zone %s to address %s", zone, addr)
|
|
addr = addr.WithZone(zone)
|
|
}
|
|
|
|
return addr.Unmap(), nil
|
|
}
|
|
|
|
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
|
routes, err := GetRoutesFromTable()
|
|
if err != nil {
|
|
return false, fmt.Errorf("get routes from table: %w", err)
|
|
}
|
|
for _, tableRoute := range routes {
|
|
if tableRoute == prefix {
|
|
return true, nil
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func isSubRange(prefix netip.Prefix) (bool, error) {
|
|
routes, err := GetRoutesFromTable()
|
|
if err != nil {
|
|
return false, fmt.Errorf("get routes from table: %w", err)
|
|
}
|
|
for _, tableRoute := range routes {
|
|
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
|
return true, nil
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
|
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
|
localRoutes, err := hasSeparateRouting()
|
|
if err != nil {
|
|
if !errors.Is(err, ErrRoutingIsSeparate) {
|
|
log.Errorf("Failed to get routes: %v", err)
|
|
}
|
|
return false, netip.Prefix{}
|
|
}
|
|
|
|
return isVpnRoute(addr, vpnRoutes, localRoutes)
|
|
}
|
|
|
|
func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.Prefix) (bool, netip.Prefix) {
|
|
vpnPrefixMap := map[netip.Prefix]struct{}{}
|
|
for _, prefix := range vpnRoutes {
|
|
vpnPrefixMap[prefix] = struct{}{}
|
|
}
|
|
|
|
// remove vpnRoute duplicates
|
|
for _, prefix := range localRoutes {
|
|
delete(vpnPrefixMap, prefix)
|
|
}
|
|
|
|
var longestPrefix netip.Prefix
|
|
var isVpn bool
|
|
|
|
combinedRoutes := make([]netip.Prefix, len(vpnRoutes)+len(localRoutes))
|
|
copy(combinedRoutes, vpnRoutes)
|
|
copy(combinedRoutes[len(vpnRoutes):], localRoutes)
|
|
|
|
for _, prefix := range combinedRoutes {
|
|
// Ignore the default route, it has special handling
|
|
if prefix.Bits() == 0 {
|
|
continue
|
|
}
|
|
|
|
if prefix.Contains(addr) {
|
|
// Longest prefix match
|
|
if !longestPrefix.IsValid() || prefix.Bits() > longestPrefix.Bits() {
|
|
longestPrefix = prefix
|
|
_, isVpn = vpnPrefixMap[prefix]
|
|
}
|
|
}
|
|
}
|
|
|
|
if !longestPrefix.IsValid() {
|
|
// No route matched
|
|
return false, netip.Prefix{}
|
|
}
|
|
|
|
// Return true if the longest matching prefix is from vpnRoutes
|
|
return isVpn, longestPrefix
|
|
}
|