[client] Mark redirected traffic early to match input filters on pre-DNAT ports (#3205)

This commit is contained in:
Viktor Liu 2025-01-23 18:00:51 +01:00 committed by GitHub
parent 790a9ed7df
commit eb2ac039c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 145 additions and 87 deletions

View File

@ -3,6 +3,7 @@ package iptables
import ( import (
"fmt" "fmt"
"net" "net"
"slices"
"strconv" "strconv"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
@ -99,6 +100,16 @@ func (m *aclManager) AddPeerFiltering(
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal) ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, action, ipsetName) specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, action, ipsetName)
mangleSpecs := slices.Clone(specs)
mangleSpecs = append(mangleSpecs,
"-i", m.wgIface.Name(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
)
specs = append(specs, "-j", actionToStr(action))
if ipsetName != "" { if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists { if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := ipset.Add(ipsetName, ip.String()); err != nil { if err := ipset.Add(ipsetName, ip.String()); err != nil {
@ -130,7 +141,7 @@ func (m *aclManager) AddPeerFiltering(
m.ipsetStore.addIpList(ipsetName, ipList) m.ipsetStore.addIpList(ipsetName, ipList)
} }
ok, err := m.iptablesClient.Exists("filter", chain, specs...) ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err) return nil, fmt.Errorf("failed to check rule: %w", err)
} }
@ -138,16 +149,22 @@ func (m *aclManager) AddPeerFiltering(
return nil, fmt.Errorf("rule already exists") return nil, fmt.Errorf("rule already exists")
} }
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil { if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil {
return nil, err return nil, err
} }
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
log.Errorf("failed to add mangle rule: %v", err)
mangleSpecs = nil
}
rule := &Rule{ rule := &Rule{
ruleID: uuid.New().String(), ruleID: uuid.New().String(),
specs: specs, specs: specs,
ipsetName: ipsetName, mangleSpecs: mangleSpecs,
ip: ip.String(), ipsetName: ipsetName,
chain: chain, ip: ip.String(),
chain: chain,
} }
m.updateState() m.updateState()
@ -190,6 +207,12 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err) return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
} }
if r.mangleSpecs != nil {
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
m.updateState() m.updateState()
return nil return nil
@ -310,17 +333,10 @@ func (m *aclManager) seedInitialEntries() {
func (m *aclManager) seedInitialOptionalEntries() { func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{ m.optionalEntries["FORWARD"] = []entry{
{ {
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules}, spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
position: 2, position: 2,
}, },
} }
m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
position: 1,
},
}
} }
func (m *aclManager) appendToEntries(chainName string, spec []string) { func (m *aclManager) appendToEntries(chainName string, spec []string) {
@ -377,7 +393,7 @@ func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.A
if dPort != "" { if dPort != "" {
specs = append(specs, "--dport", dPort) specs = append(specs, "--dport", dPort)
} }
return append(specs, "-j", actionToStr(action)) return specs
} }
func actionToStr(action firewall.Action) string { func actionToStr(action firewall.Action) string {

View File

@ -5,9 +5,10 @@ type Rule struct {
ruleID string ruleID string
ipsetName string ipsetName string
specs []string specs []string
ip string mangleSpecs []string
chain string ip string
chain string
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id

View File

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
"slices"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -46,6 +47,7 @@ type AclManager struct {
workTable *nftables.Table workTable *nftables.Table
chainInputRules *nftables.Chain chainInputRules *nftables.Chain
chainPrerouting *nftables.Chain
ipsetStore *ipsetStore ipsetStore *ipsetStore
rules map[string]*Rule rules map[string]*Rule
@ -118,23 +120,32 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
} }
if r.nftSet == nil { if r.nftSet == nil {
err := m.rConn.DelRule(r.nftRule) if err := m.rConn.DelRule(r.nftRule); err != nil {
if err != nil {
log.Errorf("failed to delete rule: %v", err) log.Errorf("failed to delete rule: %v", err)
} }
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.GetRuleID()) delete(m.rules, r.GetRuleID())
return m.rConn.Flush() return m.rConn.Flush()
} }
ips, ok := m.ipsetStore.ips(r.nftSet.Name) ips, ok := m.ipsetStore.ips(r.nftSet.Name)
if !ok { if !ok {
err := m.rConn.DelRule(r.nftRule) if err := m.rConn.DelRule(r.nftRule); err != nil {
if err != nil {
log.Errorf("failed to delete rule: %v", err) log.Errorf("failed to delete rule: %v", err)
} }
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.GetRuleID()) delete(m.rules, r.GetRuleID())
return m.rConn.Flush() return m.rConn.Flush()
} }
if _, ok := ips[r.ip.String()]; ok { if _, ok := ips[r.ip.String()]; ok {
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}}) err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
if err != nil { if err != nil {
@ -153,12 +164,16 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
return nil return nil
} }
err := m.rConn.DelRule(r.nftRule) if err := m.rConn.DelRule(r.nftRule); err != nil {
if err != nil {
log.Errorf("failed to delete rule: %v", err) log.Errorf("failed to delete rule: %v", err)
} }
err = m.rConn.Flush() if r.mangleRule != nil {
if err != nil { if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
if err := m.rConn.Flush(); err != nil {
return err return err
} }
@ -225,9 +240,12 @@ func (m *AclManager) Flush() error {
return err return err
} }
if err := m.refreshRuleHandles(m.chainInputRules); err != nil { if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err) log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
} }
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
}
return nil return nil
} }
@ -244,10 +262,11 @@ func (m *AclManager) addIOFiltering(
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset) ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok { if r, ok := m.rules[ruleId]; ok {
return &Rule{ return &Rule{
r.nftRule, nftRule: r.nftRule,
r.nftSet, mangleRule: r.mangleRule,
r.ruleID, nftSet: r.nftSet,
ip, ruleID: r.ruleID,
ip: ip,
}, nil }, nil
} }
@ -340,11 +359,13 @@ func (m *AclManager) addIOFiltering(
) )
} }
mainExpressions := slices.Clone(expressions)
switch action { switch action {
case firewall.ActionAccept: case firewall.ActionAccept:
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept}) mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
case firewall.ActionDrop: case firewall.ActionDrop:
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop}) mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
} }
userData := []byte(strings.Join([]string{ruleId, comment}, " ")) userData := []byte(strings.Join([]string{ruleId, comment}, " "))
@ -353,15 +374,16 @@ func (m *AclManager) addIOFiltering(
nftRule := m.rConn.AddRule(&nftables.Rule{ nftRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: chain, Chain: chain,
Exprs: expressions, Exprs: mainExpressions,
UserData: userData, UserData: userData,
}) })
rule := &Rule{ rule := &Rule{
nftRule: nftRule, nftRule: nftRule,
nftSet: ipset, mangleRule: m.createPreroutingRule(expressions, userData),
ruleID: ruleId, nftSet: ipset,
ip: ip, ruleID: ruleId,
ip: ip,
} }
m.rules[ruleId] = rule m.rules[ruleId] = rule
if ipset != nil { if ipset != nil {
@ -370,6 +392,59 @@ func (m *AclManager) addIOFiltering(
return rule, nil return rule, nil
} }
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
if m.chainPrerouting == nil {
log.Warn("prerouting chain is not created")
return nil
}
preroutingExprs := slices.Clone(expressions)
// interface
preroutingExprs = append([]expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}, preroutingExprs...)
// local destination and mark
preroutingExprs = append(preroutingExprs,
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
)
return m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
}
func (m *AclManager) createDefaultChains() (err error) { func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules // chainNameInputRules
chain := m.createChain(chainNameInputRules) chain := m.createChain(chainNameInputRules)
@ -413,7 +488,7 @@ func (m *AclManager) createDefaultChains() (err error) {
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the // go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP. // netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error { func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
preroutingChain := m.rConn.AddChain(&nftables.Chain{ m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting, Name: chainNamePrerouting,
Table: m.workTable, Table: m.workTable,
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
@ -421,8 +496,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
Priority: nftables.ChainPriorityMangle, Priority: nftables.ChainPriorityMangle,
}) })
m.addPreroutingRule(preroutingChain)
m.addFwmarkToForward(chainFwFilter) m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
@ -432,43 +505,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
return nil return nil
} }
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: preroutingChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
},
})
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) { func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{ m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
@ -484,8 +520,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected), Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictJump, Kind: expr.VerdictAccept,
Chain: m.chainInputRules.Name,
}, },
}, },
}) })
@ -632,6 +667,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
for i := 0; ; i++ { for i := 0; ; i++ {
err = m.rConn.Flush() err = m.rConn.Flush()
if err != nil { if err != nil {
log.Debugf("failed to flush nftables: %v", err)
if !strings.Contains(err.Error(), "busy") { if !strings.Contains(err.Error(), "busy") {
return return
} }
@ -648,7 +684,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
return return
} }
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error { func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
if m.workTable == nil || chain == nil { if m.workTable == nil || chain == nil {
return nil return nil
} }
@ -665,7 +701,11 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
split := bytes.Split(rule.UserData, []byte(" ")) split := bytes.Split(rule.UserData, []byte(" "))
r, ok := m.rules[string(split[0])] r, ok := m.rules[string(split[0])]
if ok { if ok {
*r.nftRule = *rule if mangle {
*r.mangleRule = *rule
} else {
*r.nftRule = *rule
}
} }
} }

View File

@ -8,10 +8,11 @@ import (
// Rule to handle management of rules // Rule to handle management of rules
type Rule struct { type Rule struct {
nftRule *nftables.Rule nftRule *nftables.Rule
nftSet *nftables.Set mangleRule *nftables.Rule
ruleID string nftSet *nftables.Set
ip net.IP ruleID string
ip net.IP
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id