diff --git a/client/internal/engine.go b/client/internal/engine.go index fce1e162b..68c287046 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -42,6 +42,7 @@ import ( signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/util" + nbnet "github.com/netbirdio/netbird/util/net" ) // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. @@ -105,8 +106,8 @@ type Engine struct { // peerConns is a map that holds all the peers that are known to this peer peerConns map[string]*peer.Conn - beforePeerHook peer.BeforeAddPeerHookFunc - afterPeerHook peer.AfterRemovePeerHookFunc + beforePeerHook nbnet.AddHookFunc + afterPeerHook nbnet.RemoveHookFunc // rpManager is a Rosenpass manager rpManager *rosenpass.Manager diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index c64c074a7..3a38d14c1 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -15,6 +15,7 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/iface" @@ -99,9 +100,6 @@ type IceCredentials struct { Pwd string } -type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error -type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error - type Conn struct { config ConnConfig mu sync.Mutex @@ -136,8 +134,8 @@ type Conn struct { sentExtraSrflx bool connID nbnet.ConnectionID - beforeAddPeerHooks []BeforeAddPeerHookFunc - afterRemovePeerHooks []AfterRemovePeerHookFunc + beforeAddPeerHooks []nbnet.AddHookFunc + afterRemovePeerHooks []nbnet.RemoveHookFunc } // GetConf returns the connection config @@ -380,11 +378,11 @@ func isRelayCandidate(candidate ice.Candidate) bool { return candidate.Type() == ice.CandidateTypeRelay } -func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) { +func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) { conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) } -func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) { +func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) { conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) } @@ -801,10 +799,10 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive } func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool { - var routePrefixes []netip.Prefix + var vpnRoutes []netip.Prefix for _, routes := range clientRoutes { if len(routes) > 0 && routes[0] != nil { - routePrefixes = append(routePrefixes, routes[0].Network) + vpnRoutes = append(vpnRoutes, routes[0].Network) } } @@ -814,16 +812,10 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool return false } - for _, prefix := range routePrefixes { - // default route is - if prefix.Bits() == 0 { - continue - } - - if prefix.Contains(addr) { - log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix) - return true - } + if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn { + log.Debugf("Ignoring candidate [%s], its address is routed to network %s", candidate.String(), prefix) + return true } + return false } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 53943055c..150e76682 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -28,7 +28,7 @@ import ( // Manager is a route manager interface type Manager interface { - Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) + Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector @@ -113,7 +113,7 @@ func NewManager( } // Init sets up the routing -func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { +func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { if nbnet.CustomRoutingDisabled() { return nil, nil, nil } diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index adbef8061..58a66715c 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -6,10 +6,10 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" - "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/util/net" ) // MockManager is the mock instance of a route manager @@ -20,7 +20,7 @@ type MockManager struct { StopFunc func() } -func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { +func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { return nil, nil, nil } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 01b9ebda6..53bab6edf 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -16,7 +16,6 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" @@ -29,7 +28,9 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) -func (r *SysOps) setupRefCounter(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { +var ErrRoutingIsSeparate = errors.New("routing is separate") + +func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { log.Errorf("Unable to get initial v4 default next hop: %v", err) @@ -273,7 +274,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) return r.removeFromRouteTable(prefix, nextHop) } -func (r *SysOps) setupHooks(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { +func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { prefix, err := util.GetPrefixFromIP(ip) if err != nil { @@ -414,3 +415,58 @@ func isSubRange(prefix netip.Prefix) (bool, error) { } return false, nil } + +// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix. +func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) { + localRoutes, err := hasSeparateRouting() + if err != nil { + if !errors.Is(err, ErrRoutingIsSeparate) { + log.Errorf("Failed to get routes: %v", err) + } + return false, netip.Prefix{} + } + + return isVpnRoute(addr, vpnRoutes, localRoutes) +} + +func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.Prefix) (bool, netip.Prefix) { + vpnPrefixMap := map[netip.Prefix]struct{}{} + for _, prefix := range vpnRoutes { + vpnPrefixMap[prefix] = struct{}{} + } + + // remove vpnRoute duplicates + for _, prefix := range localRoutes { + delete(vpnPrefixMap, prefix) + } + + var longestPrefix netip.Prefix + var isVpn bool + + combinedRoutes := make([]netip.Prefix, len(vpnRoutes)+len(localRoutes)) + copy(combinedRoutes, vpnRoutes) + copy(combinedRoutes[len(vpnRoutes):], localRoutes) + + for _, prefix := range combinedRoutes { + // Ignore the default route, it has special handling + if prefix.Bits() == 0 { + continue + } + + if prefix.Contains(addr) { + // Longest prefix match + if !longestPrefix.IsValid() || prefix.Bits() > longestPrefix.Bits() { + longestPrefix = prefix + _, isVpn = vpnPrefixMap[prefix] + } + } + } + + if !longestPrefix.IsValid() { + // No route matched + return false, netip.Prefix{} + } + + // Return true if the longest matching prefix is from vpnRoutes + return isVpn, longestPrefix +} diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index c79b2ac64..594aaee4a 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -420,3 +420,125 @@ func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIf assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP") } } + +func TestIsVpnRoute(t *testing.T) { + tests := []struct { + name string + addr string + vpnRoutes []string + localRoutes []string + expectedVpn bool + expectedPrefix netip.Prefix + }{ + { + name: "Match in VPN routes", + addr: "192.168.1.1", + vpnRoutes: []string{"192.168.1.0/24"}, + localRoutes: []string{"10.0.0.0/8"}, + expectedVpn: true, + expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"), + }, + { + name: "Match in local routes", + addr: "10.1.1.1", + vpnRoutes: []string{"192.168.1.0/24"}, + localRoutes: []string{"10.0.0.0/8"}, + expectedVpn: false, + expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"), + }, + { + name: "No match", + addr: "172.16.0.1", + vpnRoutes: []string{"192.168.1.0/24"}, + localRoutes: []string{"10.0.0.0/8"}, + expectedVpn: false, + expectedPrefix: netip.Prefix{}, + }, + { + name: "Default route ignored", + addr: "192.168.1.1", + vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"}, + localRoutes: []string{"10.0.0.0/8"}, + expectedVpn: true, + expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"), + }, + { + name: "Default route matches but ignored", + addr: "172.16.1.1", + vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"}, + localRoutes: []string{"10.0.0.0/8"}, + expectedVpn: false, + expectedPrefix: netip.Prefix{}, + }, + { + name: "Longest prefix match local", + addr: "192.168.1.1", + vpnRoutes: []string{"192.168.0.0/16"}, + localRoutes: []string{"192.168.1.0/24"}, + expectedVpn: false, + expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"), + }, + { + name: "Longest prefix match local multiple", + addr: "192.168.0.1", + vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"}, + localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"}, + expectedVpn: false, + expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"), + }, + { + name: "Longest prefix match vpn", + addr: "192.168.1.1", + vpnRoutes: []string{"192.168.1.0/24"}, + localRoutes: []string{"192.168.0.0/16"}, + expectedVpn: true, + expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"), + }, + { + name: "Longest prefix match vpn multiple", + addr: "192.168.0.1", + vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"}, + localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"}, + expectedVpn: true, + expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"), + }, + { + name: "Duplicate prefix in both", + addr: "192.168.1.1", + vpnRoutes: []string{"192.168.1.0/24"}, + localRoutes: []string{"192.168.1.0/24"}, + expectedVpn: false, + expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr, err := netip.ParseAddr(tt.addr) + if err != nil { + t.Fatalf("Failed to parse address %s: %v", tt.addr, err) + } + + var vpnRoutes, localRoutes []netip.Prefix + for _, route := range tt.vpnRoutes { + prefix, err := netip.ParsePrefix(route) + if err != nil { + t.Fatalf("Failed to parse VPN route %s: %v", route, err) + } + vpnRoutes = append(vpnRoutes, prefix) + } + + for _, route := range tt.localRoutes { + prefix, err := netip.ParsePrefix(route) + if err != nil { + t.Fatalf("Failed to parse local route %s: %v", route, err) + } + localRoutes = append(localRoutes, prefix) + } + + isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes) + assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value") + assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix") + }) + } +} diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index bca47d1f9..c4f69fba5 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -16,7 +16,6 @@ import ( "github.com/vishvananda/netlink" nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/sysctl" "github.com/netbirdio/netbird/client/internal/routemanager/vars" nbnet "github.com/netbirdio/netbird/util/net" @@ -86,7 +85,7 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { if isLegacy() { log.Infof("Using legacy routing setup") return r.setupRefCounter(initAddresses) @@ -502,3 +501,10 @@ func getAddressFamily(prefix netip.Prefix) int { } return netlink.FAMILY_V6 } + +func hasSeparateRouting() ([]netip.Prefix, error) { + if isLegacy() { + return getRoutesFromTable() + } + return nil, ErrRoutingIsSeparate +} diff --git a/client/internal/routemanager/systemops/systemops_mobile.go b/client/internal/routemanager/systemops/systemops_mobile.go index 1517cf949..43815c657 100644 --- a/client/internal/routemanager/systemops/systemops_mobile.go +++ b/client/internal/routemanager/systemops/systemops_mobile.go @@ -9,10 +9,10 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/peer" + nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { +func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { return nil, nil, nil } @@ -32,3 +32,7 @@ func EnableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil } + +func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) { + return false, netip.Prefix{} +} diff --git a/client/internal/routemanager/systemops/systemops_nonlinux.go b/client/internal/routemanager/systemops/systemops_nonlinux.go index 0f70cd78d..0adeb0992 100644 --- a/client/internal/routemanager/systemops/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops/systemops_nonlinux.go @@ -22,3 +22,7 @@ func EnableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil } + +func hasSeparateRouting() ([]netip.Prefix, error) { + return getRoutesFromTable() +} diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index d4d2a31ed..a2bbf35cf 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -13,10 +13,10 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/peer" + nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { return r.setupRefCounter(initAddresses) } diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index b869bc85b..a0b5fb1e5 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -17,7 +17,7 @@ import ( "github.com/yusufpapurcu/wmi" "github.com/netbirdio/netbird/client/firewall/uspfilter" - "github.com/netbirdio/netbird/client/internal/peer" + nbnet "github.com/netbirdio/netbird/util/net" ) type MSFT_NetRoute struct { @@ -56,7 +56,7 @@ var prefixList []netip.Prefix var lastUpdate time.Time var mux = sync.Mutex{} -func (r *SysOps) SetupRouting(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { return r.setupRefCounter(initAddresses) } diff --git a/util/net/net.go b/util/net/net.go index 30e058032..8d1fcebd0 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -1,9 +1,11 @@ package net import ( - "github.com/netbirdio/netbird/iface/netstack" + "net" "os" + "github.com/netbirdio/netbird/iface/netstack" + "github.com/google/uuid" ) @@ -18,6 +20,9 @@ const ( // It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. type ConnectionID string +type AddHookFunc func(connID ConnectionID, IP net.IP) error +type RemoveHookFunc func(connID ConnectionID) error + // GenerateConnID generates a unique identifier for each connection. func GenerateConnID() ConnectionID { return ConnectionID(uuid.NewString())