diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index a4a7a6c3b..fa51122af 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -8,6 +8,7 @@ import ( "github.com/coreos/go-iptables/iptables" "github.com/google/uuid" + "github.com/nadoo/ipset" log "github.com/sirupsen/logrus" fw "github.com/netbirdio/netbird/client/firewall" @@ -35,6 +36,8 @@ type Manager struct { inputDefaultRuleSpecs []string outputDefaultRuleSpecs []string wgIface iFaceMapper + + rulesets map[string]ruleset } // iFaceMapper defines subset methods of interface required for manager @@ -43,6 +46,11 @@ type iFaceMapper interface { Address() iface.WGAddress } +type ruleset struct { + rule *Rule + ips map[string]string +} + // Create iptables firewall manager func Create(wgIface iFaceMapper) (*Manager, error) { m := &Manager{ @@ -51,6 +59,11 @@ func Create(wgIface iFaceMapper) (*Manager, error) { "-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()}, outputDefaultRuleSpecs: []string{ "-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()}, + rulesets: make(map[string]ruleset), + } + + if err := ipset.Init(); err != nil { + return nil, fmt.Errorf("init ipset: %w", err) } // init clients for booth ipv4 and ipv6 @@ -111,22 +124,45 @@ func (m *Manager) AddFiltering( if sPort != nil && sPort.Values != nil { sPortVal = strconv.Itoa(sPort.Values[0]) } + ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal) ruleID := uuid.New().String() if comment == "" { comment = ruleID } - specs := m.filterRuleSpecs( - "filter", - ip, - string(protocol), - sPortVal, - dPortVal, - direction, - action, - comment, - ) + if ipsetName != "" { + rs, rsExists := m.rulesets[ipsetName] + if !rsExists { + if err := ipset.Flush(ipsetName); err != nil { + log.Errorf("flush ipset %q before use it: %v", ipsetName, err) + } + if err := ipset.Create(ipsetName); err != nil { + return nil, fmt.Errorf("failed to create ipset: %w", err) + } + } + + if err := ipset.Add(ipsetName, ip.String()); err != nil { + return nil, fmt.Errorf("failed to add IP to ipset: %w", err) + } + + if rsExists { + // if ruleset already exists it means we already have the firewall rule + // so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager. + rs.ips[ip.String()] = ruleID + return &Rule{ + ruleID: ruleID, + ipsetName: ipsetName, + ip: ip.String(), + dst: direction == fw.RuleDirectionOUT, + v6: ip.To4() == nil, + }, nil + } + // this is new ipset so we need to create firewall rule for it + } + + specs := m.filterRuleSpecs("filter", ip, string(protocol), sPortVal, dPortVal, + direction, action, comment, ipsetName) if direction == fw.RuleDirectionOUT { ok, err := client.Exists("filter", ChainOutputFilterName, specs...) @@ -154,12 +190,24 @@ func (m *Manager) AddFiltering( } } - return &Rule{ - id: ruleID, - specs: specs, - dst: direction == fw.RuleDirectionOUT, - v6: ip.To4() == nil, - }, nil + rule := &Rule{ + ruleID: ruleID, + specs: specs, + ipsetName: ipsetName, + ip: ip.String(), + dst: direction == fw.RuleDirectionOUT, + v6: ip.To4() == nil, + } + if ipsetName != "" { + // ipset name is defined and it means that this rule was created + // for it, need to assosiate it with ruleset + m.rulesets[ipsetName] = ruleset{ + rule: rule, + ips: map[string]string{rule.ip: ruleID}, + } + } + + return rule, nil } // DeleteRule from the firewall by rule definition @@ -180,6 +228,31 @@ func (m *Manager) DeleteRule(rule fw.Rule) error { client = m.ipv6Client } + if rs, ok := m.rulesets[r.ipsetName]; ok { + // delete IP from ruleset IPs list and ipset + if _, ok := rs.ips[r.ip]; ok { + if err := ipset.Del(r.ipsetName, r.ip); err != nil { + return fmt.Errorf("failed to delete ip from ipset: %w", err) + } + delete(rs.ips, r.ip) + } + + // if after delete, set still contains other IPs, + // no need to delete firewall rule and we should exit here + if len(rs.ips) != 0 { + return nil + } + + // we delete last IP from the set, that means we need to delete + // set itself and assosiated firewall rule too + delete(m.rulesets, r.ipsetName) + + if err := ipset.Destroy(r.ipsetName); err != nil { + log.Errorf("delete empty ipset: %v", err) + } + r = rs.rule + } + if r.dst { return client.Delete("filter", ChainOutputFilterName, r.specs...) } @@ -246,6 +319,16 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error { return nil } + for ipsetName := range m.rulesets { + if err := ipset.Flush(ipsetName); err != nil { + log.Errorf("flush ipset %q during reset: %v", ipsetName, err) + } + if err := ipset.Destroy(ipsetName); err != nil { + log.Errorf("delete ipset %q during reset: %v", ipsetName, err) + } + delete(m.rulesets, ipsetName) + } + return nil } @@ -253,6 +336,7 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error { func (m *Manager) filterRuleSpecs( table string, ip net.IP, protocol string, sPort, dPort string, direction fw.RuleDirection, action fw.Action, comment string, + ipsetName string, ) (specs []string) { matchByIP := true // don't use IP matching if IP is ip 0.0.0.0 @@ -262,11 +346,19 @@ func (m *Manager) filterRuleSpecs( switch direction { case fw.RuleDirectionIN: if matchByIP { - specs = append(specs, "-s", ip.String()) + if ipsetName != "" { + specs = append(specs, "-m", "set", "--set", ipsetName, "src") + } else { + specs = append(specs, "-s", ip.String()) + } } case fw.RuleDirectionOUT: if matchByIP { - specs = append(specs, "-d", ip.String()) + if ipsetName != "" { + specs = append(specs, "-m", "set", "--set", ipsetName, "dst") + } else { + specs = append(specs, "-d", ip.String()) + } } } if protocol != "all" { @@ -348,3 +440,16 @@ func (m *Manager) actionToStr(action fw.Action) string { } return "DROP" } + +func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string { + if ipsetName == "" { + return "" + } else if sPort != "" && dPort != "" { + return ipsetName + "-sport-dport" + } else if sPort != "" { + return ipsetName + "-sport" + } else if dPort != "" { + return ipsetName + "-dport" + } + return ipsetName +} diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index cbf9d4c76..84e27ed14 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -55,12 +55,13 @@ func TestIptablesManager(t *testing.T) { // just check on the local interface manager, err := Create(mock) require.NoError(t, err) + time.Sleep(time.Second) defer func() { - if err := manager.Reset(); err != nil { - t.Errorf("clear the manager state: %v", err) - } + err := manager.Reset() + require.NoError(t, err, "clear the manager state") + time.Sleep(time.Second) }() @@ -88,19 +89,17 @@ func TestIptablesManager(t *testing.T) { }) t.Run("delete first rule", func(t *testing.T) { - if err := manager.DeleteRule(rule1); err != nil { - require.NoError(t, err, "failed to delete rule") - } + err := manager.DeleteRule(rule1) + require.NoError(t, err, "failed to delete rule") checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...) }) t.Run("delete second rule", func(t *testing.T) { - if err := manager.DeleteRule(rule2); err != nil { - require.NoError(t, err, "failed to delete rule") - } + err := manager.DeleteRule(rule2) + require.NoError(t, err, "failed to delete rule") - checkRuleSpecs(t, ipv4Client, ChainInputFilterName, false, rule2.(*Rule).specs...) + require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty") }) t.Run("reset check", func(t *testing.T) { @@ -122,6 +121,88 @@ func TestIptablesManager(t *testing.T) { }) } +func TestIptablesManagerIPSet(t *testing.T) { + ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + require.NoError(t, err) + + mock := &iFaceMock{ + NameFunc: func() string { + return "lo" + }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("10.20.0.1"), + Network: &net.IPNet{ + IP: net.ParseIP("10.20.0.0"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + } + }, + } + + // just check on the local interface + manager, err := Create(mock) + require.NoError(t, err) + + time.Sleep(time.Second) + + defer func() { + err := manager.Reset() + require.NoError(t, err, "clear the manager state") + + time.Sleep(time.Second) + }() + + var rule1 fw.Rule + t.Run("add first rule with set", func(t *testing.T) { + ip := net.ParseIP("10.20.0.2") + port := &fw.Port{Values: []int{8080}} + rule1, err = manager.AddFiltering( + ip, "tcp", nil, port, fw.RuleDirectionOUT, + fw.ActionAccept, "default", "accept HTTP traffic", + ) + require.NoError(t, err, "failed to add rule") + + checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...) + require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set") + require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set") + }) + + var rule2 fw.Rule + t.Run("add second rule", func(t *testing.T) { + ip := net.ParseIP("10.20.0.3") + port := &fw.Port{ + Values: []int{443}, + } + rule2, err = manager.AddFiltering( + ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, + "default", "accept HTTPS traffic from ports range", + ) + require.NoError(t, err, "failed to add rule") + require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set") + require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set") + }) + + t.Run("delete first rule", func(t *testing.T) { + err := manager.DeleteRule(rule1) + require.NoError(t, err, "failed to delete rule") + + require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index") + }) + + t.Run("delete second rule", func(t *testing.T) { + err := manager.DeleteRule(rule2) + require.NoError(t, err, "failed to delete rule") + + require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty") + }) + + t.Run("reset check", func(t *testing.T) { + err = manager.Reset() + require.NoError(t, err, "failed to reset") + }) +} + func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) { exists, err := ipv4Client.Exists("filter", chainName, rulespec...) require.NoError(t, err, "failed to check rule") @@ -153,9 +234,9 @@ func TestIptablesCreatePerformance(t *testing.T) { time.Sleep(time.Second) defer func() { - if err := manager.Reset(); err != nil { - t.Errorf("clear the manager state: %v", err) - } + err := manager.Reset() + require.NoError(t, err, "clear the manager state") + time.Sleep(time.Second) }() diff --git a/client/firewall/iptables/rule.go b/client/firewall/iptables/rule.go index a417891ea..f65030d39 100644 --- a/client/firewall/iptables/rule.go +++ b/client/firewall/iptables/rule.go @@ -2,13 +2,16 @@ package iptables // Rule to handle management of rules type Rule struct { - id string + ruleID string + ipsetName string + specs []string + ip string dst bool v6 bool } // GetRuleID returns the rule id func (r *Rule) GetRuleID() string { - return r.id + return r.ruleID } diff --git a/go.mod b/go.mod index 9e3d49490..4cb1c48c2 100644 --- a/go.mod +++ b/go.mod @@ -48,6 +48,7 @@ require ( github.com/mdlayher/socket v0.4.0 github.com/miekg/dns v1.1.43 github.com/mitchellh/hashstructure/v2 v2.0.2 + github.com/nadoo/ipset v0.5.0 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pion/logging v0.2.2 diff --git a/go.sum b/go.sum index 73d293084..b374e507e 100644 --- a/go.sum +++ b/go.sum @@ -485,6 +485,8 @@ github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8m github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= +github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= +github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c h1:wK/s4nyZj/GF/kFJQjX6nqNfE0G3gcqd6hhnPCyp4sw=