diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index adc2d552a..5cc215256 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -21,11 +21,13 @@ type IFaceMapper interface { SetFilter(iface.PacketFilter) error } +// RuleSet is a set of rules grouped by a string key +type RuleSet map[string]Rule + // Manager userspace firewall manager type Manager struct { - outgoingRules []Rule - incomingRules []Rule - rulesIndex map[string]int + outgoingRules map[string]RuleSet + incomingRules map[string]RuleSet wgNetwork *net.IPNet decoders sync.Pool @@ -48,7 +50,6 @@ type decoder struct { // Create userspace firewall manager constructor func Create(iface IFaceMapper) (*Manager, error) { m := &Manager{ - rulesIndex: make(map[string]int), decoders: sync.Pool{ New: func() any { d := &decoder{ @@ -62,6 +63,8 @@ func Create(iface IFaceMapper) (*Manager, error) { return d }, }, + outgoingRules: make(map[string]RuleSet), + incomingRules: make(map[string]RuleSet), } if err := iface.SetFilter(m); err != nil { @@ -124,15 +127,17 @@ func (m *Manager) AddFiltering( } m.mutex.Lock() - var p int if direction == fw.RuleDirectionIN { - m.incomingRules = append(m.incomingRules, r) - p = len(m.incomingRules) - 1 + if _, ok := m.incomingRules[r.ip.String()]; !ok { + m.incomingRules[r.ip.String()] = make(RuleSet) + } + m.incomingRules[r.ip.String()][r.id] = r } else { - m.outgoingRules = append(m.outgoingRules, r) - p = len(m.outgoingRules) - 1 + if _, ok := m.outgoingRules[r.ip.String()]; !ok { + m.outgoingRules[r.ip.String()] = make(RuleSet) + } + m.outgoingRules[r.ip.String()][r.id] = r } - m.rulesIndex[r.id] = p m.mutex.Unlock() return &r, nil @@ -148,24 +153,20 @@ func (m *Manager) DeleteRule(rule fw.Rule) error { return fmt.Errorf("delete rule: invalid rule type: %T", rule) } - p, ok := m.rulesIndex[r.id] - if !ok { - return fmt.Errorf("delete rule: no rule with such id: %v", r.id) - } - delete(m.rulesIndex, r.id) - - var toUpdate []Rule if r.direction == fw.RuleDirectionIN { - m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...) - toUpdate = m.incomingRules + _, ok := m.incomingRules[r.ip.String()][r.id] + if !ok { + return fmt.Errorf("delete rule: no rule with such id: %v", r.id) + } + delete(m.incomingRules[r.ip.String()], r.id) } else { - m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...) - toUpdate = m.outgoingRules + _, ok := m.outgoingRules[r.ip.String()][r.id] + if !ok { + return fmt.Errorf("delete rule: no rule with such id: %v", r.id) + } + delete(m.outgoingRules[r.ip.String()], r.id) } - for i := 0; i < len(toUpdate); i++ { - m.rulesIndex[toUpdate[i].id] = i - } return nil } @@ -174,9 +175,8 @@ func (m *Manager) Reset() error { m.mutex.Lock() defer m.mutex.Unlock() - m.outgoingRules = m.outgoingRules[:0] - m.incomingRules = m.incomingRules[:0] - m.rulesIndex = make(map[string]int) + m.outgoingRules = make(map[string]RuleSet) + m.incomingRules = make(map[string]RuleSet) return nil } @@ -192,7 +192,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool { } // dropFilter imlements same logic for booth direction of the traffic -func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool { +func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -224,37 +224,49 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b log.Errorf("unknown layer: %v", d.decoded[0]) return true } - payloadLayer := d.decoded[1] - // check if IP address match by IP + var ip net.IP + switch ipLayer { + case layers.LayerTypeIPv4: + if isIncomingPacket { + ip = d.ip4.SrcIP + } else { + ip = d.ip4.DstIP + } + case layers.LayerTypeIPv6: + if isIncomingPacket { + ip = d.ip6.SrcIP + } else { + ip = d.ip6.DstIP + } + } + + filter, ok := validateRule(ip, packetData, rules[ip.String()], d) + if ok { + return filter + } + filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d) + if ok { + return filter + } + filter, ok = validateRule(ip, packetData, rules["::"], d) + if ok { + return filter + } + + // default policy is DROP ALL + return true +} + +func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) { + payloadLayer := d.decoded[1] for _, rule := range rules { - if rule.matchByIP { - switch ipLayer { - case layers.LayerTypeIPv4: - if isIncomingPacket { - if !d.ip4.SrcIP.Equal(rule.ip) { - continue - } - } else { - if !d.ip4.DstIP.Equal(rule.ip) { - continue - } - } - case layers.LayerTypeIPv6: - if isIncomingPacket { - if !d.ip6.SrcIP.Equal(rule.ip) { - continue - } - } else { - if !d.ip6.DstIP.Equal(rule.ip) { - continue - } - } - } + if rule.matchByIP && !ip.Equal(rule.ip) { + continue } if rule.protoLayer == layerTypeAll { - return rule.drop + return rule.drop, true } if payloadLayer != rule.protoLayer { @@ -264,38 +276,36 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b switch payloadLayer { case layers.LayerTypeTCP: if rule.sPort == 0 && rule.dPort == 0 { - return rule.drop + return rule.drop, true } if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) { - return rule.drop + return rule.drop, true } if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) { - return rule.drop + return rule.drop, true } case layers.LayerTypeUDP: // if rule has UDP hook (and if we are here we match this rule) // we ignore rule.drop and call this hook if rule.udpHook != nil { - return rule.udpHook(packetData) + return rule.udpHook(packetData), true } if rule.sPort == 0 && rule.dPort == 0 { - return rule.drop + return rule.drop, true } if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) { - return rule.drop + return rule.drop, true } if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) { - return rule.drop + return rule.drop, true } - return rule.drop + return rule.drop, true case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: - return rule.drop + return rule.drop, true } } - - // default policy is DROP ALL - return true + return false, false } // SetNetwork of the wireguard interface to which filtering applied @@ -325,19 +335,19 @@ func (m *Manager) AddUDPPacketHook( } m.mutex.Lock() - var toUpdate []Rule if in { r.direction = fw.RuleDirectionIN - m.incomingRules = append([]Rule{r}, m.incomingRules...) - toUpdate = m.incomingRules + if _, ok := m.incomingRules[r.ip.String()]; !ok { + m.incomingRules[r.ip.String()] = make(map[string]Rule) + } + m.incomingRules[r.ip.String()][r.id] = r } else { - m.outgoingRules = append([]Rule{r}, m.outgoingRules...) - toUpdate = m.outgoingRules + if _, ok := m.outgoingRules[r.ip.String()]; !ok { + m.outgoingRules[r.ip.String()] = make(map[string]Rule) + } + m.outgoingRules[r.ip.String()][r.id] = r } - for i := range toUpdate { - m.rulesIndex[toUpdate[i].id] = i - } m.mutex.Unlock() return r.id @@ -345,14 +355,18 @@ func (m *Manager) AddUDPPacketHook( // RemovePacketHook removes packet hook by given ID func (m *Manager) RemovePacketHook(hookID string) error { - for _, r := range m.incomingRules { - if r.id == hookID { - return m.DeleteRule(&r) + for _, arr := range m.incomingRules { + for _, r := range arr { + if r.id == hookID { + return m.DeleteRule(&r) + } } } - for _, r := range m.outgoingRules { - if r.id == hookID { - return m.DeleteRule(&r) + for _, arr := range m.outgoingRules { + for _, r := range arr { + if r.id == hookID { + return m.DeleteRule(&r) + } } } return fmt.Errorf("hook with given id not found") diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index eed31c627..c7f38a44f 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -123,8 +123,8 @@ func TestManagerDeleteRule(t *testing.T) { return } - if idx, ok := m.rulesIndex[rule2.GetRuleID()]; !ok || len(m.incomingRules) != 1 || idx != 0 { - t.Errorf("rule2 is not in the rulesIndex") + if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok { + t.Errorf("rule2 is not in the incomingRules") } err = m.DeleteRule(rule2) @@ -133,8 +133,8 @@ func TestManagerDeleteRule(t *testing.T) { return } - if len(m.rulesIndex) != 0 || len(m.incomingRules) != 0 { - t.Errorf("rule1 still in the rulesIndex") + if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok { + t.Errorf("rule2 is not in the incomingRules") } } @@ -169,26 +169,29 @@ func TestAddUDPPacketHook(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { manager := &Manager{ - incomingRules: []Rule{}, - outgoingRules: []Rule{}, - rulesIndex: make(map[string]int), + incomingRules: map[string]RuleSet{}, + outgoingRules: map[string]RuleSet{}, } manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) var addedRule Rule if tt.in { - if len(manager.incomingRules) != 1 { + if len(manager.incomingRules[tt.ip.String()]) != 1 { t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) return } - addedRule = manager.incomingRules[0] + for _, rule := range manager.incomingRules[tt.ip.String()] { + addedRule = rule + } } else { if len(manager.outgoingRules) != 1 { t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) return } - addedRule = manager.outgoingRules[0] + for _, rule := range manager.outgoingRules[tt.ip.String()] { + addedRule = rule + } } if !tt.ip.Equal(addedRule.ip) { @@ -211,17 +214,6 @@ func TestAddUDPPacketHook(t *testing.T) { t.Errorf("expected udpHook to be set") return } - - // Ensure rulesIndex is correctly updated - index, ok := manager.rulesIndex[addedRule.id] - if !ok { - t.Errorf("expected rule to be in rulesIndex") - return - } - if index != 0 { - t.Errorf("expected rule index to be 0, got %d", index) - return - } }) } } @@ -256,7 +248,7 @@ func TestManagerReset(t *testing.T) { return } - if len(m.rulesIndex) != 0 || len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 { + if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 { t.Errorf("rules is not empty") } } @@ -346,10 +338,12 @@ func TestRemovePacketHook(t *testing.T) { // Assert the hook is added by finding it in the manager's outgoing rules found := false - for _, rule := range manager.outgoingRules { - if rule.id == hookID { - found = true - break + for _, arr := range manager.outgoingRules { + for _, rule := range arr { + if rule.id == hookID { + found = true + break + } } } @@ -364,9 +358,11 @@ func TestRemovePacketHook(t *testing.T) { } // Assert the hook is removed by checking it in the manager's outgoing rules - for _, rule := range manager.outgoingRules { - if rule.id == hookID { - t.Fatalf("The hook was not removed properly.") + for _, arr := range manager.outgoingRules { + for _, rule := range arr { + if rule.id == hookID { + t.Fatalf("The hook was not removed properly.") + } } } }