diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index fda7b012f..ee98d503d 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -12,6 +12,8 @@ import ( "github.com/netbirdio/netbird/route" ) +const minRangeBits = 7 + type routerPeerStatus struct { connected bool relayed bool diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 1f812983c..479ac873f 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -155,7 +155,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string] if !ownNetworkIDs[networkID] { // if prefix is too small, lets assume is a possible default route which is not yet supported // we skip this route management - if newRoute.Network.Bits() < 7 { + if newRoute.Network.Bits() < minRangeBits { log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route", version.NetbirdVersion(), newRoute.Network) continue diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index e777ec8ec..b2da8075c 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -27,24 +27,24 @@ const ( RTF_MULTICAST = 0x800000 ) -func existsInRouteTable(prefix netip.Prefix) (bool, error) { +func getRoutesFromTable() ([]netip.Prefix, error) { tab, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0) if err != nil { - return false, err + return nil, err } msgs, err := route.ParseRIB(route.RIBTypeRoute, tab) if err != nil { - return false, err + return nil, err } - + var prefixList []netip.Prefix for _, msg := range msgs { m := msg.(*route.RouteMessage) if m.Version < 3 || m.Version > 5 { - return false, fmt.Errorf("unexpected RIB message version: %d", m.Version) + return nil, fmt.Errorf("unexpected RIB message version: %d", m.Version) } if m.Type != 4 /* RTM_GET */ { - return true, fmt.Errorf("unexpected RIB message type: %d", m.Type) + return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type) } if m.Flags&RTF_UP == 0 || @@ -52,31 +52,42 @@ func existsInRouteTable(prefix netip.Prefix) (bool, error) { continue } - dst, err := toIPAddr(m.Addrs[0]) - if err != nil { - return true, fmt.Errorf("unexpected RIB destination: %v", err) + addr, ok := toNetIPAddr(m.Addrs[0]) + if !ok { + continue } - mask, _ := toIPAddr(m.Addrs[2]) - cidr, _ := net.IPMask(mask.To4()).Size() - if dst.String() == prefix.Addr().String() && cidr == prefix.Bits() { - return true, nil + mask, ok := toNetIPMASK(m.Addrs[2]) + if !ok { + continue + } + cidr, _ := mask.Size() + + routePrefix := netip.PrefixFrom(addr, cidr) + if routePrefix.IsValid() { + prefixList = append(prefixList, routePrefix) } } - - return false, nil + return prefixList, nil } -func toIPAddr(a route.Addr) (net.IP, error) { +func toNetIPAddr(a route.Addr) (netip.Addr, bool) { switch t := a.(type) { case *route.Inet4Addr: ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3]) - return ip, nil - case *route.Inet6Addr: - ip := make(net.IP, net.IPv6len) - copy(ip, t.IP[:]) - return ip, nil + addr := netip.MustParseAddr(ip.String()) + return addr, true default: - return net.IP{}, fmt.Errorf("unknown family: %v", t) + return netip.Addr{}, false + } +} + +func toNetIPMASK(a route.Addr) (net.IPMask, bool) { + switch t := a.(type) { + case *route.Inet4Addr: + mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3]) + return mask, true + default: + return nil, false } } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index fb2938d55..b5b4f5696 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -60,15 +60,26 @@ func addToRouteTable(prefix netip.Prefix, addr string) error { return nil } -func removeFromRouteTable(prefix netip.Prefix) error { +func removeFromRouteTable(prefix netip.Prefix, addr string) error { _, ipNet, err := net.ParseCIDR(prefix.String()) if err != nil { return err } + addrMask := "/32" + if prefix.Addr().Unmap().Is6() { + addrMask = "/128" + } + + ip, _, err := net.ParseCIDR(addr + addrMask) + if err != nil { + return err + } + route := &netlink.Route{ Scope: netlink.SCOPE_UNIVERSE, Dst: ipNet, + Gw: ip, } err = netlink.RouteDel(route) @@ -79,15 +90,16 @@ func removeFromRouteTable(prefix netip.Prefix) error { return nil } -func existsInRouteTable(prefix netip.Prefix) (bool, error) { +func getRoutesFromTable() ([]netip.Prefix, error) { tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC) if err != nil { - return true, err + return nil, err } msgs, err := syscall.ParseNetlinkMessage(tab) if err != nil { - return true, err + return nil, err } + var prefixList []netip.Prefix loop: for _, m := range msgs { switch m.Header.Type { @@ -97,7 +109,7 @@ loop: rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0])) attrs, err := syscall.ParseNetlinkRouteAttr(&m) if err != nil { - return true, err + return nil, err } if rt.Family != syscall.AF_INET { continue loop @@ -105,17 +117,21 @@ loop: for _, attr := range attrs { if attr.Attr.Type == syscall.RTA_DST { - ip := net.IP(attr.Value) + addr, ok := netip.AddrFromSlice(attr.Value) + if !ok { + continue + } mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8) cidr, _ := mask.Size() - if ip.String() == prefix.Addr().String() && cidr == prefix.Bits() { - return true, nil + routePrefix := netip.PrefixFrom(addr, cidr) + if routePrefix.IsValid() && routePrefix.Addr().Is4() { + prefixList = append(prefixList, routePrefix) } } } } } - return false, nil + return prefixList, nil } func enableIPForwarding() error { diff --git a/client/internal/routemanager/systemops_nonandroid.go b/client/internal/routemanager/systemops_nonandroid.go index 3ddf72686..b229a580f 100644 --- a/client/internal/routemanager/systemops_nonandroid.go +++ b/client/internal/routemanager/systemops_nonandroid.go @@ -14,17 +14,6 @@ import ( var errRouteNotFound = fmt.Errorf("route not found") func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { - defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - if err != nil && err != errRouteNotFound { - return err - } - - gatewayIP := netip.MustParseAddr(defaultGateway.String()) - if prefix.Contains(gatewayIP) { - log.Warnf("skipping adding a new route for network %s because it overlaps with the default gateway: %s", prefix, gatewayIP) - return nil - } - ok, err := existsInRouteTable(prefix) if err != nil { return err @@ -34,20 +23,82 @@ func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { return nil } - return addToRouteTable(prefix, addr) -} - -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { - addrIP := net.ParseIP(addr) - prefixGateway, err := getExistingRIBRouteGateway(prefix) + ok, err = isSubRange(prefix) if err != nil { return err } - if prefixGateway != nil && !prefixGateway.Equal(addrIP) { - log.Warnf("route for network %s is pointing to a different gateway: %s, should be pointing to: %s, not removing", prefix, prefixGateway, addrIP) + + if ok { + err := addRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("unable to add route for current default gateway route. Will proceed without it. error: %s", err) + } + } + + return addToRouteTable(prefix, addr) +} + +func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) + if err != nil && err != errRouteNotFound { + return err + } + + addr := netip.MustParseAddr(defaultGateway.String()) + + if !prefix.Contains(addr) { + log.Debugf("skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) return nil } - return removeFromRouteTable(prefix) + + gatewayPrefix := netip.PrefixFrom(addr, 32) + + ok, err := existsInRouteTable(gatewayPrefix) + if err != nil { + return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) + } + + if ok { + log.Debugf("skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + return nil + } + + gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix) + if err != nil && err != errRouteNotFound { + return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) + } + log.Debugf("adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return addToRouteTable(gatewayPrefix, gatewayHop.String()) +} + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, err + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, err + } + for _, tableRoute := range routes { + if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} + +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { + return removeFromRouteTable(prefix, addr) } func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { diff --git a/client/internal/routemanager/systemops_nonandroid_test.go b/client/internal/routemanager/systemops_nonandroid_test.go index bb31834d1..3646dc3da 100644 --- a/client/internal/routemanager/systemops_nonandroid_test.go +++ b/client/internal/routemanager/systemops_nonandroid_test.go @@ -24,13 +24,13 @@ func TestAddRemoveRoutes(t *testing.T) { shouldBeRemoved bool }{ { - name: "Should Add And Remove Route", + name: "Should Add And Remove Route 100.66.120.0/24", prefix: netip.MustParsePrefix("100.66.120.0/24"), shouldRouteToWireguard: true, shouldBeRemoved: true, }, { - name: "Should Not Add Or Remove Route", + name: "Should Not Add Or Remove Route 127.0.0.1/32", prefix: netip.MustParsePrefix("127.0.0.1/32"), shouldRouteToWireguard: false, shouldBeRemoved: false, @@ -51,29 +51,32 @@ func TestAddRemoveRoutes(t *testing.T) { require.NoError(t, err, "should create testing wireguard interface") err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String()) - require.NoError(t, err, "should not return err") + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) - require.NoError(t, err, "should not return err") + require.NoError(t, err, "getExistingRIBRouteGateway should not return err") if testCase.shouldRouteToWireguard { require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") } else { require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface") } + exists, err := existsInRouteTable(testCase.prefix) + require.NoError(t, err, "existsInRouteTable should not return err") + if exists && testCase.shouldRouteToWireguard { + err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String()) + require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err") - err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String()) - require.NoError(t, err, "should not return err") + prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) + require.NoError(t, err, "getExistingRIBRouteGateway should not return err") - prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) - require.NoError(t, err, "should not return err") + internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) + require.NoError(t, err) - internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - require.NoError(t, err) - - if testCase.shouldBeRemoved { - require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway") - } else { - require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway") + if testCase.shouldBeRemoved { + require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway") + } else { + require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway") + } } }) } @@ -215,3 +218,66 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { }) } } + +func TestExistsInRouteTable(t *testing.T) { + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var addressPrefixes []netip.Prefix + for _, address := range addresses { + p := netip.MustParsePrefix(address.String()) + if p.Addr().Is4() { + addressPrefixes = append(addressPrefixes, p.Masked()) + } + } + + for _, prefix := range addressPrefixes { + exists, err := existsInRouteTable(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address exists in route table: ", err) + } + if !exists { + t.Fatalf("address %s should exist in route table", prefix) + } + } +} + +func TestIsSubRange(t *testing.T) { + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var subRangeAddressPrefixes []netip.Prefix + var nonSubRangeAddressPrefixes []netip.Prefix + for _, address := range addresses { + p := netip.MustParsePrefix(address.String()) + if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 { + p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1) + subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2) + nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked()) + } + } + + for _, prefix := range subRangeAddressPrefixes { + isSubRangePrefix, err := isSubRange(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address is sub-range: ", err) + } + if !isSubRangePrefix { + t.Fatalf("address %s should be sub-range of an existing route in the table", prefix) + } + } + + for _, prefix := range nonSubRangeAddressPrefixes { + isSubRangePrefix, err := isSubRange(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address is sub-range: ", err) + } + if isSubRangePrefix { + t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix) + } + } +} diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index 537042099..47bd60eb0 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -21,8 +21,12 @@ func addToRouteTable(prefix netip.Prefix, addr string) error { return nil } -func removeFromRouteTable(prefix netip.Prefix) error { - cmd := exec.Command("route", "delete", prefix.String()) +func removeFromRouteTable(prefix netip.Prefix, addr string) error { + args := []string{"delete", prefix.String()} + if runtime.GOOS == "darwin" { + args = append(args, addr) + } + cmd := exec.Command("route", args...) out, err := cmd.Output() if err != nil { return err diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index 2233748bf..309c184b9 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -15,23 +15,32 @@ type Win32_IP4RouteTable struct { Mask string } -func existsInRouteTable(prefix netip.Prefix) (bool, error) { +func getRoutesFromTable() ([]netip.Prefix, error) { var routes []Win32_IP4RouteTable query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" err := wmi.Query(query, &routes) if err != nil { - return true, err + return nil, err } + var prefixList []netip.Prefix for _, route := range routes { - ip := net.ParseIP(route.Mask) - ip = ip.To4() - mask := net.IPv4Mask(ip[0], ip[1], ip[2], ip[3]) + addr, err := netip.ParseAddr(route.Destination) + if err != nil { + continue + } + maskSlice := net.ParseIP(route.Mask).To4() + if maskSlice == nil { + continue + } + mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3]) cidr, _ := mask.Size() - if route.Destination == prefix.Addr().String() && cidr == prefix.Bits() { - return true, nil + + routePrefix := netip.PrefixFrom(addr, cidr) + if routePrefix.IsValid() && routePrefix.Addr().Is4() { + prefixList = append(prefixList, routePrefix) } } - return false, nil + return prefixList, nil }