Merge pull request #991 from netbirdio/fix/improve_uspfilter_performance

Improve userspace filter performance
This commit is contained in:
pascal-fischer 2023-07-12 18:02:29 +02:00 committed by GitHub
commit f40951cdf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 119 additions and 109 deletions

View File

@ -21,11 +21,13 @@ type IFaceMapper interface {
SetFilter(iface.PacketFilter) error
}
// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]Rule
// Manager userspace firewall manager
type Manager struct {
outgoingRules []Rule
incomingRules []Rule
rulesIndex map[string]int
outgoingRules map[string]RuleSet
incomingRules map[string]RuleSet
wgNetwork *net.IPNet
decoders sync.Pool
@ -48,7 +50,6 @@ type decoder struct {
// Create userspace firewall manager constructor
func Create(iface IFaceMapper) (*Manager, error) {
m := &Manager{
rulesIndex: make(map[string]int),
decoders: sync.Pool{
New: func() any {
d := &decoder{
@ -62,6 +63,8 @@ func Create(iface IFaceMapper) (*Manager, error) {
return d
},
},
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
}
if err := iface.SetFilter(m); err != nil {
@ -124,15 +127,17 @@ func (m *Manager) AddFiltering(
}
m.mutex.Lock()
var p int
if direction == fw.RuleDirectionIN {
m.incomingRules = append(m.incomingRules, r)
p = len(m.incomingRules) - 1
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet)
}
m.incomingRules[r.ip.String()][r.id] = r
} else {
m.outgoingRules = append(m.outgoingRules, r)
p = len(m.outgoingRules) - 1
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(RuleSet)
}
m.outgoingRules[r.ip.String()][r.id] = r
}
m.rulesIndex[r.id] = p
m.mutex.Unlock()
return &r, nil
@ -148,24 +153,20 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
}
p, ok := m.rulesIndex[r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.rulesIndex, r.id)
var toUpdate []Rule
if r.direction == fw.RuleDirectionIN {
m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...)
toUpdate = m.incomingRules
_, ok := m.incomingRules[r.ip.String()][r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip.String()], r.id)
} else {
m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...)
toUpdate = m.outgoingRules
_, ok := m.outgoingRules[r.ip.String()][r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.outgoingRules[r.ip.String()], r.id)
}
for i := 0; i < len(toUpdate); i++ {
m.rulesIndex[toUpdate[i].id] = i
}
return nil
}
@ -174,9 +175,8 @@ func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = m.outgoingRules[:0]
m.incomingRules = m.incomingRules[:0]
m.rulesIndex = make(map[string]int)
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
return nil
}
@ -192,7 +192,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
}
// dropFilter imlements same logic for booth direction of the traffic
func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool {
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
@ -224,37 +224,49 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
log.Errorf("unknown layer: %v", d.decoded[0])
return true
}
payloadLayer := d.decoded[1]
// check if IP address match by IP
var ip net.IP
switch ipLayer {
case layers.LayerTypeIPv4:
if isIncomingPacket {
ip = d.ip4.SrcIP
} else {
ip = d.ip4.DstIP
}
case layers.LayerTypeIPv6:
if isIncomingPacket {
ip = d.ip6.SrcIP
} else {
ip = d.ip6.DstIP
}
}
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["::"], d)
if ok {
return filter
}
// default policy is DROP ALL
return true
}
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range rules {
if rule.matchByIP {
switch ipLayer {
case layers.LayerTypeIPv4:
if isIncomingPacket {
if !d.ip4.SrcIP.Equal(rule.ip) {
continue
}
} else {
if !d.ip4.DstIP.Equal(rule.ip) {
continue
}
}
case layers.LayerTypeIPv6:
if isIncomingPacket {
if !d.ip6.SrcIP.Equal(rule.ip) {
continue
}
} else {
if !d.ip6.DstIP.Equal(rule.ip) {
continue
}
}
}
if rule.matchByIP && !ip.Equal(rule.ip) {
continue
}
if rule.protoLayer == layerTypeAll {
return rule.drop
return rule.drop, true
}
if payloadLayer != rule.protoLayer {
@ -264,38 +276,36 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
switch payloadLayer {
case layers.LayerTypeTCP:
if rule.sPort == 0 && rule.dPort == 0 {
return rule.drop
return rule.drop, true
}
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
return rule.drop
return rule.drop, true
}
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
return rule.drop
return rule.drop, true
}
case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook
if rule.udpHook != nil {
return rule.udpHook(packetData)
return rule.udpHook(packetData), true
}
if rule.sPort == 0 && rule.dPort == 0 {
return rule.drop
return rule.drop, true
}
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
return rule.drop
return rule.drop, true
}
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
return rule.drop
return rule.drop, true
}
return rule.drop
return rule.drop, true
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop
return rule.drop, true
}
}
// default policy is DROP ALL
return true
return false, false
}
// SetNetwork of the wireguard interface to which filtering applied
@ -325,19 +335,19 @@ func (m *Manager) AddUDPPacketHook(
}
m.mutex.Lock()
var toUpdate []Rule
if in {
r.direction = fw.RuleDirectionIN
m.incomingRules = append([]Rule{r}, m.incomingRules...)
toUpdate = m.incomingRules
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]Rule)
}
m.incomingRules[r.ip.String()][r.id] = r
} else {
m.outgoingRules = append([]Rule{r}, m.outgoingRules...)
toUpdate = m.outgoingRules
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
}
m.outgoingRules[r.ip.String()][r.id] = r
}
for i := range toUpdate {
m.rulesIndex[toUpdate[i].id] = i
}
m.mutex.Unlock()
return r.id
@ -345,14 +355,18 @@ func (m *Manager) AddUDPPacketHook(
// RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range m.incomingRules {
if r.id == hookID {
return m.DeleteRule(&r)
for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID {
return m.DeleteRule(&r)
}
}
}
for _, r := range m.outgoingRules {
if r.id == hookID {
return m.DeleteRule(&r)
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID {
return m.DeleteRule(&r)
}
}
}
return fmt.Errorf("hook with given id not found")

