Merge pull request #991 from netbirdio/fix/improve_uspfilter_performance

Improve userspace filter performance
This commit is contained in:
pascal-fischer 2023-07-12 18:02:29 +02:00 committed by GitHub
commit f40951cdf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 119 additions and 109 deletions

View File

@ -21,11 +21,13 @@ type IFaceMapper interface {
SetFilter(iface.PacketFilter) error SetFilter(iface.PacketFilter) error
} }
// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]Rule
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
outgoingRules []Rule outgoingRules map[string]RuleSet
incomingRules []Rule incomingRules map[string]RuleSet
rulesIndex map[string]int
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
@ -48,7 +50,6 @@ type decoder struct {
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface IFaceMapper) (*Manager, error) { func Create(iface IFaceMapper) (*Manager, error) {
m := &Manager{ m := &Manager{
rulesIndex: make(map[string]int),
decoders: sync.Pool{ decoders: sync.Pool{
New: func() any { New: func() any {
d := &decoder{ d := &decoder{
@ -62,6 +63,8 @@ func Create(iface IFaceMapper) (*Manager, error) {
return d return d
}, },
}, },
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
} }
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
@ -124,15 +127,17 @@ func (m *Manager) AddFiltering(
} }
m.mutex.Lock() m.mutex.Lock()
var p int
if direction == fw.RuleDirectionIN { if direction == fw.RuleDirectionIN {
m.incomingRules = append(m.incomingRules, r) if _, ok := m.incomingRules[r.ip.String()]; !ok {
p = len(m.incomingRules) - 1 m.incomingRules[r.ip.String()] = make(RuleSet)
}
m.incomingRules[r.ip.String()][r.id] = r
} else { } else {
m.outgoingRules = append(m.outgoingRules, r) if _, ok := m.outgoingRules[r.ip.String()]; !ok {
p = len(m.outgoingRules) - 1 m.outgoingRules[r.ip.String()] = make(RuleSet)
}
m.outgoingRules[r.ip.String()][r.id] = r
} }
m.rulesIndex[r.id] = p
m.mutex.Unlock() m.mutex.Unlock()
return &r, nil return &r, nil
@ -148,24 +153,20 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
p, ok := m.rulesIndex[r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.rulesIndex, r.id)
var toUpdate []Rule
if r.direction == fw.RuleDirectionIN { if r.direction == fw.RuleDirectionIN {
m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...) _, ok := m.incomingRules[r.ip.String()][r.id]
toUpdate = m.incomingRules if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip.String()], r.id)
} else { } else {
m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...) _, ok := m.outgoingRules[r.ip.String()][r.id]
toUpdate = m.outgoingRules if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.outgoingRules[r.ip.String()], r.id)
} }
for i := 0; i < len(toUpdate); i++ {
m.rulesIndex[toUpdate[i].id] = i
}
return nil return nil
} }
@ -174,9 +175,8 @@ func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = m.outgoingRules[:0] m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = m.incomingRules[:0] m.incomingRules = make(map[string]RuleSet)
m.rulesIndex = make(map[string]int)
return nil return nil
} }
@ -192,7 +192,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
} }
// dropFilter imlements same logic for booth direction of the traffic // dropFilter imlements same logic for booth direction of the traffic
func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool { func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
@ -224,37 +224,49 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
log.Errorf("unknown layer: %v", d.decoded[0]) log.Errorf("unknown layer: %v", d.decoded[0])
return true return true
} }
payloadLayer := d.decoded[1]
// check if IP address match by IP var ip net.IP
switch ipLayer {
case layers.LayerTypeIPv4:
if isIncomingPacket {
ip = d.ip4.SrcIP
} else {
ip = d.ip4.DstIP
}
case layers.LayerTypeIPv6:
if isIncomingPacket {
ip = d.ip6.SrcIP
} else {
ip = d.ip6.DstIP
}
}
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["::"], d)
if ok {
return filter
}
// default policy is DROP ALL
return true
}
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range rules { for _, rule := range rules {
if rule.matchByIP { if rule.matchByIP && !ip.Equal(rule.ip) {
switch ipLayer { continue
case layers.LayerTypeIPv4:
if isIncomingPacket {
if !d.ip4.SrcIP.Equal(rule.ip) {
continue
}
} else {
if !d.ip4.DstIP.Equal(rule.ip) {
continue
}
}
case layers.LayerTypeIPv6:
if isIncomingPacket {
if !d.ip6.SrcIP.Equal(rule.ip) {
continue
}
} else {
if !d.ip6.DstIP.Equal(rule.ip) {
continue
}
}
}
} }
if rule.protoLayer == layerTypeAll { if rule.protoLayer == layerTypeAll {
return rule.drop return rule.drop, true
} }
if payloadLayer != rule.protoLayer { if payloadLayer != rule.protoLayer {
@ -264,38 +276,36 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
switch payloadLayer { switch payloadLayer {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
if rule.sPort == 0 && rule.dPort == 0 { if rule.sPort == 0 && rule.dPort == 0 {
return rule.drop return rule.drop, true
} }
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) { if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
return rule.drop return rule.drop, true
} }
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) { if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
return rule.drop return rule.drop, true
} }
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule) // if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook // we ignore rule.drop and call this hook
if rule.udpHook != nil { if rule.udpHook != nil {
return rule.udpHook(packetData) return rule.udpHook(packetData), true
} }
if rule.sPort == 0 && rule.dPort == 0 { if rule.sPort == 0 && rule.dPort == 0 {
return rule.drop return rule.drop, true
} }
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) { if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
return rule.drop return rule.drop, true
} }
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) { if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
return rule.drop return rule.drop, true
} }
return rule.drop return rule.drop, true
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop return rule.drop, true
} }
} }
return false, false
// default policy is DROP ALL
return true
} }
// SetNetwork of the wireguard interface to which filtering applied // SetNetwork of the wireguard interface to which filtering applied
@ -325,19 +335,19 @@ func (m *Manager) AddUDPPacketHook(
} }
m.mutex.Lock() m.mutex.Lock()
var toUpdate []Rule
if in { if in {
r.direction = fw.RuleDirectionIN r.direction = fw.RuleDirectionIN
m.incomingRules = append([]Rule{r}, m.incomingRules...) if _, ok := m.incomingRules[r.ip.String()]; !ok {
toUpdate = m.incomingRules m.incomingRules[r.ip.String()] = make(map[string]Rule)
}
m.incomingRules[r.ip.String()][r.id] = r
} else { } else {
m.outgoingRules = append([]Rule{r}, m.outgoingRules...) if _, ok := m.outgoingRules[r.ip.String()]; !ok {
toUpdate = m.outgoingRules m.outgoingRules[r.ip.String()] = make(map[string]Rule)
}
m.outgoingRules[r.ip.String()][r.id] = r
} }
for i := range toUpdate {
m.rulesIndex[toUpdate[i].id] = i
}
m.mutex.Unlock() m.mutex.Unlock()
return r.id return r.id
@ -345,14 +355,18 @@ func (m *Manager) AddUDPPacketHook(
// RemovePacketHook removes packet hook by given ID // RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error { func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range m.incomingRules { for _, arr := range m.incomingRules {
if r.id == hookID { for _, r := range arr {
return m.DeleteRule(&r) if r.id == hookID {
return m.DeleteRule(&r)
}
} }
} }
for _, r := range m.outgoingRules { for _, arr := range m.outgoingRules {
if r.id == hookID { for _, r := range arr {
return m.DeleteRule(&r) if r.id == hookID {
return m.DeleteRule(&r)
}
} }
} }
return fmt.Errorf("hook with given id not found") return fmt.Errorf("hook with given id not found")

