[client] Remove outbound chains (#3157)

This commit is contained in:
Viktor Liu 2025-01-15 16:57:41 +01:00 committed by GitHub
parent 1ffa519387
commit 5a82477d48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 92 additions and 345 deletions

View File

@ -19,8 +19,7 @@ const (
tableName = "filter" tableName = "filter"
// rules chains contains the effective ACL rules // rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT" chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
) )
type aclEntries map[string][][]string type aclEntries map[string][][]string
@ -84,7 +83,6 @@ func (m *aclManager) AddPeerFiltering(
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
@ -97,15 +95,10 @@ func (m *aclManager) AddPeerFiltering(
sPortVal = strconv.Itoa(sPort.Values[0]) sPortVal = strconv.Itoa(sPort.Values[0])
} }
var chain string chain := chainNameInputRules
if direction == firewall.RuleDirectionOUT {
chain = chainNameOutputRules
} else {
chain = chainNameInputRules
}
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal) ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName) specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, action, ipsetName)
if ipsetName != "" { if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists { if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := ipset.Add(ipsetName, ip.String()); err != nil { if err := ipset.Add(ipsetName, ip.String()); err != nil {
@ -214,28 +207,7 @@ func (m *aclManager) Reset() error {
// todo write less destructive cleanup mechanism // todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error { func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules) ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
rules := m.entries["OUTPUT"]
for _, rule := range rules {
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil { if err != nil {
log.Debugf("failed to list chains: %s", err) log.Debugf("failed to list chains: %s", err)
return err return err
@ -295,12 +267,6 @@ func (m *aclManager) createDefaultChains() error {
return err return err
} }
// chain netbird-acl-output-rules
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
return err
}
for chainName, rules := range m.entries { for chainName, rules := range m.entries {
for _, rule := range rules { for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
@ -329,8 +295,6 @@ func (m *aclManager) createDefaultChains() error {
// The existing FORWARD rules/policies decide outbound traffic towards our interface. // The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() { func (m *aclManager) seedInitialEntries() {
established := getConntrackEstablished() established := getConntrackEstablished()
@ -390,30 +354,18 @@ func (m *aclManager) updateState() {
} }
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs( func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.Action, ipsetName string) (specs []string) {
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
) (specs []string) {
matchByIP := true matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0 // don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" { if ip.String() == "0.0.0.0" {
matchByIP = false matchByIP = false
} }
switch direction {
case firewall.RuleDirectionIN: if matchByIP {
if matchByIP { if ipsetName != "" {
if ipsetName != "" { specs = append(specs, "-m", "set", "--set", ipsetName, "src")
specs = append(specs, "-m", "set", "--set", ipsetName, "src") } else {
} else { specs = append(specs, "-s", ip.String())
specs = append(specs, "-s", ip.String())
}
}
case firewall.RuleDirectionOUT:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
} }
} }
if protocol != "all" { if protocol != "all" {

View File

@ -100,15 +100,14 @@ func (m *Manager) AddPeerFiltering(
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string, _ string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
@ -201,7 +200,6 @@ func (m *Manager) AllowNetbird() error {
"all", "all",
nil, nil,
nil, nil,
firewall.RuleDirectionIN,
firewall.ActionAccept, firewall.ActionAccept,
"", "",
"", "",

View File

@ -68,27 +68,13 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
var rule1 []fw.Rule
t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
}
})
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
Values: []int{8043: 8046}, Values: []int{8043: 8046},
} }
rule2, err = manager.AddPeerFiltering( rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule2 { for _, r := range rule2 {
@ -97,15 +83,6 @@ func TestIptablesManager(t *testing.T) {
} }
}) })
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
}
})
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeletePeerRule(r) err := manager.DeletePeerRule(r)
@ -119,7 +96,7 @@ func TestIptablesManager(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}} port := &fw.Port{Values: []int{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset(nil) err = manager.Reset(nil)
@ -135,9 +112,6 @@ func TestIptablesManager(t *testing.T) {
} }
func TestIptablesManagerIPSet(t *testing.T) { func TestIptablesManagerIPSet(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
mock := &iFaceMock{ mock := &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
@ -167,33 +141,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
var rule1 []fw.Rule
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic",
)
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
}
})
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
Values: []int{443}, Values: []int{443},
} }
rule2, err = manager.AddPeerFiltering( rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
for _, r := range rule2 { for _, r := range rule2 {
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@ -201,15 +155,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
}) })
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
}
})
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeletePeerRule(r) err := manager.DeletePeerRule(r)
@ -270,11 +215,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@ -69,7 +69,6 @@ type Manager interface {
proto Protocol, proto Protocol,
sPort *Port, sPort *Port,
dPort *Port, dPort *Port,
direction RuleDirection,
action Action, action Action,
ipsetName string, ipsetName string,
comment string, comment string,

View File

@ -22,8 +22,7 @@ import (
const ( const (
// rules chains contains the effective ACL rules // rules chains contains the effective ACL rules
chainNameInputRules = "netbird-acl-input-rules" chainNameInputRules = "netbird-acl-input-rules"
chainNameOutputRules = "netbird-acl-output-rules"
// filter chains contains the rules that jump to the rules chains // filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter" chainNameInputFilter = "netbird-acl-input-filter"
@ -45,9 +44,8 @@ type AclManager struct {
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routingFwChainName string
workTable *nftables.Table workTable *nftables.Table
chainInputRules *nftables.Chain chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain
ipsetStore *ipsetStore ipsetStore *ipsetStore
rules map[string]*Rule rules map[string]*Rule
@ -89,7 +87,6 @@ func (m *AclManager) AddPeerFiltering(
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string, comment string,
@ -104,7 +101,7 @@ func (m *AclManager) AddPeerFiltering(
} }
newRules := make([]firewall.Rule, 0, 2) newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, direction, action, ipset, comment) ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -214,38 +211,6 @@ func (m *AclManager) createDefaultAllowRules() error {
Exprs: expIn, Exprs: expIn,
}) })
expOut := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
// mask
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0, 0, 0, 0},
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainOutputRules,
Position: 0,
Exprs: expOut,
})
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err) return fmt.Errorf(flushError, err)
} }
@ -264,15 +229,19 @@ func (m *AclManager) Flush() error {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err) log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
} }
if err := m.refreshRuleHandles(m.chainOutputRules); err != nil {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
}
return nil return nil
} }
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) { func (m *AclManager) addIOFiltering(
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset) ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipset *nftables.Set,
comment string,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok { if r, ok := m.rules[ruleId]; ok {
return &Rule{ return &Rule{
r.nftRule, r.nftRule,
@ -310,9 +279,6 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
if !bytes.HasPrefix(anyIP, rawIP) { if !bytes.HasPrefix(anyIP, rawIP) {
// source address position // source address position
addrOffset := uint32(12) addrOffset := uint32(12)
if direction == firewall.RuleDirectionOUT {
addrOffset += 4 // is ipv4 address length
}
expressions = append(expressions, expressions = append(expressions,
&expr.Payload{ &expr.Payload{
@ -383,12 +349,7 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
userData := []byte(strings.Join([]string{ruleId, comment}, " ")) userData := []byte(strings.Join([]string{ruleId, comment}, " "))
var chain *nftables.Chain chain := m.chainInputRules
if direction == firewall.RuleDirectionIN {
chain = m.chainInputRules
} else {
chain = m.chainOutputRules
}
nftRule := m.rConn.AddRule(&nftables.Rule{ nftRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: chain, Chain: chain,
@ -419,15 +380,6 @@ func (m *AclManager) createDefaultChains() (err error) {
} }
m.chainInputRules = chain m.chainInputRules = chain
// chainNameOutputRules
chain = m.createChain(chainNameOutputRules)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameOutputRules, err)
return err
}
m.chainOutputRules = chain
// netbird-acl-input-filter // netbird-acl-input-filter
// type filter hook input priority filter; policy accept; // type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput) chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
@ -720,15 +672,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
return nil return nil
} }
func generatePeerRuleId( func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
ip net.IP, rulesetID := ":"
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipset *nftables.Set,
) string {
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
if sPort != nil { if sPort != nil {
rulesetID += sPort.String() rulesetID += sPort.String()
} }

View File

@ -117,7 +117,6 @@ func (m *Manager) AddPeerFiltering(
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string, comment string,
@ -130,10 +129,17 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
} }
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { func (m *Manager) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()

View File

@ -74,16 +74,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering( rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{53}}, fw.ActionDrop, "", "")
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{53}},
fw.RuleDirectionIN,
fw.ActionDrop,
"",
"",
)
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush() err = manager.Flush()
@ -210,11 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 { if i%100 == 0 {
@ -296,16 +283,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
}) })
ip := net.ParseIP("100.96.0.1") ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering( _, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{80}}, fw.ActionAccept, "", "test rule")
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{80}},
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"test rule",
)
require.NoError(t, err, "failed to add peer filtering rule") require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(

View File

@ -4,8 +4,6 @@ import (
"net" "net"
"github.com/google/gopacket" "github.com/google/gopacket"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
) )
// Rule to handle management of rules // Rule to handle management of rules
@ -15,7 +13,6 @@ type Rule struct {
ipLayer gopacket.LayerType ipLayer gopacket.LayerType
matchByIP bool matchByIP bool
protoLayer gopacket.LayerType protoLayer gopacket.LayerType
direction firewall.RuleDirection
sPort uint16 sPort uint16
dPort uint16 dPort uint16
drop bool drop bool

View File

@ -39,7 +39,9 @@ type RuleSet map[string]Rule
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
outgoingRules map[string]RuleSet // outgoingRules is used for hooks only
outgoingRules map[string]RuleSet
// incomingRules is used for filtering and hooks
incomingRules map[string]RuleSet incomingRules map[string]RuleSet
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
@ -156,9 +158,8 @@ func (m *Manager) AddPeerFiltering(
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, _ string,
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
r := Rule{ r := Rule{
@ -166,7 +167,6 @@ func (m *Manager) AddPeerFiltering(
ip: ip, ip: ip,
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
matchByIP: true, matchByIP: true,
direction: direction,
drop: action == firewall.ActionDrop, drop: action == firewall.ActionDrop,
comment: comment, comment: comment,
} }
@ -202,17 +202,10 @@ func (m *Manager) AddPeerFiltering(
} }
m.mutex.Lock() m.mutex.Lock()
if direction == firewall.RuleDirectionIN { if _, ok := m.incomingRules[r.ip.String()]; !ok {
if _, ok := m.incomingRules[r.ip.String()]; !ok { m.incomingRules[r.ip.String()] = make(RuleSet)
m.incomingRules[r.ip.String()] = make(RuleSet)
}
m.incomingRules[r.ip.String()][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(RuleSet)
}
m.outgoingRules[r.ip.String()][r.id] = r
} }
m.incomingRules[r.ip.String()][r.id] = r
m.mutex.Unlock() m.mutex.Unlock()
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
@ -241,19 +234,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
if r.direction == firewall.RuleDirectionIN { if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
_, ok := m.incomingRules[r.ip.String()][r.id] return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip.String()], r.id)
} else {
_, ok := m.outgoingRules[r.ip.String()][r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.outgoingRules[r.ip.String()], r.id)
} }
delete(m.incomingRules[r.ip.String()], r.id)
return nil return nil
} }
@ -566,7 +550,6 @@ func (m *Manager) AddUDPPacketHook(
protoLayer: layers.LayerTypeUDP, protoLayer: layers.LayerTypeUDP,
dPort: dPort, dPort: dPort,
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
direction: firewall.RuleDirectionOUT,
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort), comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
udpHook: hook, udpHook: hook,
} }
@ -577,7 +560,6 @@ func (m *Manager) AddUDPPacketHook(
m.mutex.Lock() m.mutex.Lock()
if in { if in {
r.direction = firewall.RuleDirectionIN
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]Rule) m.incomingRules[r.ip.String()] = make(map[string]Rule)
} }
@ -596,19 +578,22 @@ func (m *Manager) AddUDPPacketHook(
// RemovePacketHook removes packet hook by given ID // RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error { func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, arr := range m.incomingRules { for _, arr := range m.incomingRules {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
rule := r delete(arr, r.id)
return m.DeletePeerRule(&rule) return nil
} }
} }
} }
for _, arr := range m.outgoingRules { for _, arr := range m.outgoingRules {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
rule := r delete(arr, r.id)
return m.DeletePeerRule(&rule) return nil
} }
} }
} }

