Allow candidates on local routes if more specific than vpn routes (#2097)

This commit is contained in:
Viktor Liu 2024-06-17 09:47:17 +02:00 committed by GitHub
parent 245b086646
commit c9f3854dde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 227 additions and 37 deletions

View File

@ -42,6 +42,7 @@ import (
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto" sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
) )
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@ -105,8 +106,8 @@ type Engine struct {
// peerConns is a map that holds all the peers that are known to this peer // peerConns is a map that holds all the peers that are known to this peer
peerConns map[string]*peer.Conn peerConns map[string]*peer.Conn
beforePeerHook peer.BeforeAddPeerHookFunc beforePeerHook nbnet.AddHookFunc
afterPeerHook peer.AfterRemovePeerHookFunc afterPeerHook nbnet.RemoveHookFunc
// rpManager is a Rosenpass manager // rpManager is a Rosenpass manager
rpManager *rosenpass.Manager rpManager *rosenpass.Manager

View File

@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
@ -99,9 +100,6 @@ type IceCredentials struct {
Pwd string Pwd string
} }
type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error
type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error
type Conn struct { type Conn struct {
config ConnConfig config ConnConfig
mu sync.Mutex mu sync.Mutex
@ -136,8 +134,8 @@ type Conn struct {
sentExtraSrflx bool sentExtraSrflx bool
connID nbnet.ConnectionID connID nbnet.ConnectionID
beforeAddPeerHooks []BeforeAddPeerHookFunc beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []AfterRemovePeerHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc
} }
// GetConf returns the connection config // GetConf returns the connection config
@ -380,11 +378,11 @@ func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay return candidate.Type() == ice.CandidateTypeRelay
} }
func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) { func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
} }
func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) { func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
} }
@ -801,10 +799,10 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
} }
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool { func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
var routePrefixes []netip.Prefix var vpnRoutes []netip.Prefix
for _, routes := range clientRoutes { for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil { if len(routes) > 0 && routes[0] != nil {
routePrefixes = append(routePrefixes, routes[0].Network) vpnRoutes = append(vpnRoutes, routes[0].Network)
} }
} }
@ -814,16 +812,10 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool
return false return false
} }
for _, prefix := range routePrefixes { if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
// default route is log.Debugf("Ignoring candidate [%s], its address is routed to network %s", candidate.String(), prefix)
if prefix.Bits() == 0 { return true
continue
}
if prefix.Contains(addr) {
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
return true
}
} }
return false return false
} }

View File

@ -28,7 +28,7 @@ import (
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelection(route.HAMap) TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector GetRouteSelector() *routeselector.RouteSelector
@ -113,7 +113,7 @@ func NewManager(
} }
// Init sets up the routing // Init sets up the routing
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
if nbnet.CustomRoutingDisabled() { if nbnet.CustomRoutingDisabled() {
return nil, nil, nil return nil, nil, nil
} }

View File

@ -6,10 +6,10 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util/net"
) )
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
@ -20,7 +20,7 @@ type MockManager struct {
StopFunc func() StopFunc func()
} }
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
return nil, nil, nil return nil, nil, nil
} }

View File

@ -16,7 +16,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
@ -29,7 +28,9 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
func (r *SysOps) setupRefCounter(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
log.Errorf("Unable to get initial v4 default next hop: %v", err) log.Errorf("Unable to get initial v4 default next hop: %v", err)
@ -273,7 +274,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
return r.removeFromRouteTable(prefix, nextHop) return r.removeFromRouteTable(prefix, nextHop)
} }
func (r *SysOps) setupHooks(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip) prefix, err := util.GetPrefixFromIP(ip)
if err != nil { if err != nil {
@ -414,3 +415,58 @@ func isSubRange(prefix netip.Prefix) (bool, error) {
} }
return false, nil return false, nil
} }
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
localRoutes, err := hasSeparateRouting()
if err != nil {
if !errors.Is(err, ErrRoutingIsSeparate) {
log.Errorf("Failed to get routes: %v", err)
}
return false, netip.Prefix{}
}
return isVpnRoute(addr, vpnRoutes, localRoutes)
}
func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.Prefix) (bool, netip.Prefix) {
vpnPrefixMap := map[netip.Prefix]struct{}{}
for _, prefix := range vpnRoutes {
vpnPrefixMap[prefix] = struct{}{}
}
// remove vpnRoute duplicates
for _, prefix := range localRoutes {
delete(vpnPrefixMap, prefix)
}
var longestPrefix netip.Prefix
var isVpn bool
combinedRoutes := make([]netip.Prefix, len(vpnRoutes)+len(localRoutes))
copy(combinedRoutes, vpnRoutes)
copy(combinedRoutes[len(vpnRoutes):], localRoutes)
for _, prefix := range combinedRoutes {
// Ignore the default route, it has special handling
if prefix.Bits() == 0 {
continue
}
if prefix.Contains(addr) {
// Longest prefix match
if !longestPrefix.IsValid() || prefix.Bits() > longestPrefix.Bits() {
longestPrefix = prefix
_, isVpn = vpnPrefixMap[prefix]
}
}
}
if !longestPrefix.IsValid() {
// No route matched
return false, netip.Prefix{}
}
// Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix
}

