[client] Check for fwmark support and use fallback routing if not supported (#3220)

This commit is contained in:
Viktor Liu 2025-02-11 13:09:17 +01:00 committed by GitHub
parent 44407a158a
commit 18f84f0df5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 163 additions and 56 deletions

View File

@ -362,7 +362,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
} }
func getFwmark() int { func getFwmark() int {
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() { if nbnet.AdvancedRouting() {
return nbnet.NetbirdFwmark return nbnet.NetbirdFwmark
} }
return 0 return 0

View File

@ -31,6 +31,7 @@ import (
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@ -109,6 +110,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
nbnet.Init()
backOff := &backoff.ExponentialBackOff{ backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second, InitialInterval: time.Second,
RandomizationFactor: 1, RandomizationFactor: 1,

View File

@ -113,13 +113,14 @@ func NewManager(config ManagerConfig) *DefaultManager {
disableServerRoutes: config.DisableServerRoutes, disableServerRoutes: config.DisableServerRoutes,
} }
useNoop := netstack.IsEnabled() || config.DisableClientRoutes
dm.setupRefCounters(useNoop)
// don't proceed with client routes if it is disabled // don't proceed with client routes if it is disabled
if config.DisableClientRoutes { if config.DisableClientRoutes {
return dm return dm
} }
dm.setupRefCounters()
if runtime.GOOS == "android" { if runtime.GOOS == "android" {
cr := dm.initialClientRoutes(config.InitialRoutes) cr := dm.initialClientRoutes(config.InitialRoutes)
dm.notifier.SetInitialClientRoutes(cr) dm.notifier.SetInitialClientRoutes(cr)
@ -127,7 +128,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
return dm return dm
} }
func (m *DefaultManager) setupRefCounters() { func (m *DefaultManager) setupRefCounters(useNoop bool) {
m.routeRefCounter = refcounter.New( m.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ struct{}) (struct{}, error) { func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface()) return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
@ -137,7 +138,7 @@ func (m *DefaultManager) setupRefCounters() {
}, },
) )
if netstack.IsEnabled() { if useNoop {
m.routeRefCounter = refcounter.New( m.routeRefCounter = refcounter.New(
func(netip.Prefix, struct{}) (struct{}, error) { func(netip.Prefix, struct{}) (struct{}, error) {
return struct{}{}, refcounter.ErrIgnore return struct{}{}, refcounter.ErrIgnore
@ -449,7 +450,7 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro
} }
func isRouteSupported(route *route.Route) bool { func isRouteSupported(route *route.Route) bool {
if !nbnet.CustomRoutingDisabled() || route.IsDynamic() { if netstack.IsEnabled() || !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
return true return true
} }

View File

@ -53,20 +53,6 @@ type ruleParams struct {
description string description string
} }
// isLegacy determines whether to use the legacy routing setup
func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
}
// setIsLegacy sets the legacy routing setup
func setIsLegacy(b bool) {
if b {
os.Setenv("NB_USE_LEGACY_ROUTING", "true")
} else {
os.Unsetenv("NB_USE_LEGACY_ROUTING")
}
}
func getSetupRules() []ruleParams { func getSetupRules() []ruleParams {
return []ruleParams{ return []ruleParams{
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
@ -87,7 +73,7 @@ func getSetupRules() []ruleParams {
// This table is where a default route or other specific routes received from the management server are configured, // This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity. // enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
if isLegacy() { if !nbnet.AdvancedRouting() {
log.Infof("Using legacy routing setup") log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }
@ -103,11 +89,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
rules := getSetupRules() rules := getSetupRules()
for _, rule := range rules { for _, rule := range rules {
if err := addRule(rule); err != nil { if err := addRule(rule); err != nil {
if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
setIsLegacy(true)
return r.setupRefCounter(initAddresses, stateManager)
}
return nil, nil, fmt.Errorf("%s: %w", rule.description, err) return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
} }
} }
@ -130,7 +111,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
// It systematically removes the three rules and any associated routing table entries to ensure a clean state. // It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process. // The function uses error aggregation to report any errors encountered during the cleanup process.
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
if isLegacy() { if !nbnet.AdvancedRouting() {
return r.cleanupRefCounter(stateManager) return r.cleanupRefCounter(stateManager)
} }
@ -168,7 +149,7 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
} }
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() { if !nbnet.AdvancedRouting() {
return r.genericAddVPNRoute(prefix, intf) return r.genericAddVPNRoute(prefix, intf)
} }
@ -191,7 +172,7 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
} }
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() { if !nbnet.AdvancedRouting() {
return r.genericRemoveVPNRoute(prefix, intf) return r.genericRemoveVPNRoute(prefix, intf)
} }
@ -504,7 +485,7 @@ func getAddressFamily(prefix netip.Prefix) int {
} }
func hasSeparateRouting() ([]netip.Prefix, error) { func hasSeparateRouting() ([]netip.Prefix, error) {
if isLegacy() { if !nbnet.AdvancedRouting() {
return GetRoutesFromTable() return GetRoutesFromTable()
} }
return nil, ErrRoutingIsSeparate return nil, ErrRoutingIsSeparate

View File

@ -85,6 +85,7 @@ var testCases = []testCase{
} }
func TestRouting(t *testing.T) { func TestRouting(t *testing.T) {
nbnet.Init()
for _, tc := range testCases { for _, tc := range testCases {
// todo resolve test execution on freebsd // todo resolve test execution on freebsd
if runtime.GOOS == "freebsd" { if runtime.GOOS == "freebsd" {

View File

@ -40,7 +40,6 @@ func WithCustomDialer() grpc.DialOption {
} }
} }
log.Debug("Using nbnet.NewDialer()")
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
log.Errorf("Failed to dial: %s", err) log.Errorf("Failed to dial: %s", err)

View File

@ -2,6 +2,7 @@ package net
import ( import (
"os" "os"
"strconv"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -10,20 +11,24 @@ import (
const ( const (
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
) )
// CustomRoutingDisabled returns true if custom routing is disabled.
// This will fall back to the operation mode before the exit node functionality was implemented.
// In particular exclusion routes won't be set up and all dialers and listeners will use net.Dial and net.Listen, respectively.
func CustomRoutingDisabled() bool { func CustomRoutingDisabled() bool {
if netstack.IsEnabled() { if netstack.IsEnabled() {
return true return true
} }
return os.Getenv(envDisableCustomRouting) == "true"
}
func SkipSocketMark() bool { var customRoutingDisabled bool
if skipSocketMark := os.Getenv(envSkipSocketMark); skipSocketMark == "true" { if val := os.Getenv(envDisableCustomRouting); val != "" {
log.Infof("%s is set to true, skipping SO_MARK", envSkipSocketMark) var err error
return true customRoutingDisabled, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableCustomRouting, err)
}
} }
return false
return customRoutingDisabled
} }

12
util/net/env_generic.go Normal file
View File

@ -0,0 +1,12 @@
//go:build !linux || android
package net
func Init() {
// nothing to do on non-linux
}
func AdvancedRouting() bool {
// non-linux currently doesn't support advanced routing
return false
}

119
util/net/env_linux.go Normal file
View File

@ -0,0 +1,119 @@
//go:build linux && !android
package net
import (
"errors"
"os"
"strconv"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/iface/netstack"
)
const (
// these have the same effect, skip socket env supported for backward compatibility
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
envUseLegacyRouting = "NB_USE_LEGACY_ROUTING"
)
var advancedRoutingSupported bool
func Init() {
advancedRoutingSupported = checkAdvancedRoutingSupport()
}
func AdvancedRouting() bool {
return advancedRoutingSupported
}
func checkAdvancedRoutingSupport() bool {
var err error
var legacyRouting bool
if val := os.Getenv(envUseLegacyRouting); val != "" {
legacyRouting, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
}
}
var skipSocketMark bool
if val := os.Getenv(envSkipSocketMark); val != "" {
skipSocketMark, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envSkipSocketMark, err)
}
}
// requested to disable advanced routing
if legacyRouting || skipSocketMark ||
// envCustomRoutingDisabled disables the custom dialers.
// There is no point in using advanced routing without those, as they set up fwmarks on the sockets.
CustomRoutingDisabled() ||
// netstack mode doesn't need routing at all
netstack.IsEnabled() {
log.Info("advanced routing has been requested to be disabled")
return false
}
if !CheckFwmarkSupport() || !CheckRuleOperationsSupport() {
log.Warn("system doesn't support required routing features, falling back to legacy routing")
return false
}
log.Info("system supports advanced routing")
return true
}
func CheckFwmarkSupport() bool {
// temporarily enable advanced routing to check fwmarks are supported
old := advancedRoutingSupported
advancedRoutingSupported = true
defer func() {
advancedRoutingSupported = old
}()
dialer := NewDialer()
dialer.Timeout = 100 * time.Millisecond
conn, err := dialer.Dial("udp", "127.0.0.1:9")
if err != nil {
log.Warnf("failed to dial with fwmark: %v", err)
return false
}
if err := conn.Close(); err != nil {
log.Warnf("failed to close connection: %v", err)
}
return true
}
func CheckRuleOperationsSupport() bool {
rule := netlink.NewRule()
// low precedence, semi-random
rule.Priority = 32321
rule.Table = syscall.RT_TABLE_MAIN
rule.Family = netlink.FAMILY_V4
if err := netlink.RuleAdd(rule); err != nil {
if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warn("IP rule operations are not supported")
return false
}
log.Warnf("failed to test rule support: %v", err)
return false
}
if err := netlink.RuleDel(rule); err != nil {
log.Warnf("failed to delete test rule: %v", err)
}
return true
}

View File

@ -5,13 +5,11 @@ package net
import ( import (
"fmt" "fmt"
"syscall" "syscall"
log "github.com/sirupsen/logrus"
) )
// SetSocketMark sets the SO_MARK option on the given socket connection // SetSocketMark sets the SO_MARK option on the given socket connection
func SetSocketMark(conn syscall.Conn) error { func SetSocketMark(conn syscall.Conn) error {
if isSocketMarkDisabled() { if !AdvancedRouting() {
return nil return nil
} }
@ -25,7 +23,7 @@ func SetSocketMark(conn syscall.Conn) error {
// SetSocketOpt sets the SO_MARK option on the given file descriptor // SetSocketOpt sets the SO_MARK option on the given file descriptor
func SetSocketOpt(fd int) error { func SetSocketOpt(fd int) error {
if isSocketMarkDisabled() { if !AdvancedRouting() {
return nil return nil
} }
@ -36,7 +34,7 @@ func setRawSocketMark(conn syscall.RawConn) error {
var setErr error var setErr error
err := conn.Control(func(fd uintptr) { err := conn.Control(func(fd uintptr) {
if isSocketMarkDisabled() { if !AdvancedRouting() {
return return
} }
setErr = setSocketOptInt(int(fd)) setErr = setSocketOptInt(int(fd))
@ -55,15 +53,3 @@ func setRawSocketMark(conn syscall.RawConn) error {
func setSocketOptInt(fd int) error { func setSocketOptInt(fd int) error {
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
} }
func isSocketMarkDisabled() bool {
if CustomRoutingDisabled() {
log.Infof("Custom routing is disabled, skipping SO_MARK")
return true
}
if SkipSocketMark() {
return true
}
return false
}