mirror of
https://github.com/netbirdio/netbird.git
synced 2024-10-05 09:42:09 +02:00
Add default forward rule (#1021)
* Add default forward rule * Fix * Add multiple forward rules * Fix delete rule error handling
This commit is contained in:
parent
2541c78dd0
commit
6c2ed4b4f2
@ -19,6 +19,9 @@ const (
|
|||||||
nftablesTable = "netbird-rt"
|
nftablesTable = "netbird-rt"
|
||||||
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
||||||
nftablesRoutingNatChain = "netbird-rt-nat"
|
nftablesRoutingNatChain = "netbird-rt-nat"
|
||||||
|
|
||||||
|
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
||||||
|
userDataAcceptForwardRuleDst = "frwacceptdst"
|
||||||
)
|
)
|
||||||
|
|
||||||
// constants needed to create nftable rules
|
// constants needed to create nftable rules
|
||||||
@ -71,25 +74,28 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type nftablesManager struct {
|
type nftablesManager struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
stop context.CancelFunc
|
stop context.CancelFunc
|
||||||
conn *nftables.Conn
|
conn *nftables.Conn
|
||||||
tableIPv4 *nftables.Table
|
tableIPv4 *nftables.Table
|
||||||
tableIPv6 *nftables.Table
|
tableIPv6 *nftables.Table
|
||||||
chains map[string]map[string]*nftables.Chain
|
chains map[string]map[string]*nftables.Chain
|
||||||
rules map[string]*nftables.Rule
|
rules map[string]*nftables.Rule
|
||||||
mux sync.Mutex
|
filterTable *nftables.Table
|
||||||
|
defaultForwardRules []*nftables.Rule
|
||||||
|
mux sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
|
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
|
||||||
mgr := &nftablesManager{
|
mgr := &nftablesManager{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
conn: &nftables.Conn{},
|
conn: &nftables.Conn{},
|
||||||
chains: make(map[string]map[string]*nftables.Chain),
|
chains: make(map[string]map[string]*nftables.Chain),
|
||||||
rules: make(map[string]*nftables.Rule),
|
rules: make(map[string]*nftables.Rule),
|
||||||
|
defaultForwardRules: make([]*nftables.Rule, 2),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := mgr.isSupported()
|
err := mgr.isSupported()
|
||||||
@ -97,6 +103,11 @@ func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = mgr.readFilterTable()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return mgr, nil
|
return mgr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,6 +120,13 @@ func (n *nftablesManager) CleanRoutingRules() {
|
|||||||
n.conn.FlushTable(n.tableIPv6)
|
n.conn.FlushTable(n.tableIPv6)
|
||||||
n.conn.FlushTable(n.tableIPv4)
|
n.conn.FlushTable(n.tableIPv4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if n.defaultForwardRules[0] != nil {
|
||||||
|
err := n.eraseDefaultForwardRule()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delete forward rule: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,6 +259,112 @@ func (n *nftablesManager) refreshRulesMap() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *nftablesManager) readFilterTable() error {
|
||||||
|
tables, err := n.conn.ListTables()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range tables {
|
||||||
|
if t.Name == "filter" {
|
||||||
|
n.filterTable = t
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nftablesManager) eraseDefaultForwardRule() error {
|
||||||
|
if n.defaultForwardRules[0] == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := n.refreshDefaultForwardRule()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, r := range n.defaultForwardRules {
|
||||||
|
err = n.conn.DelRule(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delete forward rule (%d): %s", i, err)
|
||||||
|
}
|
||||||
|
n.defaultForwardRules[i] = nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nftablesManager) refreshDefaultForwardRule() error {
|
||||||
|
rules, err := n.conn.GetRules(n.defaultForwardRules[0].Table, n.defaultForwardRules[0].Chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to list rules in forward chain: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for i, r := range n.defaultForwardRules {
|
||||||
|
for _, rule := range rules {
|
||||||
|
if string(rule.UserData) == string(r.UserData) {
|
||||||
|
n.defaultForwardRules[i] = rule
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("unable to find forward accept rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nftablesManager) acceptForwardRule(sourceNetwork string) error {
|
||||||
|
src := generateCIDRMatcherExpressions("source", sourceNetwork)
|
||||||
|
dst := generateCIDRMatcherExpressions("destination", "0.0.0.0/0")
|
||||||
|
|
||||||
|
var exprs []expr.Any
|
||||||
|
exprs = append(src, append(dst, &expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
})...)
|
||||||
|
|
||||||
|
r := &nftables.Rule{
|
||||||
|
Table: n.filterTable,
|
||||||
|
Chain: &nftables.Chain{
|
||||||
|
Name: "FORWARD",
|
||||||
|
Table: n.filterTable,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
},
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(userDataAcceptForwardRuleSrc),
|
||||||
|
}
|
||||||
|
|
||||||
|
n.defaultForwardRules[0] = n.conn.AddRule(r)
|
||||||
|
|
||||||
|
src = generateCIDRMatcherExpressions("source", "0.0.0.0/0")
|
||||||
|
dst = generateCIDRMatcherExpressions("destination", sourceNetwork)
|
||||||
|
|
||||||
|
exprs = append(src, append(dst, &expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
})...)
|
||||||
|
|
||||||
|
r = &nftables.Rule{
|
||||||
|
Table: n.filterTable,
|
||||||
|
Chain: &nftables.Chain{
|
||||||
|
Name: "FORWARD",
|
||||||
|
Table: n.filterTable,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
},
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(userDataAcceptForwardRuleDst),
|
||||||
|
}
|
||||||
|
|
||||||
|
n.defaultForwardRules[1] = n.conn.AddRule(r)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
|
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
|
||||||
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
|
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
|
||||||
_, foundIPv4 := n.rules[ipv4Forwarding]
|
_, foundIPv4 := n.rules[ipv4Forwarding]
|
||||||
@ -294,6 +418,14 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if n.defaultForwardRules[0] == nil && n.filterTable != nil {
|
||||||
|
err = n.acceptForwardRule(pair.source)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to create default forward rule: %s", err)
|
||||||
|
}
|
||||||
|
log.Debugf("default accept forward rule added")
|
||||||
|
}
|
||||||
|
|
||||||
err = n.conn.Flush()
|
err = n.conn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
||||||
@ -374,6 +506,13 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(n.rules) == 2 && n.defaultForwardRules[0] != nil {
|
||||||
|
err := n.eraseDefaultForwardRule()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delte default fwd rule: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = n.conn.Flush()
|
err = n.conn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
||||||
|
Loading…
Reference in New Issue
Block a user