View File

@ -123,8 +123,8 @@ func TestManagerDeleteRule(t *testing.T) {
return
}
if idx, ok := m.rulesIndex[rule2.GetRuleID()]; !ok || len(m.incomingRules) != 1 || idx != 0 {
t.Errorf("rule2 is not in the rulesIndex")
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
}
err = m.DeleteRule(rule2)
@ -133,8 +133,8 @@ func TestManagerDeleteRule(t *testing.T) {
return
}
if len(m.rulesIndex) != 0 || len(m.incomingRules) != 0 {
t.Errorf("rule1 still in the rulesIndex")
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
@ -169,26 +169,29 @@ func TestAddUDPPacketHook(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &Manager{
incomingRules: []Rule{},
outgoingRules: []Rule{},
rulesIndex: make(map[string]int),
incomingRules: map[string]RuleSet{},
outgoingRules: map[string]RuleSet{},
}
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule Rule
if tt.in {
if len(manager.incomingRules) != 1 {
if len(manager.incomingRules[tt.ip.String()]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return
}
addedRule = manager.incomingRules[0]
for _, rule := range manager.incomingRules[tt.ip.String()] {
addedRule = rule
}
} else {
if len(manager.outgoingRules) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return
}
addedRule = manager.outgoingRules[0]
for _, rule := range manager.outgoingRules[tt.ip.String()] {
addedRule = rule
}
}
if !tt.ip.Equal(addedRule.ip) {
@ -211,17 +214,6 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected udpHook to be set")
return
}
// Ensure rulesIndex is correctly updated
index, ok := manager.rulesIndex[addedRule.id]
if !ok {
t.Errorf("expected rule to be in rulesIndex")
return
}
if index != 0 {
t.Errorf("expected rule index to be 0, got %d", index)
return
}
})
}
}
@ -256,7 +248,7 @@ func TestManagerReset(t *testing.T) {
return
}
if len(m.rulesIndex) != 0 || len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
t.Errorf("rules is not empty")
}
}
@ -346,10 +338,12 @@ func TestRemovePacketHook(t *testing.T) {
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
for _, rule := range manager.outgoingRules {
if rule.id == hookID {
found = true
break
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
found = true
break
}
}
}
@ -364,9 +358,11 @@ func TestRemovePacketHook(t *testing.T) {
}
// Assert the hook is removed by checking it in the manager's outgoing rules
for _, rule := range manager.outgoingRules {
if rule.id == hookID {
t.Fatalf("The hook was not removed properly.")
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
t.Fatalf("The hook was not removed properly.")
}
}
}
}