[client] Manage the IP forwarding sysctl setting in global way (#3270)

Add new package ipfwdstate that implements reference counting for IP forwarding
state management. This allows multiple usage to safely request IP forwarding
without interfering with each other.
This commit is contained in:
Zoltan Papp 2025-02-03 12:27:18 +01:00 committed by GitHub
parent a85ea1ddb0
commit 1b011a2d85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 94 additions and 13 deletions

View File

@ -16,6 +16,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@ -76,6 +77,7 @@ type router struct {
legacyManagement bool legacyManagement bool
stateManager *statemanager.Manager stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
} }
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
@ -83,6 +85,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface, wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
} }
r.ipsetCounter = refcounter.New( r.ipsetCounter = refcounter.New(
@ -217,6 +220,10 @@ func (r *router) deleteIpSet(setName string) error {
// AddNatRule inserts an iptables rule pair into the nat chain // AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement { if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil { if err := r.addLegacyRouteRule(pair); err != nil {
@ -243,6 +250,10 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains // RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err) return fmt.Errorf("remove nat rule: %w", err)
} }
@ -575,6 +586,10 @@ func (r *router) updateState() {
} }
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID() ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists { if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil return rule, nil
@ -669,6 +684,10 @@ func (r *router) rollbackRules(rules map[string]ruleInfo) error {
} }
func (r *router) DeleteDNATRule(rule firewall.Rule) error { func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID() ruleKey := rule.ID()
var merr *multierror.Error var merr *multierror.Error

View File

@ -21,6 +21,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@ -56,6 +57,7 @@ type router struct {
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
wgIface iFaceMapper wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool legacyManagement bool
} }
@ -66,6 +68,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
chains: make(map[string]*nftables.Chain), chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule), rules: make(map[string]*nftables.Rule),
wgIface: wgIface, wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
} }
r.ipsetCounter = refcounter.New( r.ipsetCounter = refcounter.New(
@ -464,6 +467,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain // AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
@ -890,6 +897,10 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule // RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
@ -951,6 +962,10 @@ func (r *router) refreshRulesMap() error {
} }
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID() ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists { if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil return rule, nil
@ -1174,6 +1189,10 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey
} }
func (r *router) DeleteDNATRule(rule firewall.Rule) error { func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID() ruleKey := rule.ID()
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {

View File

@ -0,0 +1,51 @@
package ipfwdstate
import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// IPForwardingState is a struct that keeps track of the IP forwarding state.
// todo: read initial state of the IP forwarding from the system and reset the state based on it
type IPForwardingState struct {
enabledCounter int
}
func NewIPForwardingState() *IPForwardingState {
return &IPForwardingState{}
}
func (f *IPForwardingState) RequestForwarding() error {
if f.enabledCounter != 0 {
f.enabledCounter++
return nil
}
if err := systemops.EnableIPForwarding(); err != nil {
return fmt.Errorf("failed to enable IP forwarding with sysctl: %w", err)
}
f.enabledCounter = 1
log.Info("IP forwarding enabled")
return nil
}
func (f *IPForwardingState) ReleaseForwarding() error {
if f.enabledCounter == 0 {
return nil
}
if f.enabledCounter > 1 {
f.enabledCounter--
return nil
}
// if failed to disable IP forwarding we anyway decrement the counter
f.enabledCounter = 0
// todo call systemops.DisableIPForwarding()
return nil
}

View File

@ -13,7 +13,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -70,13 +69,6 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
m.routes[id] = newRoute m.routes[id] = newRoute
} }
if len(m.routes) > 0 {
err := systemops.EnableIPForwarding()
if err != nil {
return err
}
}
return nil return nil
} }