[client] Enable userspace forwarder conditionally (#3309)

* Enable userspace forwarder conditionally

* Move disable/enable logic
This commit is contained in:
Viktor Liu 2025-02-12 11:10:49 +01:00 committed by GitHub
parent 18f84f0df5
commit b41de7fcd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 145 additions and 54 deletions

View File

@ -218,6 +218,14 @@ func (m *Manager) SetLogLevel(log.Level) {
// not supported
}
func (m *Manager) EnableRouting() error {
return nil
}
func (m *Manager) DisableRouting() error {
return nil
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@ -101,6 +101,10 @@ type Manager interface {
Flush() error
SetLogLevel(log.Level)
EnableRouting() error
DisableRouting() error
}
func GenKey(format string, pair RouterPair) string {

View File

@ -323,6 +323,14 @@ func (m *Manager) SetLogLevel(log.Level) {
// not supported
}
func (m *Manager) EnableRouting() error {
return nil
}
func (m *Manager) DisableRouting() error {
return nil
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets

View File

@ -74,6 +74,8 @@ type Manager struct {
mutex sync.RWMutex
// indicates whether server routes are disabled
disableServerRoutes bool
// indicates whether we forward packets not destined for ourselves
routingEnabled bool
// indicates whether we leave forwarding and filtering to the native firewall
@ -125,15 +127,27 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
return mgr, nil
}
func parseCreateEnv() (bool, bool) {
var disableConntrack, enableLocalForwarding bool
var err error
if val := os.Getenv(EnvDisableConntrack); val != "" {
disableConntrack, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
}
}
if val := os.Getenv(EnvEnableNetstackLocalForwarding); val != "" {
enableLocalForwarding, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
}
}
return disableConntrack, enableLocalForwarding
}
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
disableConntrack, err := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
}
enableLocalForwarding, err := strconv.ParseBool(os.Getenv(EnvEnableNetstackLocalForwarding))
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
}
disableConntrack, enableLocalForwarding := parseCreateEnv()
m := &Manager{
decoders: sync.Pool{
@ -149,15 +163,16 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
return d
},
},
nativeFirewall: nativeFirewall,
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
wgIface: iface,
localipmanager: newLocalIPManager(),
routingEnabled: false,
stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()),
netstack: netstack.IsEnabled(),
nativeFirewall: nativeFirewall,
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
wgIface: iface,
localipmanager: newLocalIPManager(),
disableServerRoutes: disableServerRoutes,
routingEnabled: false,
stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()),
netstack: netstack.IsEnabled(),
// default true for non-netstack, for netstack only if explicitly enabled
localForwarding: !netstack.IsEnabled() || enableLocalForwarding,
}
@ -166,7 +181,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
return nil, fmt.Errorf("update local IPs: %w", err)
}
// Only initialize trackers if stateful mode is enabled
if disableConntrack {
log.Info("conntrack is disabled")
} else {
@ -175,7 +189,12 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}
m.determineRouting(iface, disableServerRoutes)
// netstack needs the forwarder for local traffic
if m.netstack && m.localForwarding {
if err := m.initForwarder(); err != nil {
log.Errorf("failed to initialize forwarder: %v", err)
}
}
if err := m.blockInvalidRouted(iface); err != nil {
log.Errorf("failed to block invalid routed traffic: %v", err)
@ -213,9 +232,21 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
return nil
}
func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes bool) {
disableUspRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting))
forceUserspaceRouter, _ := strconv.ParseBool(os.Getenv(EnvForceUserspaceRouter))
func (m *Manager) determineRouting() error {
var disableUspRouting, forceUserspaceRouter bool
var err error
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
disableUspRouting, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
}
}
if val := os.Getenv(EnvForceUserspaceRouter); val != "" {
forceUserspaceRouter, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err)
}
}
switch {
case disableUspRouting:
@ -223,7 +254,7 @@ func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes
m.nativeRouter = false
log.Info("userspace routing is disabled")
case disableServerRoutes:
case m.disableServerRoutes:
// if server routes are disabled we will let packets pass to the native stack
m.routingEnabled = true
m.nativeRouter = true
@ -252,32 +283,37 @@ func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes
log.Info("userspace routing enabled by default")
}
// netstack needs the forwarder for local traffic
if m.netstack && m.localForwarding ||
m.routingEnabled && !m.nativeRouter {
m.initForwarder(iface)
if m.routingEnabled && !m.nativeRouter {
return m.initForwarder()
}
return nil
}
// initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder(iface common.IFaceMapper) {
// Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := iface.GetWGDevice()
if intf == nil {
log.Info("forwarding not supported")
m.routingEnabled = false
return
func (m *Manager) initForwarder() error {
if m.forwarder != nil {
return nil
}
forwarder, err := forwarder.New(iface, m.logger, m.netstack)
if err != nil {
log.Errorf("failed to create forwarder: %v", err)
// Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := m.wgIface.GetWGDevice()
if intf == nil {
m.routingEnabled = false
return
return errors.New("forwarding not supported")
}
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
if err != nil {
m.routingEnabled = false
return fmt.Errorf("create forwarder: %w", err)
}
m.forwarder = forwarder
log.Debug("forwarder initialized")
return nil
}
func (m *Manager) Init(*statemanager.Manager) error {
@ -285,7 +321,7 @@ func (m *Manager) Init(*statemanager.Manager) error {
}
func (m *Manager) IsServerRouteSupported() bool {
return m.nativeFirewall != nil || m.routingEnabled && m.forwarder != nil
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
@ -586,7 +622,6 @@ func (m *Manager) dropFilter(packetData []byte) bool {
defer m.decoders.Put(d)
if !m.isValidPacket(d, packetData) {
m.logger.Trace("Invalid packet structure")
return true
}
@ -658,11 +693,9 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
return false
}
// Get protocol and ports for route ACL check
proto := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
// Check route ACLs
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
srcIP, srcPort, dstIP, dstPort, proto)
@ -704,12 +737,12 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
log.Tracef("couldn't decode layer, err: %s", err)
m.logger.Trace("couldn't decode packet, err: %s", err)
return false
}
if len(d.decoded) < 2 {
log.Tracef("not enough levels in network packet")
m.logger.Trace("packet doesn't have network and transport layers")
return false
}
return true
@ -953,3 +986,34 @@ func (m *Manager) SetLogLevel(level log.Level) {
m.logger.SetLevel(nblog.Level(level))
}
}
func (m *Manager) EnableRouting() error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.determineRouting()
}
func (m *Manager) DisableRouting() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.forwarder == nil {
return nil
}
m.routingEnabled = false
m.nativeRouter = false
// don't stop forwarder if in use by netstack
if m.netstack && m.localForwarding {
return nil
}
m.forwarder.Stop()
m.forwarder = nil
log.Debug("forwarder stopped")
return nil
}

View File

@ -303,6 +303,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
}
manager, err := Create(ifaceMock, false)
require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err)
require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled)

View File

@ -286,15 +286,15 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.OnNewRoutes(filteredClientRoutes)
}
m.clientRoutes = newClientRoutesIDMap
if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return err
}
if m.serverRouter == nil {
return nil
}
m.clientRoutes = newClientRoutesIDMap
if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil {
return fmt.Errorf("update routes: %w", err)
}
return nil
}

View File

@ -71,9 +71,15 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
}
if len(m.routes) > 0 {
err := systemops.EnableIPForwarding()
if err != nil {
return err
if err := systemops.EnableIPForwarding(); err != nil {
return fmt.Errorf("enable ip forwarding: %w", err)
}
if err := m.firewall.EnableRouting(); err != nil {
return fmt.Errorf("enable routing: %w", err)
}
} else {
if err := m.firewall.DisableRouting(); err != nil {
return fmt.Errorf("disable routing: %w", err)
}
}