View File

@ -91,7 +91,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
// Single rule allowing all traffic // Single rule allowing all traffic
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "allow all") fw.ActionAccept, "", "allow all")
require.NoError(b, err) require.NoError(b, err)
}, },
desc: "Baseline: Single 'allow all' rule without connection tracking", desc: "Baseline: Single 'allow all' rule without connection tracking",
@ -114,7 +114,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP, _, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
&fw.Port{Values: []int{1024 + i}}, &fw.Port{Values: []int{1024 + i}},
&fw.Port{Values: []int{80}}, &fw.Port{Values: []int{80}},
fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return") fw.ActionAccept, "", "explicit return")
require.NoError(b, err) require.NoError(b, err)
} }
}, },
@ -126,7 +126,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections // Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
fw.RuleDirectionIN, fw.ActionDrop, "", "default drop") fw.ActionDrop, "", "default drop")
require.NoError(b, err) require.NoError(b, err)
}, },
desc: "Connection tracking with established connections", desc: "Connection tracking with established connections",
@ -590,7 +590,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}}, &fw.Port{Values: []int{80}},
nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@ -681,7 +681,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}}, &fw.Port{Values: []int{80}},
nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@ -799,7 +799,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}}, &fw.Port{Values: []int{80}},
nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@ -886,7 +886,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}}, &fw.Port{Values: []int{80}},
nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }

