From f03aadf064acd7b8b0b65ff991d8138f187bdccb Mon Sep 17 00:00:00 2001 From: Givi Khojanashvili Date: Thu, 16 Mar 2023 13:00:08 +0400 Subject: [PATCH] Feat firewall controller interface (#740) Add a standard interface for the client firewall to support ACL. --- client/firewall/firewall.go | 57 +++++++ client/firewall/iptables/manager_linux.go | 160 ++++++++++++++++++ .../firewall/iptables/manager_linux_test.go | 105 ++++++++++++ client/firewall/iptables/rule.go | 13 ++ client/firewall/port.go | 24 +++ 5 files changed, 359 insertions(+) create mode 100644 client/firewall/firewall.go create mode 100644 client/firewall/iptables/manager_linux.go create mode 100644 client/firewall/iptables/manager_linux_test.go create mode 100644 client/firewall/iptables/rule.go create mode 100644 client/firewall/port.go diff --git a/client/firewall/firewall.go b/client/firewall/firewall.go new file mode 100644 index 000000000..2e685e15c --- /dev/null +++ b/client/firewall/firewall.go @@ -0,0 +1,57 @@ +package firewall + +import ( + "net" +) + +// Rule abstraction should be implemented by each firewall manager +// +// Each firewall type for different OS can use different type +// of the properties to hold data of the created rule +type Rule interface { + // GetRuleID returns the rule id + GetRuleID() string +} + +// Direction is the direction of the traffic +type Direction int + +const ( + // DirectionSrc is the direction of the traffic from the source + DirectionSrc Direction = iota + // DirectionDst is the direction of the traffic from the destination + DirectionDst +) + +// Action is the action to be taken on a rule +type Action int + +const ( + // ActionAccept is the action to accept a packet + ActionAccept Action = iota + // ActionDrop is the action to drop a packet + ActionDrop +) + +// Manager is the high level abstraction of a firewall manager +// +// It declares methods which handle actions required by the +// Netbird client for ACL and routing functionality +type Manager interface { + // AddFiltering rule to the firewall + AddFiltering( + ip net.IP, + port *Port, + direction Direction, + action Action, + comment string, + ) (Rule, error) + + // DeleteRule from the firewall by rule definition + DeleteRule(rule Rule) error + + // Reset firewall to the default state + Reset() error + + // TODO: migrate routemanager firewal actions to this interface +} diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go new file mode 100644 index 000000000..e5aafd6d8 --- /dev/null +++ b/client/firewall/iptables/manager_linux.go @@ -0,0 +1,160 @@ +package iptables + +import ( + "fmt" + "net" + "strconv" + "sync" + + "github.com/coreos/go-iptables/iptables" + "github.com/google/uuid" + + fw "github.com/netbirdio/netbird/client/firewall" +) + +const ( + // ChainFilterName is the name of the chain that is used for filtering by the Netbird client + ChainFilterName = "NETBIRD-ACL" +) + +// Manager of iptables firewall +type Manager struct { + mutex sync.Mutex + + ipv4Client *iptables.IPTables + ipv6Client *iptables.IPTables +} + +// Create iptables firewall manager +func Create() (*Manager, error) { + m := &Manager{} + + // init clients for booth ipv4 and ipv6 + ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + if err != nil { + return nil, fmt.Errorf("iptables is not installed in the system or not supported") + } + m.ipv4Client = ipv4Client + + ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) + if err != nil { + return nil, fmt.Errorf("ip6tables is not installed in the system or not supported") + } + m.ipv6Client = ipv6Client + + if err := m.Reset(); err != nil { + return nil, fmt.Errorf("failed to reset firewall: %s", err) + } + + return m, nil +} + +// AddFiltering rule to the firewall +func (m *Manager) AddFiltering( + ip net.IP, + port *fw.Port, + direction fw.Direction, + action fw.Action, + comment string, +) (fw.Rule, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + client := m.client(ip) + ok, err := client.ChainExists("filter", ChainFilterName) + if err != nil { + return nil, fmt.Errorf("failed to check if chain exists: %s", err) + } + if !ok { + if err := client.NewChain("filter", ChainFilterName); err != nil { + return nil, fmt.Errorf("failed to create chain: %s", err) + } + } + if port == nil || port.Values == nil || (port.IsRange && len(port.Values) != 2) { + return nil, fmt.Errorf("invalid port definition") + } + pv := strconv.Itoa(port.Values[0]) + if port.IsRange { + pv += ":" + strconv.Itoa(port.Values[1]) + } + specs := m.filterRuleSpecs("filter", ChainFilterName, ip, pv, direction, action, comment) + if err := client.AppendUnique("filter", ChainFilterName, specs...); err != nil { + return nil, err + } + rule := &Rule{ + id: uuid.New().String(), + specs: specs, + v6: ip.To4() == nil, + } + return rule, nil +} + +// DeleteRule from the firewall by rule definition +func (m *Manager) DeleteRule(rule fw.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + r, ok := rule.(*Rule) + if !ok { + return fmt.Errorf("invalid rule type") + } + client := m.ipv4Client + if r.v6 { + client = m.ipv6Client + } + return client.Delete("filter", ChainFilterName, r.specs...) +} + +// Reset firewall to the default state +func (m *Manager) Reset() error { + m.mutex.Lock() + defer m.mutex.Unlock() + if err := m.reset(m.ipv4Client, "filter", ChainFilterName); err != nil { + return fmt.Errorf("clean ipv4 firewall ACL chain: %w", err) + } + if err := m.reset(m.ipv6Client, "filter", ChainFilterName); err != nil { + return fmt.Errorf("clean ipv6 firewall ACL chain: %w", err) + } + return nil +} + +// reset firewall chain, clear it and drop it +func (m *Manager) reset(client *iptables.IPTables, table, chain string) error { + ok, err := client.ChainExists(table, chain) + if err != nil { + return fmt.Errorf("failed to check if chain exists: %w", err) + } + if !ok { + return nil + } + if err := client.ClearChain(table, ChainFilterName); err != nil { + return fmt.Errorf("failed to clear chain: %w", err) + } + return client.DeleteChain(table, ChainFilterName) +} + +// filterRuleSpecs returns the specs of a filtering rule +func (m *Manager) filterRuleSpecs( + table string, chain string, ip net.IP, port string, + direction fw.Direction, action fw.Action, comment string, +) (specs []string) { + if direction == fw.DirectionSrc { + specs = append(specs, "-s", ip.String()) + } + specs = append(specs, "-p", "tcp", "--dport", port) + specs = append(specs, "-j", m.actionToStr(action)) + return append(specs, "-m", "comment", "--comment", comment) +} + +// client returns corresponding iptables client for the given ip +func (m *Manager) client(ip net.IP) *iptables.IPTables { + if ip.To4() != nil { + return m.ipv4Client + } + return m.ipv6Client +} + +func (m *Manager) actionToStr(action fw.Action) string { + if action == fw.ActionAccept { + return "ACCEPT" + } + return "DROP" +} diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go new file mode 100644 index 000000000..d576cb803 --- /dev/null +++ b/client/firewall/iptables/manager_linux_test.go @@ -0,0 +1,105 @@ +package iptables + +import ( + "net" + "testing" + + "github.com/coreos/go-iptables/iptables" + fw "github.com/netbirdio/netbird/client/firewall" +) + +func TestNewManager(t *testing.T) { + ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + if err != nil { + t.Fatal(err) + } + + manager, err := Create() + if err != nil { + t.Fatal(err) + } + + var rule1 fw.Rule + t.Run("add first rule", func(t *testing.T) { + ip := net.ParseIP("10.20.0.2") + port := &fw.Port{Proto: fw.PortProtocolTCP, Values: []int{8080}} + rule1, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTP traffic") + if err != nil { + t.Errorf("failed to add rule: %v", err) + } + + checkRuleSpecs(t, ipv4Client, true, rule1.(*Rule).specs...) + }) + + var rule2 fw.Rule + t.Run("add second rule", func(t *testing.T) { + ip := net.ParseIP("10.20.0.3") + port := &fw.Port{ + Proto: fw.PortProtocolTCP, + Values: []int{8043: 8046}, + } + rule2, err = manager.AddFiltering( + ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTPS traffic from ports range") + if err != nil { + t.Errorf("failed to add rule: %v", err) + } + + checkRuleSpecs(t, ipv4Client, true, rule2.(*Rule).specs...) + }) + + t.Run("delete first rule", func(t *testing.T) { + if err := manager.DeleteRule(rule1); err != nil { + t.Errorf("failed to delete rule: %v", err) + } + + checkRuleSpecs(t, ipv4Client, false, rule1.(*Rule).specs...) + }) + + t.Run("delete second rule", func(t *testing.T) { + if err := manager.DeleteRule(rule2); err != nil { + t.Errorf("failed to delete rule: %v", err) + } + + checkRuleSpecs(t, ipv4Client, false, rule2.(*Rule).specs...) + }) + + t.Run("reset check", func(t *testing.T) { + // add second rule + ip := net.ParseIP("10.20.0.3") + port := &fw.Port{Proto: fw.PortProtocolUDP, Values: []int{5353}} + _, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept Fake DNS traffic") + if err != nil { + t.Errorf("failed to add rule: %v", err) + } + + if err := manager.Reset(); err != nil { + t.Errorf("failed to reset: %v", err) + } + + ok, err := ipv4Client.ChainExists("filter", ChainFilterName) + if err != nil { + t.Errorf("failed to drop chain: %v", err) + } + + if ok { + t.Errorf("chain '%v' still exists after Reset", ChainFilterName) + } + }) +} + +func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, mustExists bool, rulespec ...string) { + exists, err := ipv4Client.Exists("filter", ChainFilterName, rulespec...) + if err != nil { + t.Errorf("failed to check rule: %v", err) + return + } + + if !exists && mustExists { + t.Errorf("rule '%v' does not exist", rulespec) + return + } + if exists && !mustExists { + t.Errorf("rule '%v' exist", rulespec) + return + } +} diff --git a/client/firewall/iptables/rule.go b/client/firewall/iptables/rule.go new file mode 100644 index 000000000..4b5807a9b --- /dev/null +++ b/client/firewall/iptables/rule.go @@ -0,0 +1,13 @@ +package iptables + +// Rule to handle management of rules +type Rule struct { + id string + specs []string + v6 bool +} + +// GetRuleID returns the rule id +func (r *Rule) GetRuleID() string { + return r.id +} diff --git a/client/firewall/port.go b/client/firewall/port.go new file mode 100644 index 000000000..fc09c51f3 --- /dev/null +++ b/client/firewall/port.go @@ -0,0 +1,24 @@ +package firewall + +// PortProtocol is the protocol of the port +type PortProtocol string + +const ( + // PortProtocolTCP is the TCP protocol + PortProtocolTCP PortProtocol = "tcp" + + // PortProtocolUDP is the UDP protocol + PortProtocolUDP PortProtocol = "udp" +) + +// Port of the address for firewall rule +type Port struct { + // IsRange is true Values contains two values, the first is the start port, the second is the end port + IsRange bool + + // Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports + Values []int + + // Proto is the protocol of the port + Proto PortProtocol +}