View File

@ -123,8 +123,8 @@ func TestManagerDeleteRule(t *testing.T) {
return return
} }
if idx, ok := m.rulesIndex[rule2.GetRuleID()]; !ok || len(m.incomingRules) != 1 || idx != 0 { if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the rulesIndex") t.Errorf("rule2 is not in the incomingRules")
} }
err = m.DeleteRule(rule2) err = m.DeleteRule(rule2)
@ -133,8 +133,8 @@ func TestManagerDeleteRule(t *testing.T) {
return return
} }
if len(m.rulesIndex) != 0 || len(m.incomingRules) != 0 { if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
t.Errorf("rule1 still in the rulesIndex") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@ -169,26 +169,29 @@ func TestAddUDPPacketHook(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
manager := &Manager{ manager := &Manager{
incomingRules: []Rule{}, incomingRules: map[string]RuleSet{},
outgoingRules: []Rule{}, outgoingRules: map[string]RuleSet{},
rulesIndex: make(map[string]int),
} }
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule Rule var addedRule Rule
if tt.in { if tt.in {
if len(manager.incomingRules) != 1 { if len(manager.incomingRules[tt.ip.String()]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return return
} }
addedRule = manager.incomingRules[0] for _, rule := range manager.incomingRules[tt.ip.String()] {
addedRule = rule
}
} else { } else {
if len(manager.outgoingRules) != 1 { if len(manager.outgoingRules) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return return
} }
addedRule = manager.outgoingRules[0] for _, rule := range manager.outgoingRules[tt.ip.String()] {
addedRule = rule
}
} }
if !tt.ip.Equal(addedRule.ip) { if !tt.ip.Equal(addedRule.ip) {
@ -211,17 +214,6 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected udpHook to be set") t.Errorf("expected udpHook to be set")
return return
} }
// Ensure rulesIndex is correctly updated
index, ok := manager.rulesIndex[addedRule.id]
if !ok {
t.Errorf("expected rule to be in rulesIndex")
return
}
if index != 0 {
t.Errorf("expected rule index to be 0, got %d", index)
return
}
}) })
} }
} }
@ -256,7 +248,7 @@ func TestManagerReset(t *testing.T) {
return return
} }
if len(m.rulesIndex) != 0 || len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 { if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
t.Errorf("rules is not empty") t.Errorf("rules is not empty")
} }
} }
@ -346,10 +338,12 @@ func TestRemovePacketHook(t *testing.T) {
// Assert the hook is added by finding it in the manager's outgoing rules // Assert the hook is added by finding it in the manager's outgoing rules
found := false found := false
for _, rule := range manager.outgoingRules { for _, arr := range manager.outgoingRules {
if rule.id == hookID { for _, rule := range arr {
found = true if rule.id == hookID {
break found = true
break
}
} }
} }
@ -364,9 +358,11 @@ func TestRemovePacketHook(t *testing.T) {
} }
// Assert the hook is removed by checking it in the manager's outgoing rules // Assert the hook is removed by checking it in the manager's outgoing rules
for _, rule := range manager.outgoingRules { for _, arr := range manager.outgoingRules {
if rule.id == hookID { for _, rule := range arr {
t.Fatalf("The hook was not removed properly.") if rule.id == hookID {
t.Fatalf("The hook was not removed properly.")
}
} }
} }
} }