mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-15 17:52:47 +02:00
[client] Cleanup firewall state on startup (#2768)
This commit is contained in:
@ -3,7 +3,6 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
@ -18,6 +17,7 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -48,28 +48,31 @@ type routeFilteringRuleParams struct {
|
||||
SetName string
|
||||
}
|
||||
|
||||
type routeRules map[string][]string
|
||||
|
||||
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||
|
||||
type router struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
iptablesClient *iptables.IPTables
|
||||
rules map[string][]string
|
||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||
rules routeRules
|
||||
ipsetCounter *ipsetCounter
|
||||
wgIface iFaceMapper
|
||||
legacyManagement bool
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||
r := &router{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
iptablesClient: iptablesClient,
|
||||
rules: make(map[string][]string),
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
r.createIpSet,
|
||||
func(name string, sources []netip.Prefix) (struct{}, error) {
|
||||
return struct{}{}, r.createIpSet(name, sources)
|
||||
},
|
||||
func(name string, _ struct{}) error {
|
||||
return r.deleteIpSet(name)
|
||||
},
|
||||
@ -79,16 +82,23 @@ func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgI
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
err := r.cleanUpDefaultForwardRules()
|
||||
if err != nil {
|
||||
log.Errorf("cleanup routing rules: %s", err)
|
||||
return nil, err
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *router) init(stateManager *statemanager.Manager) error {
|
||||
r.stateManager = stateManager
|
||||
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
err = r.createContainers()
|
||||
if err != nil {
|
||||
log.Errorf("create containers for route: %s", err)
|
||||
|
||||
if err := r.createContainers(); err != nil {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
return r, err
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) AddRouteFiltering(
|
||||
@ -129,6 +139,8 @@ func (r *router) AddRouteFiltering(
|
||||
|
||||
r.rules[string(ruleKey)] = rule
|
||||
|
||||
r.updateState()
|
||||
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
@ -152,6 +164,8 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
log.Debugf("route rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -164,18 +178,18 @@ func (r *router) findSetNameInRule(rule []string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) {
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
||||
return struct{}{}, fmt.Errorf("create set %s: %w", setName, err)
|
||||
return fmt.Errorf("create set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
for _, prefix := range sources {
|
||||
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
||||
return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err)
|
||||
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return struct{}{}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) deleteIpSet(setName string) error {
|
||||
@ -206,6 +220,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -223,6 +239,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -280,6 +298,9 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
@ -294,6 +315,8 @@ func (r *router) Reset() error {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
@ -431,6 +454,32 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) updateState() {
|
||||
if r.stateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var currentState *ShutdownState
|
||||
if existing := r.stateManager.GetState(currentState); existing != nil {
|
||||
if existingState, ok := existing.(*ShutdownState); ok {
|
||||
currentState = existingState
|
||||
}
|
||||
}
|
||||
if currentState == nil {
|
||||
currentState = &ShutdownState{}
|
||||
}
|
||||
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
|
||||
if err := r.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
||||
intdir := "-i"
|
||||
lointdir := "-o"
|
||||
|
Reference in New Issue
Block a user