Remove the gateway check for routes (#1317)

Most operating systems add a /32 route for the default gateway address to its routing table

This will allow routes to be configured into the system even when the incoming range contains the default gateway.

In case a range is a sub-range of an existing route and this range happens to contain the default gateway it attempts to create a default gateway route to prevent loop issues
This commit is contained in:
Maycon Santos 2023-11-24 11:31:22 +01:00 committed by GitHub
parent 5a3ee4f9c4
commit fdd23d4644
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 236 additions and 77 deletions

View File

@ -12,6 +12,8 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const minRangeBits = 7
type routerPeerStatus struct { type routerPeerStatus struct {
connected bool connected bool
relayed bool relayed bool

View File

@ -155,7 +155,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]
if !ownNetworkIDs[networkID] { if !ownNetworkIDs[networkID] {
// if prefix is too small, lets assume is a possible default route which is not yet supported // if prefix is too small, lets assume is a possible default route which is not yet supported
// we skip this route management // we skip this route management
if newRoute.Network.Bits() < 7 { if newRoute.Network.Bits() < minRangeBits {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route", log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route",
version.NetbirdVersion(), newRoute.Network) version.NetbirdVersion(), newRoute.Network)
continue continue

View File

@ -27,24 +27,24 @@ const (
RTF_MULTICAST = 0x800000 RTF_MULTICAST = 0x800000
) )
func existsInRouteTable(prefix netip.Prefix) (bool, error) { func getRoutesFromTable() ([]netip.Prefix, error) {
tab, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0) tab, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0)
if err != nil { if err != nil {
return false, err return nil, err
} }
msgs, err := route.ParseRIB(route.RIBTypeRoute, tab) msgs, err := route.ParseRIB(route.RIBTypeRoute, tab)
if err != nil { if err != nil {
return false, err return nil, err
} }
var prefixList []netip.Prefix
for _, msg := range msgs { for _, msg := range msgs {
m := msg.(*route.RouteMessage) m := msg.(*route.RouteMessage)
if m.Version < 3 || m.Version > 5 { if m.Version < 3 || m.Version > 5 {
return false, fmt.Errorf("unexpected RIB message version: %d", m.Version) return nil, fmt.Errorf("unexpected RIB message version: %d", m.Version)
} }
if m.Type != 4 /* RTM_GET */ { if m.Type != 4 /* RTM_GET */ {
return true, fmt.Errorf("unexpected RIB message type: %d", m.Type) return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type)
} }
if m.Flags&RTF_UP == 0 || if m.Flags&RTF_UP == 0 ||
@ -52,31 +52,42 @@ func existsInRouteTable(prefix netip.Prefix) (bool, error) {
continue continue
} }
dst, err := toIPAddr(m.Addrs[0]) addr, ok := toNetIPAddr(m.Addrs[0])
if err != nil { if !ok {
return true, fmt.Errorf("unexpected RIB destination: %v", err) continue
} }
mask, _ := toIPAddr(m.Addrs[2]) mask, ok := toNetIPMASK(m.Addrs[2])
cidr, _ := net.IPMask(mask.To4()).Size() if !ok {
if dst.String() == prefix.Addr().String() && cidr == prefix.Bits() { continue
return true, nil
} }
cidr, _ := mask.Size()
routePrefix := netip.PrefixFrom(addr, cidr)
if routePrefix.IsValid() {
prefixList = append(prefixList, routePrefix)
}
}
return prefixList, nil
} }
return false, nil func toNetIPAddr(a route.Addr) (netip.Addr, bool) {
}
func toIPAddr(a route.Addr) (net.IP, error) {
switch t := a.(type) { switch t := a.(type) {
case *route.Inet4Addr: case *route.Inet4Addr:
ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3]) ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
return ip, nil addr := netip.MustParseAddr(ip.String())
case *route.Inet6Addr: return addr, true
ip := make(net.IP, net.IPv6len)
copy(ip, t.IP[:])
return ip, nil
default: default:
return net.IP{}, fmt.Errorf("unknown family: %v", t) return netip.Addr{}, false
}
}
func toNetIPMASK(a route.Addr) (net.IPMask, bool) {
switch t := a.(type) {
case *route.Inet4Addr:
mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
return mask, true
default:
return nil, false
} }
} }

View File

