Add default forward rule (#1021)

* Add default forward rule

* Fix

* Add multiple forward rules

* Fix delete rule error handling
This commit is contained in:
Zoltan Papp 2023-07-22 18:39:23 +02:00 committed by GitHub
parent 2541c78dd0
commit 6c2ed4b4f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -19,6 +19,9 @@ const (
nftablesTable = "netbird-rt"
nftablesRoutingForwardingChain = "netbird-rt-fwd"
nftablesRoutingNatChain = "netbird-rt-nat"
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
)
// constants needed to create nftable rules
@ -71,25 +74,28 @@ var (
)
type nftablesManager struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
tableIPv4 *nftables.Table
tableIPv6 *nftables.Table
chains map[string]map[string]*nftables.Chain
rules map[string]*nftables.Rule
mux sync.Mutex
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
tableIPv4 *nftables.Table
tableIPv6 *nftables.Table
chains map[string]map[string]*nftables.Chain
rules map[string]*nftables.Rule
filterTable *nftables.Table
defaultForwardRules []*nftables.Rule
mux sync.Mutex
}
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
ctx, cancel := context.WithCancel(parentCtx)
mgr := &nftablesManager{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
defaultForwardRules: make([]*nftables.Rule, 2),
}
err := mgr.isSupported()
@ -97,6 +103,11 @@ func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
return nil, err
}
err = mgr.readFilterTable()
if err != nil {
return nil, err
}
return mgr, nil
}
@ -109,6 +120,13 @@ func (n *nftablesManager) CleanRoutingRules() {
n.conn.FlushTable(n.tableIPv6)
n.conn.FlushTable(n.tableIPv4)
}
if n.defaultForwardRules[0] != nil {
err := n.eraseDefaultForwardRule()
if err != nil {
log.Errorf("failed to delete forward rule: %s", err)
}
}
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
}
@ -241,6 +259,112 @@ func (n *nftablesManager) refreshRulesMap() error {
return nil
}
func (n *nftablesManager) readFilterTable() error {
tables, err := n.conn.ListTables()
if err != nil {
return err
}
for _, t := range tables {
if t.Name == "filter" {
n.filterTable = t
return nil
}
}
return nil
}
func (n *nftablesManager) eraseDefaultForwardRule() error {
if n.defaultForwardRules[0] == nil {
return nil
}
err := n.refreshDefaultForwardRule()
if err != nil {
return err
}
for i, r := range n.defaultForwardRules {
err = n.conn.DelRule(r)
if err != nil {
log.Errorf("failed to delete forward rule (%d): %s", i, err)
}
n.defaultForwardRules[i] = nil
}
return nil
}
func (n *nftablesManager) refreshDefaultForwardRule() error {
rules, err := n.conn.GetRules(n.defaultForwardRules[0].Table, n.defaultForwardRules[0].Chain)
if err != nil {
return fmt.Errorf("unable to list rules in forward chain: %s", err)
}
found := false
for i, r := range n.defaultForwardRules {
for _, rule := range rules {
if string(rule.UserData) == string(r.UserData) {
n.defaultForwardRules[i] = rule
found = true
break
}
}
}
if !found {
return fmt.Errorf("unable to find forward accept rule")
}
return nil
}
func (n *nftablesManager) acceptForwardRule(sourceNetwork string) error {
src := generateCIDRMatcherExpressions("source", sourceNetwork)
dst := generateCIDRMatcherExpressions("destination", "0.0.0.0/0")
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{
Kind: expr.VerdictAccept,
})...)
r := &nftables.Rule{
Table: n.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: n.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleSrc),
}
n.defaultForwardRules[0] = n.conn.AddRule(r)
src = generateCIDRMatcherExpressions("source", "0.0.0.0/0")
dst = generateCIDRMatcherExpressions("destination", sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{
Kind: expr.VerdictAccept,
})...)
r = &nftables.Rule{
Table: n.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: n.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleDst),
}
n.defaultForwardRules[1] = n.conn.AddRule(r)
return nil
}
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
_, foundIPv4 := n.rules[ipv4Forwarding]
@ -294,6 +418,14 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
}
}
if n.defaultForwardRules[0] == nil && n.filterTable != nil {
err = n.acceptForwardRule(pair.source)
if err != nil {
log.Errorf("unable to create default forward rule: %s", err)
}
log.Debugf("default accept forward rule added")
}
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
@ -374,6 +506,13 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
return err
}
if len(n.rules) == 2 && n.defaultForwardRules[0] != nil {
err := n.eraseDefaultForwardRule()
if err != nil {
log.Errorf("failed to delte default fwd rule: %s", err)
}
}
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)