View File

@ -420,3 +420,125 @@ func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIf
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP") assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
} }
} }
func TestIsVpnRoute(t *testing.T) {
tests := []struct {
name string
addr string
vpnRoutes []string
localRoutes []string
expectedVpn bool
expectedPrefix netip.Prefix
}{
{
name: "Match in VPN routes",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Match in local routes",
addr: "10.1.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
},
{
name: "No match",
addr: "172.16.0.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Default route ignored",
addr: "192.168.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Default route matches but ignored",
addr: "172.16.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Longest prefix match local",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.0.0/16"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match local multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
},
{
name: "Longest prefix match vpn",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.0.0/16"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match vpn multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
},
{
name: "Duplicate prefix in both",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := netip.ParseAddr(tt.addr)
if err != nil {
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
}
var vpnRoutes, localRoutes []netip.Prefix
for _, route := range tt.vpnRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
}
vpnRoutes = append(vpnRoutes, prefix)
}
for _, route := range tt.localRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse local route %s: %v", route, err)
}
localRoutes = append(localRoutes, prefix)
}
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
})
}
}

View File

@ -16,7 +16,6 @@ import (
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl" "github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@ -86,7 +85,7 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured, // This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity. // enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
if isLegacy() { if isLegacy() {
log.Infof("Using legacy routing setup") log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses) return r.setupRefCounter(initAddresses)
@ -502,3 +501,10 @@ func getAddressFamily(prefix netip.Prefix) int {
} }
return netlink.FAMILY_V6 return netlink.FAMILY_V6
} }
func hasSeparateRouting() ([]netip.Prefix, error) {
if isLegacy() {
return getRoutesFromTable()
}
return nil, ErrRoutingIsSeparate
}

View File

@ -9,10 +9,10 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer" nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting([]net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return nil, nil, nil return nil, nil, nil
} }
@ -32,3 +32,7 @@ func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil return nil
} }
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
return false, netip.Prefix{}
}

View File

@ -22,3 +22,7 @@ func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil return nil
} }
func hasSeparateRouting() ([]netip.Prefix, error) {
return getRoutesFromTable()
}

View File

@ -13,10 +13,10 @@ import (
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer" nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return r.setupRefCounter(initAddresses) return r.setupRefCounter(initAddresses)
} }

View File

@ -17,7 +17,7 @@ import (
"github.com/yusufpapurcu/wmi" "github.com/yusufpapurcu/wmi"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/peer" nbnet "github.com/netbirdio/netbird/util/net"
) )
type MSFT_NetRoute struct { type MSFT_NetRoute struct {
@ -56,7 +56,7 @@ var prefixList []netip.Prefix
var lastUpdate time.Time var lastUpdate time.Time
var mux = sync.Mutex{} var mux = sync.Mutex{}
func (r *SysOps) SetupRouting(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return r.setupRefCounter(initAddresses) return r.setupRefCounter(initAddresses)
} }

View File

@ -1,9 +1,11 @@
package net package net
import ( import (
"github.com/netbirdio/netbird/iface/netstack" "net"
"os" "os"
"github.com/netbirdio/netbird/iface/netstack"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -18,6 +20,9 @@ const (
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. // It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
type ConnectionID string type ConnectionID string
type AddHookFunc func(connID ConnectionID, IP net.IP) error
type RemoveHookFunc func(connID ConnectionID) error
// GenerateConnID generates a unique identifier for each connection. // GenerateConnID generates a unique identifier for each connection.
func GenerateConnID() ConnectionID { func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString()) return ConnectionID(uuid.NewString())