@ -60,15 +60,26 @@ func addToRouteTable(prefix netip.Prefix, addr string) error {
return nil return nil
} }
func removeFromRouteTable(prefix netip.Prefix) error { func removeFromRouteTable(prefix netip.Prefix, addr string) error {
_, ipNet, err := net.ParseCIDR(prefix.String()) _, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil { if err != nil {
return err return err
} }
addrMask := "/32"
if prefix.Addr().Unmap().Is6() {
addrMask = "/128"
}
ip, _, err := net.ParseCIDR(addr + addrMask)
if err != nil {
return err
}
route := &netlink.Route{ route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE, Scope: netlink.SCOPE_UNIVERSE,
Dst: ipNet, Dst: ipNet,
Gw: ip,
} }
err = netlink.RouteDel(route) err = netlink.RouteDel(route)
@ -79,15 +90,16 @@ func removeFromRouteTable(prefix netip.Prefix) error {
return nil return nil
} }
func existsInRouteTable(prefix netip.Prefix) (bool, error) { func getRoutesFromTable() ([]netip.Prefix, error) {
tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC) tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
if err != nil { if err != nil {
return true, err return nil, err
} }
msgs, err := syscall.ParseNetlinkMessage(tab) msgs, err := syscall.ParseNetlinkMessage(tab)
if err != nil { if err != nil {
return true, err return nil, err
} }
var prefixList []netip.Prefix
loop: loop:
for _, m := range msgs { for _, m := range msgs {
switch m.Header.Type { switch m.Header.Type {
@ -97,7 +109,7 @@ loop:
rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0])) rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0]))
attrs, err := syscall.ParseNetlinkRouteAttr(&m) attrs, err := syscall.ParseNetlinkRouteAttr(&m)
if err != nil { if err != nil {
return true, err return nil, err
} }
if rt.Family != syscall.AF_INET { if rt.Family != syscall.AF_INET {
continue loop continue loop
@ -105,17 +117,21 @@ loop:
for _, attr := range attrs { for _, attr := range attrs {
if attr.Attr.Type == syscall.RTA_DST { if attr.Attr.Type == syscall.RTA_DST {
ip := net.IP(attr.Value) addr, ok := netip.AddrFromSlice(attr.Value)
if !ok {
continue
}
mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8) mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8)
cidr, _ := mask.Size() cidr, _ := mask.Size()
if ip.String() == prefix.Addr().String() && cidr == prefix.Bits() { routePrefix := netip.PrefixFrom(addr, cidr)
return true, nil if routePrefix.IsValid() && routePrefix.Addr().Is4() {
prefixList = append(prefixList, routePrefix)
} }
} }
} }
} }
} }
return false, nil return prefixList, nil
} }
func enableIPForwarding() error { func enableIPForwarding() error {

View File

@ -14,17 +14,6 @@ import (
var errRouteNotFound = fmt.Errorf("route not found") var errRouteNotFound = fmt.Errorf("route not found")
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
if err != nil && err != errRouteNotFound {
return err
}
gatewayIP := netip.MustParseAddr(defaultGateway.String())
if prefix.Contains(gatewayIP) {
log.Warnf("skipping adding a new route for network %s because it overlaps with the default gateway: %s", prefix, gatewayIP)
return nil
}
ok, err := existsInRouteTable(prefix) ok, err := existsInRouteTable(prefix)
if err != nil { if err != nil {
return err return err
@ -34,20 +23,82 @@ func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
return nil return nil
} }
return addToRouteTable(prefix, addr) ok, err = isSubRange(prefix)
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
addrIP := net.ParseIP(addr)
prefixGateway, err := getExistingRIBRouteGateway(prefix)
if err != nil { if err != nil {
return err return err
} }
if prefixGateway != nil && !prefixGateway.Equal(addrIP) {
log.Warnf("route for network %s is pointing to a different gateway: %s, should be pointing to: %s, not removing", prefix, prefixGateway, addrIP) if ok {
err := addRouteForCurrentDefaultGateway(prefix)
if err != nil {
log.Warnf("unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return addToRouteTable(prefix, addr)
}
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
if err != nil && err != errRouteNotFound {
return err
}
addr := netip.MustParseAddr(defaultGateway.String())
if !prefix.Contains(addr) {
log.Debugf("skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix)
return nil return nil
} }
return removeFromRouteTable(prefix)
gatewayPrefix := netip.PrefixFrom(addr, 32)
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix)
if err != nil && err != errRouteNotFound {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return addToRouteTable(gatewayPrefix, gatewayHop.String())
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, err
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, err
}
for _, tableRoute := range routes {
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
return removeFromRouteTable(prefix, addr)
} }
func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {

View File

@ -24,13 +24,13 @@ func TestAddRemoveRoutes(t *testing.T) {
shouldBeRemoved bool shouldBeRemoved bool
}{ }{
{ {
name: "Should Add And Remove Route", name: "Should Add And Remove Route 100.66.120.0/24",
prefix: netip.MustParsePrefix("100.66.120.0/24"), prefix: netip.MustParsePrefix("100.66.120.0/24"),
shouldRouteToWireguard: true, shouldRouteToWireguard: true,
shouldBeRemoved: true, shouldBeRemoved: true,
}, },
{ {
name: "Should Not Add Or Remove Route", name: "Should Not Add Or Remove Route 127.0.0.1/32",
prefix: netip.MustParsePrefix("127.0.0.1/32"), prefix: netip.MustParsePrefix("127.0.0.1/32"),
shouldRouteToWireguard: false, shouldRouteToWireguard: false,
shouldBeRemoved: false, shouldBeRemoved: false,
@ -51,21 +51,23 @@ func TestAddRemoveRoutes(t *testing.T) {
require.NoError(t, err, "should create testing wireguard interface") require.NoError(t, err, "should create testing wireguard interface")
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String()) err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String())
require.NoError(t, err, "should not return err") require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
require.NoError(t, err, "should not return err") require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
if testCase.shouldRouteToWireguard { if testCase.shouldRouteToWireguard {
require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
} else { } else {
require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface") require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface")
} }
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String()) err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String())
require.NoError(t, err, "should not return err") require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err")
prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix)
require.NoError(t, err, "should not return err") require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
require.NoError(t, err) require.NoError(t, err)
@ -75,6 +77,7 @@ func TestAddRemoveRoutes(t *testing.T) {
} else { } else {
require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway") require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway")
} }
}
}) })
} }
} }
@ -215,3 +218,66 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
}) })
} }
} }
func TestExistsInRouteTable(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if p.Addr().Is4() {
addressPrefixes = append(addressPrefixes, p.Masked())
}
}
for _, prefix := range addressPrefixes {
exists, err := existsInRouteTable(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
}
if !exists {
t.Fatalf("address %s should exist in route table", prefix)
}
}
}
func TestIsSubRange(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var subRangeAddressPrefixes []netip.Prefix
var nonSubRangeAddressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
}
}
for _, prefix := range subRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if !isSubRangePrefix {
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
}
}
for _, prefix := range nonSubRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if isSubRangePrefix {
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
}
}
}

