diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index adc2d552a..e834cf8b6 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -23,8 +23,8 @@ type IFaceMapper interface { // Manager userspace firewall manager type Manager struct { - outgoingRules []Rule - incomingRules []Rule + outgoingRules map[string][]Rule + incomingRules map[string][]Rule rulesIndex map[string]int wgNetwork *net.IPNet decoders sync.Pool @@ -62,6 +62,8 @@ func Create(iface IFaceMapper) (*Manager, error) { return d }, }, + outgoingRules: make(map[string][]Rule), + incomingRules: make(map[string][]Rule), } if err := iface.SetFilter(m); err != nil { @@ -126,10 +128,10 @@ func (m *Manager) AddFiltering( m.mutex.Lock() var p int if direction == fw.RuleDirectionIN { - m.incomingRules = append(m.incomingRules, r) + m.incomingRules[r.ip.String()] = append(m.incomingRules[r.ip.String()], r) p = len(m.incomingRules) - 1 } else { - m.outgoingRules = append(m.outgoingRules, r) + m.outgoingRules[r.ip.String()] = append(m.outgoingRules[r.ip.String()], r) p = len(m.outgoingRules) - 1 } m.rulesIndex[r.id] = p @@ -156,11 +158,11 @@ func (m *Manager) DeleteRule(rule fw.Rule) error { var toUpdate []Rule if r.direction == fw.RuleDirectionIN { - m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...) - toUpdate = m.incomingRules + m.incomingRules[r.ip.String()] = append(m.incomingRules[r.ip.String()][:p], m.incomingRules[r.ip.String()][p+1:]...) + toUpdate = m.incomingRules[r.ip.String()] } else { - m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...) - toUpdate = m.outgoingRules + m.outgoingRules[r.ip.String()] = append(m.outgoingRules[r.ip.String()][:p], m.outgoingRules[r.ip.String()][p+1:]...) + toUpdate = m.outgoingRules[r.ip.String()] } for i := 0; i < len(toUpdate); i++ { @@ -174,8 +176,8 @@ func (m *Manager) Reset() error { m.mutex.Lock() defer m.mutex.Unlock() - m.outgoingRules = m.outgoingRules[:0] - m.incomingRules = m.incomingRules[:0] + m.outgoingRules = make(map[string][]Rule) + m.incomingRules = make(map[string][]Rule) m.rulesIndex = make(map[string]int) return nil @@ -192,7 +194,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool { } // dropFilter imlements same logic for booth direction of the traffic -func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool { +func (m *Manager) dropFilter(packetData []byte, rules map[string][]Rule, isIncomingPacket bool) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -226,29 +228,37 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b } payloadLayer := d.decoded[1] + var srcIP, dstIP net.IP + var ipRules []Rule + switch ipLayer { + case layers.LayerTypeIPv4: + if isIncomingPacket { + srcIP = d.ip4.SrcIP + ipRules = rules[srcIP.String()] + } else { + dstIP = d.ip4.DstIP + ipRules = rules[dstIP.String()] + } + case layers.LayerTypeIPv6: + if isIncomingPacket { + srcIP = d.ip6.SrcIP + ipRules = rules[srcIP.String()] + } else { + dstIP = d.ip6.DstIP + ipRules = rules[dstIP.String()] + } + } + // check if IP address match by IP - for _, rule := range rules { + for _, rule := range ipRules { if rule.matchByIP { - switch ipLayer { - case layers.LayerTypeIPv4: - if isIncomingPacket { - if !d.ip4.SrcIP.Equal(rule.ip) { - continue - } - } else { - if !d.ip4.DstIP.Equal(rule.ip) { - continue - } + if isIncomingPacket { + if !srcIP.Equal(rule.ip) { + continue } - case layers.LayerTypeIPv6: - if isIncomingPacket { - if !d.ip6.SrcIP.Equal(rule.ip) { - continue - } - } else { - if !d.ip6.DstIP.Equal(rule.ip) { - continue - } + } else { + if !dstIP.Equal(rule.ip) { + continue } } } @@ -328,11 +338,11 @@ func (m *Manager) AddUDPPacketHook( var toUpdate []Rule if in { r.direction = fw.RuleDirectionIN - m.incomingRules = append([]Rule{r}, m.incomingRules...) - toUpdate = m.incomingRules + m.incomingRules[r.ip.String()] = append([]Rule{r}, m.incomingRules[r.ip.String()]...) + toUpdate = m.incomingRules[r.ip.String()] } else { - m.outgoingRules = append([]Rule{r}, m.outgoingRules...) - toUpdate = m.outgoingRules + m.outgoingRules[r.ip.String()] = append([]Rule{r}, m.outgoingRules[r.ip.String()]...) + toUpdate = m.outgoingRules[r.ip.String()] } for i := range toUpdate { @@ -345,14 +355,18 @@ func (m *Manager) AddUDPPacketHook( // RemovePacketHook removes packet hook by given ID func (m *Manager) RemovePacketHook(hookID string) error { - for _, r := range m.incomingRules { - if r.id == hookID { - return m.DeleteRule(&r) + for _, arr := range m.incomingRules { + for _, r := range arr { + if r.id == hookID { + return m.DeleteRule(&r) + } } } - for _, r := range m.outgoingRules { - if r.id == hookID { - return m.DeleteRule(&r) + for _, arr := range m.outgoingRules { + for _, r := range arr { + if r.id == hookID { + return m.DeleteRule(&r) + } } } return fmt.Errorf("hook with given id not found")