View File

@ -70,11 +70,10 @@ func TestManagerAddPeerFiltering(t *testing.T) {
ip := net.ParseIP("192.168.1.1") ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}} port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@ -105,37 +104,15 @@ func TestManagerDeleteRule(t *testing.T) {
ip := net.ParseIP("192.168.1.1") ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}} port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule 2"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
ip = net.ParseIP("192.168.1.1")
proto = fw.ProtocolTCP
port = &fw.Port{Values: []int{80}}
direction = fw.RuleDirectionIN
action = fw.ActionDrop
comment = "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
for _, r := range rule {
err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
}
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok { if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
@ -225,10 +202,6 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer) t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
return return
} }
if tt.expDir != addedRule.direction {
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
return
}
if addedRule.udpHook == nil { if addedRule.udpHook == nil {
t.Errorf("expected udpHook to be set") t.Errorf("expected udpHook to be set")
return return
@ -251,11 +224,10 @@ func TestManagerReset(t *testing.T) {
ip := net.ParseIP("192.168.1.1") ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}} port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) _, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@ -289,11 +261,10 @@ func TestNotMatchByIP(t *testing.T) {
ip := net.ParseIP("0.0.0.0") ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP proto := fw.ProtocolUDP
direction := fw.RuleDirectionOUT
action := fw.ActionAccept action := fw.ActionAccept
comment := "Test rule" comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment) _, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@ -327,7 +298,7 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if m.dropFilter(buf.Bytes(), m.outgoingRules) { if m.dropFilter(buf.Bytes(), m.incomingRules) {
t.Errorf("expected packet to be accepted") t.Errorf("expected packet to be accepted")
return return
} }
@ -493,11 +464,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@ -151,7 +151,7 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
d.rollBack(newRulePairs) d.rollBack(newRulePairs)
break break
} }
if len(rules) > 0 { if len(rulePair) > 0 {
d.peerRulesPairs[pairID] = rulePair d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair newRulePairs[pairID] = rulePair
} }
@ -288,6 +288,8 @@ func (d *DefaultManager) protoRuleToFirewallRule(
case mgmProto.RuleDirection_IN: case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
case mgmProto.RuleDirection_OUT: case mgmProto.RuleDirection_OUT:
// TODO: Remove this soon. Outbound rules are obsolete.
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
@ -308,25 +310,12 @@ func (d *DefaultManager) addInRules(
ipsetName string, ipsetName string,
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var rules []firewall.Rule rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("add firewall rule: %w", err)
}
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
} }
rule, err = d.firewall.AddPeerFiltering( return rule, nil
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
return append(rules, rule...), nil
} }
func (d *DefaultManager) addOutRules( func (d *DefaultManager) addOutRules(
@ -337,25 +326,16 @@ func (d *DefaultManager) addOutRules(
ipsetName string, ipsetName string,
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) { if shouldSkipInvertedRule(protocol, port) {
return rules, nil return nil, nil
} }
rule, err = d.firewall.AddPeerFiltering( rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("add firewall rule: %w", err)
} }
return append(rules, rule...), nil return rule, nil
} }
// getPeerRuleID() returns unique ID for the rule based on its parameters. // getPeerRuleID() returns unique ID for the rule based on its parameters.

View File

@ -119,8 +119,8 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRulesIsEmpty = false networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 2 { if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
return return
} }
}) })
@ -356,8 +356,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 4 { if len(acl.peerRulesPairs) != 3 {
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
return return
} }
} }

View File

@ -88,7 +88,7 @@ func (h *Manager) allowDNSFirewall() error {
return nil return nil
} }
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "")
if err != nil { if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err) log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err return err

View File

@ -495,7 +495,6 @@ func (e *Engine) initFirewall() error {
manager.ProtocolUDP, manager.ProtocolUDP,
nil, nil,
&port, &port,
manager.RuleDirectionIN,
manager.ActionAccept, manager.ActionAccept,
"", "",
"", "",