View File

@ -21,8 +21,12 @@ func addToRouteTable(prefix netip.Prefix, addr string) error {
return nil return nil
} }
func removeFromRouteTable(prefix netip.Prefix) error { func removeFromRouteTable(prefix netip.Prefix, addr string) error {
cmd := exec.Command("route", "delete", prefix.String()) args := []string{"delete", prefix.String()}
if runtime.GOOS == "darwin" {
args = append(args, addr)
}
cmd := exec.Command("route", args...)
out, err := cmd.Output() out, err := cmd.Output()
if err != nil { if err != nil {
return err return err

View File

@ -15,23 +15,32 @@ type Win32_IP4RouteTable struct {
Mask string Mask string
} }
func existsInRouteTable(prefix netip.Prefix) (bool, error) { func getRoutesFromTable() ([]netip.Prefix, error) {
var routes []Win32_IP4RouteTable var routes []Win32_IP4RouteTable
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
err := wmi.Query(query, &routes) err := wmi.Query(query, &routes)
if err != nil { if err != nil {
return true, err return nil, err
} }
var prefixList []netip.Prefix
for _, route := range routes { for _, route := range routes {
ip := net.ParseIP(route.Mask) addr, err := netip.ParseAddr(route.Destination)
ip = ip.To4() if err != nil {
mask := net.IPv4Mask(ip[0], ip[1], ip[2], ip[3]) continue
}
maskSlice := net.ParseIP(route.Mask).To4()
if maskSlice == nil {
continue
}
mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3])
cidr, _ := mask.Size() cidr, _ := mask.Size()
if route.Destination == prefix.Addr().String() && cidr == prefix.Bits() {
return true, nil routePrefix := netip.PrefixFrom(addr, cidr)
if routePrefix.IsValid() && routePrefix.Addr().Is4() {
prefixList = append(prefixList, routePrefix)
} }
} }
return false, nil return prefixList, nil
} }