mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-24 19:51:33 +02:00
Merge pull request #991 from netbirdio/fix/improve_uspfilter_performance
Improve userspace filter performance
This commit is contained in:
commit
f40951cdf5
@ -21,11 +21,13 @@ type IFaceMapper interface {
|
|||||||
SetFilter(iface.PacketFilter) error
|
SetFilter(iface.PacketFilter) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RuleSet is a set of rules grouped by a string key
|
||||||
|
type RuleSet map[string]Rule
|
||||||
|
|
||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
outgoingRules []Rule
|
outgoingRules map[string]RuleSet
|
||||||
incomingRules []Rule
|
incomingRules map[string]RuleSet
|
||||||
rulesIndex map[string]int
|
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
|
|
||||||
@ -48,7 +50,6 @@ type decoder struct {
|
|||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface IFaceMapper) (*Manager, error) {
|
func Create(iface IFaceMapper) (*Manager, error) {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
rulesIndex: make(map[string]int),
|
|
||||||
decoders: sync.Pool{
|
decoders: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
d := &decoder{
|
d := &decoder{
|
||||||
@ -62,6 +63,8 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
|||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
outgoingRules: make(map[string]RuleSet),
|
||||||
|
incomingRules: make(map[string]RuleSet),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
@ -124,15 +127,17 @@ func (m *Manager) AddFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
var p int
|
|
||||||
if direction == fw.RuleDirectionIN {
|
if direction == fw.RuleDirectionIN {
|
||||||
m.incomingRules = append(m.incomingRules, r)
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
p = len(m.incomingRules) - 1
|
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||||
|
}
|
||||||
|
m.incomingRules[r.ip.String()][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
m.outgoingRules = append(m.outgoingRules, r)
|
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||||
p = len(m.outgoingRules) - 1
|
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
||||||
|
}
|
||||||
|
m.outgoingRules[r.ip.String()][r.id] = r
|
||||||
}
|
}
|
||||||
m.rulesIndex[r.id] = p
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return &r, nil
|
return &r, nil
|
||||||
@ -148,24 +153,20 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
|||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
p, ok := m.rulesIndex[r.id]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
|
||||||
}
|
|
||||||
delete(m.rulesIndex, r.id)
|
|
||||||
|
|
||||||
var toUpdate []Rule
|
|
||||||
if r.direction == fw.RuleDirectionIN {
|
if r.direction == fw.RuleDirectionIN {
|
||||||
m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...)
|
_, ok := m.incomingRules[r.ip.String()][r.id]
|
||||||
toUpdate = m.incomingRules
|
if !ok {
|
||||||
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
|
}
|
||||||
|
delete(m.incomingRules[r.ip.String()], r.id)
|
||||||
} else {
|
} else {
|
||||||
m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...)
|
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
||||||
toUpdate = m.outgoingRules
|
if !ok {
|
||||||
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
|
}
|
||||||
|
delete(m.outgoingRules[r.ip.String()], r.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < len(toUpdate); i++ {
|
|
||||||
m.rulesIndex[toUpdate[i].id] = i
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -174,9 +175,8 @@ func (m *Manager) Reset() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = m.outgoingRules[:0]
|
m.outgoingRules = make(map[string]RuleSet)
|
||||||
m.incomingRules = m.incomingRules[:0]
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
m.rulesIndex = make(map[string]int)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -192,7 +192,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter imlements same logic for booth direction of the traffic
|
// dropFilter imlements same logic for booth direction of the traffic
|
||||||
func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool {
|
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
@ -224,37 +224,49 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
|||||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
payloadLayer := d.decoded[1]
|
|
||||||
|
|
||||||
// check if IP address match by IP
|
var ip net.IP
|
||||||
|
switch ipLayer {
|
||||||
|
case layers.LayerTypeIPv4:
|
||||||
|
if isIncomingPacket {
|
||||||
|
ip = d.ip4.SrcIP
|
||||||
|
} else {
|
||||||
|
ip = d.ip4.DstIP
|
||||||
|
}
|
||||||
|
case layers.LayerTypeIPv6:
|
||||||
|
if isIncomingPacket {
|
||||||
|
ip = d.ip6.SrcIP
|
||||||
|
} else {
|
||||||
|
ip = d.ip6.DstIP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
|
||||||
|
if ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
|
||||||
|
if ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
filter, ok = validateRule(ip, packetData, rules["::"], d)
|
||||||
|
if ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
|
||||||
|
// default policy is DROP ALL
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
||||||
|
payloadLayer := d.decoded[1]
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.matchByIP {
|
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||||
switch ipLayer {
|
continue
|
||||||
case layers.LayerTypeIPv4:
|
|
||||||
if isIncomingPacket {
|
|
||||||
if !d.ip4.SrcIP.Equal(rule.ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if !d.ip4.DstIP.Equal(rule.ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case layers.LayerTypeIPv6:
|
|
||||||
if isIncomingPacket {
|
|
||||||
if !d.ip6.SrcIP.Equal(rule.ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if !d.ip6.DstIP.Equal(rule.ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.protoLayer == layerTypeAll {
|
if rule.protoLayer == layerTypeAll {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if payloadLayer != rule.protoLayer {
|
if payloadLayer != rule.protoLayer {
|
||||||
@ -264,38 +276,36 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
|||||||
switch payloadLayer {
|
switch payloadLayer {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
if rule.sPort == 0 && rule.dPort == 0 {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
// if rule has UDP hook (and if we are here we match this rule)
|
// if rule has UDP hook (and if we are here we match this rule)
|
||||||
// we ignore rule.drop and call this hook
|
// we ignore rule.drop and call this hook
|
||||||
if rule.udpHook != nil {
|
if rule.udpHook != nil {
|
||||||
return rule.udpHook(packetData)
|
return rule.udpHook(packetData), true
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
if rule.sPort == 0 && rule.dPort == 0 {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
return rule.drop
|
return rule.drop, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return false, false
|
||||||
// default policy is DROP ALL
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
// SetNetwork of the wireguard interface to which filtering applied
|
||||||
@ -325,19 +335,19 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
var toUpdate []Rule
|
|
||||||
if in {
|
if in {
|
||||||
r.direction = fw.RuleDirectionIN
|
r.direction = fw.RuleDirectionIN
|
||||||
m.incomingRules = append([]Rule{r}, m.incomingRules...)
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
toUpdate = m.incomingRules
|
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||||
|
}
|
||||||
|
m.incomingRules[r.ip.String()][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
m.outgoingRules = append([]Rule{r}, m.outgoingRules...)
|
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||||
toUpdate = m.outgoingRules
|
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
||||||
|
}
|
||||||
|
m.outgoingRules[r.ip.String()][r.id] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range toUpdate {
|
|
||||||
m.rulesIndex[toUpdate[i].id] = i
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return r.id
|
return r.id
|
||||||
@ -345,14 +355,18 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
|
|
||||||
// RemovePacketHook removes packet hook by given ID
|
// RemovePacketHook removes packet hook by given ID
|
||||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||||
for _, r := range m.incomingRules {
|
for _, arr := range m.incomingRules {
|
||||||
if r.id == hookID {
|
for _, r := range arr {
|
||||||
return m.DeleteRule(&r)
|
if r.id == hookID {
|
||||||
|
return m.DeleteRule(&r)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, r := range m.outgoingRules {
|
for _, arr := range m.outgoingRules {
|
||||||
if r.id == hookID {
|
for _, r := range arr {
|
||||||
return m.DeleteRule(&r)
|
if r.id == hookID {
|
||||||
|
return m.DeleteRule(&r)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Errorf("hook with given id not found")
|
return fmt.Errorf("hook with given id not found")
|
||||||
|
@ -123,8 +123,8 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if idx, ok := m.rulesIndex[rule2.GetRuleID()]; !ok || len(m.incomingRules) != 1 || idx != 0 {
|
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
|
||||||
t.Errorf("rule2 is not in the rulesIndex")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.DeleteRule(rule2)
|
err = m.DeleteRule(rule2)
|
||||||
@ -133,8 +133,8 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.rulesIndex) != 0 || len(m.incomingRules) != 0 {
|
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
|
||||||
t.Errorf("rule1 still in the rulesIndex")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,26 +169,29 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager := &Manager{
|
manager := &Manager{
|
||||||
incomingRules: []Rule{},
|
incomingRules: map[string]RuleSet{},
|
||||||
outgoingRules: []Rule{},
|
outgoingRules: map[string]RuleSet{},
|
||||||
rulesIndex: make(map[string]int),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
var addedRule Rule
|
var addedRule Rule
|
||||||
if tt.in {
|
if tt.in {
|
||||||
if len(manager.incomingRules) != 1 {
|
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
addedRule = manager.incomingRules[0]
|
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
||||||
|
addedRule = rule
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if len(manager.outgoingRules) != 1 {
|
if len(manager.outgoingRules) != 1 {
|
||||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
addedRule = manager.outgoingRules[0]
|
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
||||||
|
addedRule = rule
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tt.ip.Equal(addedRule.ip) {
|
if !tt.ip.Equal(addedRule.ip) {
|
||||||
@ -211,17 +214,6 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Errorf("expected udpHook to be set")
|
t.Errorf("expected udpHook to be set")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure rulesIndex is correctly updated
|
|
||||||
index, ok := manager.rulesIndex[addedRule.id]
|
|
||||||
if !ok {
|
|
||||||
t.Errorf("expected rule to be in rulesIndex")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if index != 0 {
|
|
||||||
t.Errorf("expected rule index to be 0, got %d", index)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -256,7 +248,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.rulesIndex) != 0 || len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
||||||
t.Errorf("rules is not empty")
|
t.Errorf("rules is not empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -346,10 +338,12 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
|
|
||||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||||
found := false
|
found := false
|
||||||
for _, rule := range manager.outgoingRules {
|
for _, arr := range manager.outgoingRules {
|
||||||
if rule.id == hookID {
|
for _, rule := range arr {
|
||||||
found = true
|
if rule.id == hookID {
|
||||||
break
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -364,9 +358,11 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||||
for _, rule := range manager.outgoingRules {
|
for _, arr := range manager.outgoingRules {
|
||||||
if rule.id == hookID {
|
for _, rule := range arr {
|
||||||
t.Fatalf("The hook was not removed properly.")
|
if rule.id == hookID {
|
||||||
|
t.Fatalf("The hook was not removed properly.")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user