mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
[management, client] Add access control support to network routes (#2100)
This commit is contained in:
parent
a3a479429e
commit
ff7863785f
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,
|
ignore_words_list: erro,clienta,hastable,iif
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
only_warn: 1
|
only_warn: 1
|
||||||
golangci:
|
golangci:
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/iface"
|
import (
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
)
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
|
@ -19,24 +19,22 @@ const (
|
|||||||
// rules chains contains the effective ACL rules
|
// rules chains contains the effective ACL rules
|
||||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||||
|
|
||||||
postRoutingMark = "0x000007e4"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type aclManager struct {
|
type aclManager struct {
|
||||||
iptablesClient *iptables.IPTables
|
iptablesClient *iptables.IPTables
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
routeingFwChainName string
|
routingFwChainName string
|
||||||
|
|
||||||
entries map[string][][]string
|
entries map[string][][]string
|
||||||
ipsetStore *ipsetStore
|
ipsetStore *ipsetStore
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
|
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
||||||
m := &aclManager{
|
m := &aclManager{
|
||||||
iptablesClient: iptablesClient,
|
iptablesClient: iptablesClient,
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
routeingFwChainName: routeingFwChainName,
|
routingFwChainName: routingFwChainName,
|
||||||
|
|
||||||
entries: make(map[string][][]string),
|
entries: make(map[string][][]string),
|
||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
@ -61,7 +59,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, route
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) AddFiltering(
|
func (m *aclManager) AddPeerFiltering(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
@ -127,7 +125,7 @@ func (m *aclManager) AddFiltering(
|
|||||||
return nil, fmt.Errorf("rule already exists")
|
return nil, fmt.Errorf("rule already exists")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
|
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,28 +137,16 @@ func (m *aclManager) AddFiltering(
|
|||||||
chain: chain,
|
chain: chain,
|
||||||
}
|
}
|
||||||
|
|
||||||
if !shouldAddToPrerouting(protocol, dPort, direction) {
|
return []firewall.Rule{rule}, nil
|
||||||
return []firewall.Rule{rule}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
|
|
||||||
if err != nil {
|
|
||||||
return []firewall.Rule{rule}, err
|
|
||||||
}
|
|
||||||
return []firewall.Rule{rule, rulePrerouting}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
func (m *aclManager) DeleteRule(rule firewall.Rule) error {
|
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||||
r, ok := rule.(*Rule)
|
r, ok := rule.(*Rule)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid rule type")
|
return fmt.Errorf("invalid rule type")
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.chain == "PREROUTING" {
|
|
||||||
goto DELETERULE
|
|
||||||
}
|
|
||||||
|
|
||||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||||
// delete IP from ruleset IPs list and ipset
|
// delete IP from ruleset IPs list and ipset
|
||||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||||
@ -185,14 +171,7 @@ func (m *aclManager) DeleteRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
DELETERULE:
|
err := m.iptablesClient.Delete(tableName, r.chain, r.specs...)
|
||||||
var table string
|
|
||||||
if r.chain == "PREROUTING" {
|
|
||||||
table = "mangle"
|
|
||||||
} else {
|
|
||||||
table = "filter"
|
|
||||||
}
|
|
||||||
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
|
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
|
||||||
}
|
}
|
||||||
@ -203,44 +182,6 @@ func (m *aclManager) Reset() error {
|
|||||||
return m.cleanChains()
|
return m.cleanChains()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
|
|
||||||
var src []string
|
|
||||||
if ipsetName != "" {
|
|
||||||
src = []string{"-m", "set", "--set", ipsetName, "src"}
|
|
||||||
} else {
|
|
||||||
src = []string{"-s", ip.String()}
|
|
||||||
}
|
|
||||||
specs := []string{
|
|
||||||
"-d", m.wgIface.Address().IP.String(),
|
|
||||||
"-p", protocol,
|
|
||||||
"--dport", port,
|
|
||||||
"-j", "MARK", "--set-mark", postRoutingMark,
|
|
||||||
}
|
|
||||||
|
|
||||||
specs = append(src, specs...)
|
|
||||||
|
|
||||||
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
return nil, fmt.Errorf("rule already exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rule := &Rule{
|
|
||||||
ruleID: uuid.New().String(),
|
|
||||||
specs: specs,
|
|
||||||
ipsetName: ipsetName,
|
|
||||||
ip: ip.String(),
|
|
||||||
chain: "PREROUTING",
|
|
||||||
}
|
|
||||||
return rule, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// todo write less destructive cleanup mechanism
|
// todo write less destructive cleanup mechanism
|
||||||
func (m *aclManager) cleanChains() error {
|
func (m *aclManager) cleanChains() error {
|
||||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
||||||
@ -291,25 +232,6 @@ func (m *aclManager) cleanChains() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to list chains: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
for _, rule := range m.entries["PREROUTING"] {
|
|
||||||
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||||
if err := ipset.Flush(ipsetName); err != nil {
|
if err := ipset.Flush(ipsetName); err != nil {
|
||||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||||
@ -338,17 +260,9 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
|
|
||||||
for chainName, rules := range m.entries {
|
for chainName, rules := range m.entries {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if chainName == "FORWARD" {
|
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||||
// position 2 because we add it after router's, jump rule
|
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||||
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
|
return err
|
||||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
|
|
||||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -356,40 +270,29 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
|
||||||
|
// We want to make sure our traffic is not dropped by existing rules.
|
||||||
|
|
||||||
|
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||||
|
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||||
|
|
||||||
|
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
|
||||||
func (m *aclManager) seedInitialEntries() {
|
func (m *aclManager) seedInitialEntries() {
|
||||||
m.appendToEntries("INPUT",
|
|
||||||
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
|
||||||
|
|
||||||
m.appendToEntries("INPUT",
|
established := getConntrackEstablished()
|
||||||
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
|
||||||
|
|
||||||
m.appendToEntries("INPUT",
|
|
||||||
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
|
|
||||||
|
|
||||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||||
|
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||||
m.appendToEntries("OUTPUT",
|
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||||
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
|
||||||
|
|
||||||
m.appendToEntries("OUTPUT",
|
|
||||||
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
|
||||||
|
|
||||||
m.appendToEntries("OUTPUT",
|
|
||||||
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
|
|
||||||
|
|
||||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
|
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
|
||||||
|
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
|
||||||
|
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||||
|
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))
|
||||||
|
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
|
||||||
m.appendToEntries("FORWARD",
|
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
||||||
[]string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
|
||||||
m.appendToEntries("FORWARD",
|
|
||||||
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
|
||||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
|
||||||
|
|
||||||
m.appendToEntries("PREROUTING",
|
|
||||||
[]string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||||
@ -456,18 +359,3 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string {
|
|||||||
return ipsetName
|
return ipsetName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
|
|
||||||
if proto == "all" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if direction != firewall.RuleDirectionIN {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if dPort == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
@ -21,7 +22,7 @@ type Manager struct {
|
|||||||
|
|
||||||
ipv4Client *iptables.IPTables
|
ipv4Client *iptables.IPTables
|
||||||
aclMgr *aclManager
|
aclMgr *aclManager
|
||||||
router *routerManager
|
router *router
|
||||||
}
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
@ -43,12 +44,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
|||||||
ipv4Client: iptablesClient,
|
ipv4Client: iptablesClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
m.router, err = newRouterManager(context, iptablesClient)
|
m.router, err = newRouter(context, iptablesClient, wgIface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to initialize route related chains: %s", err)
|
log.Debugf("failed to initialize route related chains: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
|
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to initialize ACL manager: %s", err)
|
log.Debugf("failed to initialize ACL manager: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -57,10 +58,10 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
// AddPeerFiltering adds a rule to the firewall
|
||||||
//
|
//
|
||||||
// Comment will be ignored because some system this feature is not supported
|
// Comment will be ignored because some system this feature is not supported
|
||||||
func (m *Manager) AddFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
@ -73,33 +74,62 @@ func (m *Manager) AddFiltering(
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
func (m *Manager) AddRouteFiltering(
|
||||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
sources [] netip.Prefix,
|
||||||
|
destination netip.Prefix,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.aclMgr.DeleteRule(rule)
|
if !destination.Addr().Is4() {
|
||||||
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePeerRule from the firewall by rule definition
|
||||||
|
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.aclMgr.DeletePeerRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.router.InsertRoutingRules(pair)
|
return m.router.AddNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.router.RemoveRoutingRules(pair)
|
return m.router.RemoveNatRule(pair)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
|
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
@ -125,7 +155,7 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := m.AddFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
net.ParseIP("0.0.0.0"),
|
net.ParseIP("0.0.0.0"),
|
||||||
"all",
|
"all",
|
||||||
nil,
|
nil,
|
||||||
@ -138,7 +168,7 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
||||||
}
|
}
|
||||||
_, err = m.AddFiltering(
|
_, err = m.AddPeerFiltering(
|
||||||
net.ParseIP("0.0.0.0"),
|
net.ParseIP("0.0.0.0"),
|
||||||
"all",
|
"all",
|
||||||
nil,
|
nil,
|
||||||
@ -153,3 +183,7 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
|
func getConntrackEstablished() []string {
|
||||||
|
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||||
|
}
|
||||||
|
@ -14,6 +14,21 @@ import (
|
|||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ifaceMock = &iFaceMock{
|
||||||
|
NameFunc: func() string {
|
||||||
|
return "lo"
|
||||||
|
},
|
||||||
|
AddressFunc: func() iface.WGAddress {
|
||||||
|
return iface.WGAddress{
|
||||||
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMock struct {
|
type iFaceMock struct {
|
||||||
NameFunc func() string
|
NameFunc func() string
|
||||||
@ -40,23 +55,8 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
mock := &iFaceMock{
|
|
||||||
NameFunc: func() string {
|
|
||||||
return "lo"
|
|
||||||
},
|
|
||||||
AddressFunc: func() iface.WGAddress {
|
|
||||||
return iface.WGAddress{
|
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
|
||||||
Network: &net.IPNet{
|
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(context.Background(), mock)
|
manager, err := Create(context.Background(), ifaceMock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@ -72,7 +72,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
t.Run("add first rule", func(t *testing.T) {
|
t.Run("add first rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.2")
|
ip := net.ParseIP("10.20.0.2")
|
||||||
port := &fw.Port{Values: []int{8080}}
|
port := &fw.Port{Values: []int{8080}}
|
||||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
for _, r := range rule1 {
|
for _, r := range rule1 {
|
||||||
@ -87,7 +87,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []int{8043: 8046},
|
Values: []int{8043: 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddFiltering(
|
rule2, err = manager.AddPeerFiltering(
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
@ -99,7 +99,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
t.Run("delete first rule", func(t *testing.T) {
|
||||||
for _, r := range rule1 {
|
for _, r := range rule1 {
|
||||||
err := manager.DeleteRule(r)
|
err := manager.DeletePeerRule(r)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
||||||
@ -108,7 +108,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
err := manager.DeleteRule(r)
|
err := manager.DeletePeerRule(r)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,7 +119,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []int{5353}}
|
port := &fw.Port{Values: []int{5353}}
|
||||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
@ -170,7 +170,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
t.Run("add first rule with set", func(t *testing.T) {
|
t.Run("add first rule with set", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.2")
|
ip := net.ParseIP("10.20.0.2")
|
||||||
port := &fw.Port{Values: []int{8080}}
|
port := &fw.Port{Values: []int{8080}}
|
||||||
rule1, err = manager.AddFiltering(
|
rule1, err = manager.AddPeerFiltering(
|
||||||
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
||||||
fw.ActionAccept, "default", "accept HTTP traffic",
|
fw.ActionAccept, "default", "accept HTTP traffic",
|
||||||
)
|
)
|
||||||
@ -189,7 +189,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []int{443},
|
Values: []int{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddFiltering(
|
rule2, err = manager.AddPeerFiltering(
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||||
"default", "accept HTTPS traffic from ports range",
|
"default", "accept HTTPS traffic from ports range",
|
||||||
)
|
)
|
||||||
@ -202,7 +202,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
t.Run("delete first rule", func(t *testing.T) {
|
||||||
for _, r := range rule1 {
|
for _, r := range rule1 {
|
||||||
err := manager.DeleteRule(r)
|
err := manager.DeletePeerRule(r)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||||
@ -211,7 +211,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
err := manager.DeleteRule(r)
|
err := manager.DeletePeerRule(r)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||||
@ -269,9 +269,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
@ -5,368 +5,478 @@ package iptables
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/nadoo/ipset"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Ipv4Forwarding = "netbird-rt-forwarding"
|
ipv4Nat = "netbird-rt-nat"
|
||||||
ipv4Nat = "netbird-rt-nat"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// constants needed to manage and create iptable rules
|
// constants needed to manage and create iptable rules
|
||||||
const (
|
const (
|
||||||
tableFilter = "filter"
|
tableFilter = "filter"
|
||||||
tableNat = "nat"
|
tableNat = "nat"
|
||||||
chainFORWARD = "FORWARD"
|
|
||||||
chainPOSTROUTING = "POSTROUTING"
|
chainPOSTROUTING = "POSTROUTING"
|
||||||
chainRTNAT = "NETBIRD-RT-NAT"
|
chainRTNAT = "NETBIRD-RT-NAT"
|
||||||
chainRTFWD = "NETBIRD-RT-FWD"
|
chainRTFWD = "NETBIRD-RT-FWD"
|
||||||
routingFinalForwardJump = "ACCEPT"
|
routingFinalForwardJump = "ACCEPT"
|
||||||
routingFinalNatJump = "MASQUERADE"
|
routingFinalNatJump = "MASQUERADE"
|
||||||
|
|
||||||
|
matchSet = "--match-set"
|
||||||
)
|
)
|
||||||
|
|
||||||
type routerManager struct {
|
type routeFilteringRuleParams struct {
|
||||||
ctx context.Context
|
Sources []netip.Prefix
|
||||||
stop context.CancelFunc
|
Destination netip.Prefix
|
||||||
iptablesClient *iptables.IPTables
|
Proto firewall.Protocol
|
||||||
rules map[string][]string
|
SPort *firewall.Port
|
||||||
|
DPort *firewall.Port
|
||||||
|
Direction firewall.RuleDirection
|
||||||
|
Action firewall.Action
|
||||||
|
SetName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
|
type router struct {
|
||||||
|
ctx context.Context
|
||||||
|
stop context.CancelFunc
|
||||||
|
iptablesClient *iptables.IPTables
|
||||||
|
rules map[string][]string
|
||||||
|
ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||||
|
wgIface iFaceMapper
|
||||||
|
legacyManagement bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
m := &routerManager{
|
r := &router{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
iptablesClient: iptablesClient,
|
iptablesClient: iptablesClient,
|
||||||
rules: make(map[string][]string),
|
rules: make(map[string][]string),
|
||||||
|
wgIface: wgIface,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := m.cleanUpDefaultForwardRules()
|
r.ipsetCounter = refcounter.New(
|
||||||
|
r.createIpSet,
|
||||||
|
func(name string, _ struct{}) error {
|
||||||
|
return r.deleteIpSet(name)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := ipset.Init(); err != nil {
|
||||||
|
return nil, fmt.Errorf("init ipset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.cleanUpDefaultForwardRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to cleanup routing rules: %s", err)
|
log.Errorf("cleanup routing rules: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = m.createContainers()
|
err = r.createContainers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create containers for route: %s", err)
|
log.Errorf("create containers for route: %s", err)
|
||||||
}
|
}
|
||||||
return m, err
|
return r, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
|
func (r *router) AddRouteFiltering(
|
||||||
func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
|
sources []netip.Prefix,
|
||||||
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
|
destination netip.Prefix,
|
||||||
if err != nil {
|
proto firewall.Protocol,
|
||||||
return err
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
|
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
|
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||||
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
|
var setName string
|
||||||
if err != nil {
|
if len(sources) > 1 {
|
||||||
return err
|
setName = firewall.GenerateSetName(sources)
|
||||||
|
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
|
||||||
|
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
params := routeFilteringRuleParams{
|
||||||
|
Sources: sources,
|
||||||
|
Destination: destination,
|
||||||
|
Proto: proto,
|
||||||
|
SPort: sPort,
|
||||||
|
DPort: dPort,
|
||||||
|
Action: action,
|
||||||
|
SetName: setName,
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := genRouteFilteringRuleSpec(params)
|
||||||
|
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
||||||
|
return nil, fmt.Errorf("add route rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[string(ruleKey)] = rule
|
||||||
|
|
||||||
|
return ruleKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
|
ruleKey := rule.GetRuleID()
|
||||||
|
|
||||||
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
|
setName := r.findSetNameInRule(rule)
|
||||||
|
|
||||||
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
|
||||||
|
return fmt.Errorf("delete route rule: %v", err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if setName != "" {
|
||||||
|
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||||
|
return fmt.Errorf("failed to remove ipset: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debugf("route rule %s not found", ruleKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) findSetNameInRule(rule []string) string {
|
||||||
|
for i, arg := range rule {
|
||||||
|
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||||
|
return rule[i+3]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) {
|
||||||
|
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
||||||
|
return struct{}{}, fmt.Errorf("create set %s: %w", setName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range sources {
|
||||||
|
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
||||||
|
return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return struct{}{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) deleteIpSet(setName string) error {
|
||||||
|
if err := ipset.Destroy(setName); err != nil {
|
||||||
|
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||||
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
|
if r.legacyManagement {
|
||||||
|
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||||
|
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||||
|
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !pair.Masquerade {
|
if !pair.Masquerade {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
if err := r.addNatRule(pair); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("add nat rule: %w", err)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertRoutingRule inserts an iptables rule
|
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||||
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
var err error
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
|
return fmt.Errorf("remove nat rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
rule := genRuleSpec(jump, pair.Source, pair.Destination)
|
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||||
existingRule, found := i.rules[ruleKey]
|
}
|
||||||
if found {
|
|
||||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||||
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||||
|
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||||
|
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
||||||
|
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[ruleKey] = rule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||||
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
delete(i.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
}
|
} else {
|
||||||
|
log.Debugf("legacy forwarding rule %s not found", ruleKey)
|
||||||
err = i.iptablesClient.Insert(table, chain, 1, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
i.rules[ruleKey] = rule
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
|
|
||||||
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
|
||||||
err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !pair.Masquerade {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
|
// GetLegacyManagement returns the current legacy management mode
|
||||||
var err error
|
func (r *router) GetLegacyManagement() bool {
|
||||||
|
return r.legacyManagement
|
||||||
|
}
|
||||||
|
|
||||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||||
existingRule, found := i.rules[ruleKey]
|
func (r *router) SetLegacyManagement(isLegacy bool) {
|
||||||
if found {
|
r.legacyManagement = isLegacy
|
||||||
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
}
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||||
|
func (r *router) RemoveAllLegacyRouteRules() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for k, rule := range r.rules {
|
||||||
|
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(i.rules, ruleKey)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *routerManager) RouteingFwChainName() string {
|
func (r *router) Reset() error {
|
||||||
return chainRTFWD
|
var merr *multierror.Error
|
||||||
}
|
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
func (i *routerManager) Reset() error {
|
|
||||||
err := i.cleanUpDefaultForwardRules()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
i.rules = make(map[string][]string)
|
r.rules = make(map[string][]string)
|
||||||
return nil
|
|
||||||
|
if err := r.ipsetCounter.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *routerManager) cleanUpDefaultForwardRules() error {
|
func (r *router) cleanUpDefaultForwardRules() error {
|
||||||
err := i.cleanJumpRules()
|
err := r.cleanJumpRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("flushing routing related tables")
|
log.Debug("flushing routing related tables")
|
||||||
ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
|
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||||
if err != nil {
|
table := tableFilter
|
||||||
log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
|
if chain == chainRTNAT {
|
||||||
return err
|
table = tableNat
|
||||||
} else if ok {
|
}
|
||||||
err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
|
|
||||||
|
ok, err := r.iptablesClient.ChainExists(table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
|
log.Errorf("failed check chain %s, error: %v", chain, err)
|
||||||
return err
|
return err
|
||||||
|
} else if ok {
|
||||||
|
err = r.iptablesClient.ClearAndDeleteChain(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
|
|
||||||
return err
|
|
||||||
} else if ok {
|
|
||||||
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *routerManager) createContainers() error {
|
|
||||||
if i.rules[Ipv4Forwarding] != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
errMSGFormat := "failed creating chain %s,error: %v"
|
|
||||||
err := i.createChain(tableFilter, chainRTFWD)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.createChain(tableNat, chainRTNAT)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, chainRTNAT, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.addJumpRules()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error while creating jump rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addJumpRules create jump rules to send packets to NetBird chains
|
func (r *router) createContainers() error {
|
||||||
func (i *routerManager) addJumpRules() error {
|
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||||
rule := []string{"-j", chainRTFWD}
|
if err := r.createAndSetupChain(chain); err != nil {
|
||||||
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
|
return fmt.Errorf("create chain %s: %v", chain, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
|
||||||
|
return fmt.Errorf("insert established rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.addJumpRules()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) createAndSetupChain(chain string) error {
|
||||||
|
table := r.getTableForChain(chain)
|
||||||
|
|
||||||
|
if err := r.iptablesClient.NewChain(table, chain); err != nil {
|
||||||
|
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) getTableForChain(chain string) string {
|
||||||
|
if chain == chainRTNAT {
|
||||||
|
return tableNat
|
||||||
|
}
|
||||||
|
return tableFilter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) insertEstablishedRule(chain string) error {
|
||||||
|
establishedRule := getConntrackEstablished()
|
||||||
|
|
||||||
|
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to insert established rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := "established-" + chain
|
||||||
|
r.rules[ruleKey] = establishedRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) addJumpRules() error {
|
||||||
|
rule := []string{"-j", chainRTNAT}
|
||||||
|
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
i.rules[Ipv4Forwarding] = rule
|
r.rules[ipv4Nat] = rule
|
||||||
|
|
||||||
rule = []string{"-j", chainRTNAT}
|
|
||||||
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
i.rules[ipv4Nat] = rule
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
|
func (r *router) cleanJumpRules() error {
|
||||||
func (i *routerManager) cleanJumpRules() error {
|
rule, found := r.rules[ipv4Nat]
|
||||||
var err error
|
|
||||||
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
|
|
||||||
rule, found := i.rules[Ipv4Forwarding]
|
|
||||||
if found {
|
if found {
|
||||||
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
|
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
|
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
|
||||||
}
|
|
||||||
}
|
|
||||||
rule, found = i.rules[ipv4Nat]
|
|
||||||
if found {
|
|
||||||
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to list rules: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ruleString := range rules {
|
|
||||||
if !strings.Contains(ruleString, "NETBIRD") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
rule := strings.Fields(ruleString)
|
|
||||||
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ruleString := range rules {
|
|
||||||
if !strings.Contains(ruleString, "NETBIRD") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
rule := strings.Fields(ruleString)
|
|
||||||
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *routerManager) createChain(table, newChain string) error {
|
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||||
chains, err := i.iptablesClient.ListChains(table)
|
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
shouldCreateChain := true
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
for _, chain := range chains {
|
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
|
||||||
if chain == newChain {
|
|
||||||
shouldCreateChain = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldCreateChain {
|
|
||||||
err = i.iptablesClient.NewChain(table, newChain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the loopback return rule to the NAT chain
|
|
||||||
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
|
|
||||||
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addNATRule appends an iptables rule pair to the nat chain
|
|
||||||
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
|
||||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
|
||||||
rule := genRuleSpec(jump, pair.Source, pair.Destination)
|
|
||||||
existingRule, found := i.rules[ruleKey]
|
|
||||||
if found {
|
|
||||||
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
delete(i.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// inserting after loopback ignore rule
|
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
|
||||||
err := i.iptablesClient.Insert(table, chain, 2, rule...)
|
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
i.rules[ruleKey] = rule
|
r.rules[ruleKey] = rule
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// genRuleSpec generates rule specification
|
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||||
func genRuleSpec(jump, source, destination string) []string {
|
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||||
return []string{"-s", source, "-d", destination, "-j", jump}
|
|
||||||
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
|
||||||
|
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
} else {
|
||||||
|
log.Debugf("nat rule %s not found", ruleKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getIptablesRuleType(table string) string {
|
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
||||||
ruleType := "forwarding"
|
intdir := "-i"
|
||||||
if table == tableNat {
|
if inverse {
|
||||||
ruleType = "nat"
|
intdir = "-o"
|
||||||
}
|
}
|
||||||
return ruleType
|
return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump}
|
||||||
|
}
|
||||||
|
|
||||||
|
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||||
|
var rule []string
|
||||||
|
|
||||||
|
if params.SetName != "" {
|
||||||
|
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
|
||||||
|
} else if len(params.Sources) > 0 {
|
||||||
|
source := params.Sources[0]
|
||||||
|
rule = append(rule, "-s", source.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
rule = append(rule, "-d", params.Destination.String())
|
||||||
|
|
||||||
|
if params.Proto != firewall.ProtocolALL {
|
||||||
|
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||||
|
rule = append(rule, applyPort("--sport", params.SPort)...)
|
||||||
|
rule = append(rule, applyPort("--dport", params.DPort)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
rule = append(rule, "-j", actionToStr(params.Action))
|
||||||
|
|
||||||
|
return rule
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyPort(flag string, port *firewall.Port) []string {
|
||||||
|
if port == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if port.IsRange && len(port.Values) == 2 {
|
||||||
|
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(port.Values) > 1 {
|
||||||
|
portList := make([]string, len(port.Values))
|
||||||
|
for i, p := range port.Values {
|
||||||
|
portList[i] = strconv.Itoa(p)
|
||||||
|
}
|
||||||
|
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string{flag, strconv.Itoa(port.Values[0])}
|
||||||
}
|
}
|
||||||
|
@ -4,11 +4,13 @@ package iptables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
@ -28,7 +30,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "failed to init iptables client")
|
require.NoError(t, err, "failed to init iptables client")
|
||||||
|
|
||||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
||||||
require.NoError(t, err, "should return a valid iptables manager")
|
require.NoError(t, err, "should return a valid iptables manager")
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -37,26 +39,22 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
|
|
||||||
require.Len(t, manager.rules, 2, "should have created rules map")
|
require.Len(t, manager.rules, 2, "should have created rules map")
|
||||||
|
|
||||||
exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
|
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
|
|
||||||
require.True(t, exists, "forwarding rule should exist")
|
|
||||||
|
|
||||||
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||||
require.True(t, exists, "postrouting rule should exist")
|
require.True(t, exists, "postrouting rule should exist")
|
||||||
|
|
||||||
pair := firewall.RouterPair{
|
pair := firewall.RouterPair{
|
||||||
ID: "abc",
|
ID: "abc",
|
||||||
Source: "100.100.100.1/32",
|
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||||
Destination: "100.100.100.0/24",
|
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
}
|
}
|
||||||
forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination)
|
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||||
|
|
||||||
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination)
|
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
|
||||||
|
|
||||||
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
@ -65,7 +63,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||||
|
|
||||||
if !isIptablesSupported() {
|
if !isIptablesSupported() {
|
||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
@ -76,7 +74,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "failed to init iptables client")
|
require.NoError(t, err, "failed to init iptables client")
|
||||||
|
|
||||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -86,35 +84,13 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = manager.InsertRoutingRules(testCase.InputPair)
|
err = manager.AddNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "forwarding pair should be inserted")
|
require.NoError(t, err, "forwarding pair should be inserted")
|
||||||
|
|
||||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||||
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
|
||||||
|
|
||||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
|
||||||
require.True(t, exists, "forwarding rule should exist")
|
|
||||||
|
|
||||||
foundRule, found := manager.rules[forwardRuleKey]
|
|
||||||
require.True(t, found, "forwarding rule should exist in the manager map")
|
|
||||||
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
|
||||||
|
|
||||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
|
||||||
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
|
||||||
require.True(t, exists, "income forwarding rule should exist")
|
|
||||||
|
|
||||||
foundRule, found = manager.rules[inForwardRuleKey]
|
|
||||||
require.True(t, found, "income forwarding rule should exist in the manager map")
|
|
||||||
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
|
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
|
||||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||||
if testCase.InputPair.Masquerade {
|
if testCase.InputPair.Masquerade {
|
||||||
require.True(t, exists, "nat rule should be created")
|
require.True(t, exists, "nat rule should be created")
|
||||||
@ -127,8 +103,8 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
require.False(t, foundNat, "nat rule should not exist in the map")
|
require.False(t, foundNat, "nat rule should not exist in the map")
|
||||||
}
|
}
|
||||||
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||||
@ -146,7 +122,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||||
|
|
||||||
if !isIptablesSupported() {
|
if !isIptablesSupported() {
|
||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
@ -156,7 +132,7 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
|
||||||
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = manager.Reset()
|
_ = manager.Reset()
|
||||||
@ -164,26 +140,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||||
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
|
||||||
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
|
||||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
|
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
@ -191,28 +155,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
err = manager.RemoveNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
|
||||||
require.False(t, exists, "forwarding rule should not exist")
|
|
||||||
|
|
||||||
_, found := manager.rules[forwardRuleKey]
|
|
||||||
require.False(t, found, "forwarding rule should exist in the manager map")
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
|
||||||
require.False(t, exists, "income forwarding rule should not exist")
|
|
||||||
|
|
||||||
_, found = manager.rules[inForwardRuleKey]
|
|
||||||
require.False(t, found, "income forwarding rule should exist in the manager map")
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||||
require.False(t, exists, "nat rule should not exist")
|
require.False(t, exists, "nat rule should not exist")
|
||||||
|
|
||||||
_, found = manager.rules[natRuleKey]
|
_, found := manager.rules[natRuleKey]
|
||||||
require.False(t, found, "nat rule should exist in the manager map")
|
require.False(t, found, "nat rule should exist in the manager map")
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||||
@ -221,7 +171,175 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
|
|
||||||
_, found = manager.rules[inNatRuleKey]
|
_, found = manager.rules[inNatRuleKey]
|
||||||
require.False(t, found, "income nat rule should exist in the manager map")
|
require.False(t, found, "income nat rule should exist in the manager map")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||||
|
if !isIptablesSupported() {
|
||||||
|
t.Skip("iptables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
require.NoError(t, err, "Failed to create iptables client")
|
||||||
|
|
||||||
|
r, err := newRouter(context.Background(), iptablesClient, ifaceMock)
|
||||||
|
require.NoError(t, err, "Failed to create router manager")
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := r.Reset()
|
||||||
|
require.NoError(t, err, "Failed to reset router")
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sources []netip.Prefix
|
||||||
|
destination netip.Prefix
|
||||||
|
proto firewall.Protocol
|
||||||
|
sPort *firewall.Port
|
||||||
|
dPort *firewall.Port
|
||||||
|
direction firewall.RuleDirection
|
||||||
|
action firewall.Action
|
||||||
|
expectSet bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Basic TCP rule with single source",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
proto: firewall.ProtocolTCP,
|
||||||
|
sPort: nil,
|
||||||
|
dPort: &firewall.Port{Values: []int{80}},
|
||||||
|
direction: firewall.RuleDirectionIN,
|
||||||
|
action: firewall.ActionAccept,
|
||||||
|
expectSet: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP rule with multiple sources",
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
},
|
||||||
|
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
proto: firewall.ProtocolUDP,
|
||||||
|
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||||
|
dPort: nil,
|
||||||
|
direction: firewall.RuleDirectionOUT,
|
||||||
|
action: firewall.ActionDrop,
|
||||||
|
expectSet: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "All protocols rule",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||||
|
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
proto: firewall.ProtocolALL,
|
||||||
|
sPort: nil,
|
||||||
|
dPort: nil,
|
||||||
|
direction: firewall.RuleDirectionIN,
|
||||||
|
action: firewall.ActionAccept,
|
||||||
|
expectSet: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMP rule",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
|
||||||
|
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
proto: firewall.ProtocolICMP,
|
||||||
|
sPort: nil,
|
||||||
|
dPort: nil,
|
||||||
|
direction: firewall.RuleDirectionIN,
|
||||||
|
action: firewall.ActionAccept,
|
||||||
|
expectSet: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TCP rule with multiple source ports",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||||
|
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
proto: firewall.ProtocolTCP,
|
||||||
|
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
||||||
|
dPort: nil,
|
||||||
|
direction: firewall.RuleDirectionOUT,
|
||||||
|
action: firewall.ActionAccept,
|
||||||
|
expectSet: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP rule with single IP and port range",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
|
||||||
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
proto: firewall.ProtocolUDP,
|
||||||
|
sPort: nil,
|
||||||
|
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||||
|
direction: firewall.RuleDirectionIN,
|
||||||
|
action: firewall.ActionDrop,
|
||||||
|
expectSet: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TCP rule with source and destination ports",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
|
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
proto: firewall.ProtocolTCP,
|
||||||
|
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||||
|
dPort: &firewall.Port{Values: []int{22}},
|
||||||
|
direction: firewall.RuleDirectionOUT,
|
||||||
|
action: firewall.ActionAccept,
|
||||||
|
expectSet: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop all incoming traffic",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
destination: netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
proto: firewall.ProtocolALL,
|
||||||
|
sPort: nil,
|
||||||
|
dPort: nil,
|
||||||
|
direction: firewall.RuleDirectionIN,
|
||||||
|
action: firewall.ActionDrop,
|
||||||
|
expectSet: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
|
// Check if the rule is in the internal map
|
||||||
|
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||||
|
assert.True(t, ok, "Rule not found in internal map")
|
||||||
|
|
||||||
|
// Log the internal rule
|
||||||
|
t.Logf("Internal rule: %v", rule)
|
||||||
|
|
||||||
|
// Check if the rule exists in iptables
|
||||||
|
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
|
||||||
|
assert.NoError(t, err, "Failed to check rule existence")
|
||||||
|
assert.True(t, exists, "Rule not found in iptables")
|
||||||
|
|
||||||
|
// Verify rule content
|
||||||
|
params := routeFilteringRuleParams{
|
||||||
|
Sources: tt.sources,
|
||||||
|
Destination: tt.destination,
|
||||||
|
Proto: tt.proto,
|
||||||
|
SPort: tt.sPort,
|
||||||
|
DPort: tt.dPort,
|
||||||
|
Action: tt.action,
|
||||||
|
SetName: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedRule := genRouteFilteringRuleSpec(params)
|
||||||
|
|
||||||
|
if tt.expectSet {
|
||||||
|
setName := firewall.GenerateSetName(tt.sources)
|
||||||
|
params.SetName = setName
|
||||||
|
expectedRule = genRouteFilteringRuleSpec(params)
|
||||||
|
|
||||||
|
// Check if the set was created
|
||||||
|
_, exists := r.ipsetCounter.Get(setName)
|
||||||
|
assert.True(t, exists, "IPSet not created")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
err = r.DeleteRouteRule(ruleKey)
|
||||||
|
require.NoError(t, err, "Failed to delete rule")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,15 +1,21 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NatFormat = "netbird-nat-%s"
|
ForwardingFormatPrefix = "netbird-fwd-"
|
||||||
ForwardingFormat = "netbird-fwd-%s"
|
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||||
InNatFormat = "netbird-nat-in-%s"
|
NatFormat = "netbird-nat-%s-%t"
|
||||||
InForwardingFormat = "netbird-fwd-in-%s"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Rule abstraction should be implemented by each firewall manager
|
// Rule abstraction should be implemented by each firewall manager
|
||||||
@ -49,11 +55,11 @@ type Manager interface {
|
|||||||
// AllowNetbird allows netbird interface traffic
|
// AllowNetbird allows netbird interface traffic
|
||||||
AllowNetbird() error
|
AllowNetbird() error
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
// AddPeerFiltering adds a rule to the firewall
|
||||||
//
|
//
|
||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
AddFiltering(
|
AddPeerFiltering(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto Protocol,
|
proto Protocol,
|
||||||
sPort *Port,
|
sPort *Port,
|
||||||
@ -64,17 +70,25 @@ type Manager interface {
|
|||||||
comment string,
|
comment string,
|
||||||
) ([]Rule, error)
|
) ([]Rule, error)
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
DeleteRule(rule Rule) error
|
DeletePeerRule(rule Rule) error
|
||||||
|
|
||||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||||
IsServerRouteSupported() bool
|
IsServerRouteSupported() bool
|
||||||
|
|
||||||
// InsertRoutingRules inserts a routing firewall rule
|
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
|
||||||
InsertRoutingRules(pair RouterPair) error
|
|
||||||
|
|
||||||
// RemoveRoutingRules removes a routing firewall rule
|
// DeleteRouteRule deletes a routing rule
|
||||||
RemoveRoutingRules(pair RouterPair) error
|
DeleteRouteRule(rule Rule) error
|
||||||
|
|
||||||
|
// AddNatRule inserts a routing NAT rule
|
||||||
|
AddNatRule(pair RouterPair) error
|
||||||
|
|
||||||
|
// RemoveNatRule removes a routing NAT rule
|
||||||
|
RemoveNatRule(pair RouterPair) error
|
||||||
|
|
||||||
|
// SetLegacyManagement sets the legacy management mode
|
||||||
|
SetLegacyManagement(legacy bool) error
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
Reset() error
|
Reset() error
|
||||||
@ -83,6 +97,89 @@ type Manager interface {
|
|||||||
Flush() error
|
Flush() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenKey(format string, input string) string {
|
func GenKey(format string, pair RouterPair) string {
|
||||||
return fmt.Sprintf(format, input)
|
return fmt.Sprintf(format, pair.ID, pair.Inverse)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LegacyManager defines the interface for legacy management operations
|
||||||
|
type LegacyManager interface {
|
||||||
|
RemoveAllLegacyRouteRules() error
|
||||||
|
GetLegacyManagement() bool
|
||||||
|
SetLegacyManagement(bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLegacyManagement sets the route manager to use legacy management
|
||||||
|
func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
|
||||||
|
oldLegacy := router.GetLegacyManagement()
|
||||||
|
|
||||||
|
if oldLegacy != isLegacy {
|
||||||
|
router.SetLegacyManagement(isLegacy)
|
||||||
|
log.Debugf("Set legacy management to %v", isLegacy)
|
||||||
|
}
|
||||||
|
|
||||||
|
// client reconnected to a newer mgmt, we need to clean up the legacy rules
|
||||||
|
if !isLegacy && oldLegacy {
|
||||||
|
if err := router.RemoveAllLegacyRouteRules(); err != nil {
|
||||||
|
return fmt.Errorf("remove legacy routing rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Legacy routing rules removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
||||||
|
func GenerateSetName(sources []netip.Prefix) string {
|
||||||
|
// sort for consistent naming
|
||||||
|
sortPrefixes(sources)
|
||||||
|
|
||||||
|
var sourcesStr strings.Builder
|
||||||
|
for _, src := range sources {
|
||||||
|
sourcesStr.WriteString(src.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := sha256.Sum256([]byte(sourcesStr.String()))
|
||||||
|
shortHash := hex.EncodeToString(hash[:])[:8]
|
||||||
|
|
||||||
|
return fmt.Sprintf("nb-%s", shortHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
|
||||||
|
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||||
|
if len(prefixes) == 0 {
|
||||||
|
return prefixes
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := []netip.Prefix{prefixes[0]}
|
||||||
|
for _, prefix := range prefixes[1:] {
|
||||||
|
last := merged[len(merged)-1]
|
||||||
|
if last.Contains(prefix.Addr()) {
|
||||||
|
// If the current prefix is contained within the last merged prefix, skip it
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if prefix.Contains(last.Addr()) {
|
||||||
|
// If the current prefix contains the last merged prefix, replace it
|
||||||
|
merged[len(merged)-1] = prefix
|
||||||
|
} else {
|
||||||
|
// Otherwise, add the current prefix to the merged list
|
||||||
|
merged = append(merged, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortPrefixes sorts the given slice of netip.Prefix in place.
|
||||||
|
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
||||||
|
func sortPrefixes(prefixes []netip.Prefix) {
|
||||||
|
sort.Slice(prefixes, func(i, j int) bool {
|
||||||
|
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
|
||||||
|
if addrCmp != 0 {
|
||||||
|
return addrCmp < 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
|
||||||
|
return prefixes[i].Bits() > prefixes[j].Bits()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
192
client/firewall/manager/firewall_test.go
Normal file
192
client/firewall/manager/firewall_test.go
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
package manager_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateSetName(t *testing.T) {
|
||||||
|
t.Run("Different orders result in same hash", func(t *testing.T) {
|
||||||
|
prefixes1 := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
}
|
||||||
|
prefixes2 := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
}
|
||||||
|
|
||||||
|
result1 := manager.GenerateSetName(prefixes1)
|
||||||
|
result2 := manager.GenerateSetName(prefixes2)
|
||||||
|
|
||||||
|
if result1 != result2 {
|
||||||
|
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Result format is correct", func(t *testing.T) {
|
||||||
|
prefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
}
|
||||||
|
|
||||||
|
result := manager.GenerateSetName(prefixes)
|
||||||
|
|
||||||
|
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error matching regex: %v", err)
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
t.Errorf("Result format is incorrect: %s", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
||||||
|
result1 := manager.GenerateSetName([]netip.Prefix{})
|
||||||
|
result2 := manager.GenerateSetName([]netip.Prefix{})
|
||||||
|
|
||||||
|
if result1 != result2 {
|
||||||
|
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
|
||||||
|
prefixes1 := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("2001:db8::/32"),
|
||||||
|
}
|
||||||
|
prefixes2 := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("2001:db8::/32"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
}
|
||||||
|
|
||||||
|
result1 := manager.GenerateSetName(prefixes1)
|
||||||
|
result2 := manager.GenerateSetName(prefixes2)
|
||||||
|
|
||||||
|
if result1 != result2 {
|
||||||
|
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeIPRanges(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []netip.Prefix
|
||||||
|
expected []netip.Prefix
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty input",
|
||||||
|
input: []netip.Prefix{},
|
||||||
|
expected: []netip.Prefix{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single range",
|
||||||
|
input: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
expected: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Two non-overlapping ranges",
|
||||||
|
input: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
},
|
||||||
|
expected: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "One range containing another",
|
||||||
|
input: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
expected: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "One range containing another (different order)",
|
||||||
|
input: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
},
|
||||||
|
expected: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Overlapping ranges",
|
||||||
|
input: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.1.128/25"),
|
||||||
|
},
|
||||||
|
expected: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Overlapping ranges (different order)",
|
||||||
|
input: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.128/25"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
expected: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple overlapping ranges",
|
||||||
|
input: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.2.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.1.128/25"),
|
||||||
|
},
|
||||||
|
expected: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Partially overlapping ranges",
|
||||||
|
input: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/23"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.2.0/25"),
|
||||||
|
},
|
||||||
|
expected: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.0.0/23"),
|
||||||
|
netip.MustParsePrefix("192.168.2.0/25"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 ranges",
|
||||||
|
input: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("2001:db8::/32"),
|
||||||
|
netip.MustParsePrefix("2001:db8:1::/48"),
|
||||||
|
netip.MustParsePrefix("2001:db8:2::/48"),
|
||||||
|
},
|
||||||
|
expected: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("2001:db8::/32"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := manager.MergeIPRanges(tt.input)
|
||||||
|
if !reflect.DeepEqual(result, tt.expected) {
|
||||||
|
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -1,18 +1,26 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
type RouterPair struct {
|
type RouterPair struct {
|
||||||
ID string
|
ID route.ID
|
||||||
Source string
|
Source netip.Prefix
|
||||||
Destination string
|
Destination netip.Prefix
|
||||||
Masquerade bool
|
Masquerade bool
|
||||||
|
Inverse bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetInPair(pair RouterPair) RouterPair {
|
func GetInversePair(pair RouterPair) RouterPair {
|
||||||
return RouterPair{
|
return RouterPair{
|
||||||
ID: pair.ID,
|
ID: pair.ID,
|
||||||
// invert Source/Destination
|
// invert Source/Destination
|
||||||
Source: pair.Destination,
|
Source: pair.Destination,
|
||||||
Destination: pair.Source,
|
Destination: pair.Source,
|
||||||
Masquerade: pair.Masquerade,
|
Masquerade: pair.Masquerade,
|
||||||
|
Inverse: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -33,9 +33,10 @@ const (
|
|||||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const flushError = "flush: %w"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||||
postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type AclManager struct {
|
type AclManager struct {
|
||||||
@ -48,7 +49,6 @@ type AclManager struct {
|
|||||||
chainInputRules *nftables.Chain
|
chainInputRules *nftables.Chain
|
||||||
chainOutputRules *nftables.Chain
|
chainOutputRules *nftables.Chain
|
||||||
chainFwFilter *nftables.Chain
|
chainFwFilter *nftables.Chain
|
||||||
chainPrerouting *nftables.Chain
|
|
||||||
|
|
||||||
ipsetStore *ipsetStore
|
ipsetStore *ipsetStore
|
||||||
rules map[string]*Rule
|
rules map[string]*Rule
|
||||||
@ -64,7 +64,7 @@ type iFaceMapper interface {
|
|||||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
|
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
|
||||||
// sConn is used for creating sets and adding/removing elements from them
|
// sConn is used for creating sets and adding/removing elements from them
|
||||||
// it's differ then rConn (which does create new conn for each flush operation)
|
// it's differ then rConn (which does create new conn for each flush operation)
|
||||||
// and is permanent. Using same connection for booth type of operations
|
// and is permanent. Using same connection for both type of operations
|
||||||
// overloads netlink with high amount of rules ( > 10000)
|
// overloads netlink with high amount of rules ( > 10000)
|
||||||
sConn, err := nftables.New(nftables.AsLasting())
|
sConn, err := nftables.New(nftables.AsLasting())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -90,11 +90,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
// AddPeerFiltering rule to the firewall
|
||||||
//
|
//
|
||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *AclManager) AddFiltering(
|
func (m *AclManager) AddPeerFiltering(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
@ -120,20 +120,11 @@ func (m *AclManager) AddFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
newRules = append(newRules, ioRule)
|
newRules = append(newRules, ioRule)
|
||||||
if !shouldAddToPrerouting(proto, dPort, direction) {
|
|
||||||
return newRules, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip)
|
|
||||||
if err != nil {
|
|
||||||
return newRules, err
|
|
||||||
}
|
|
||||||
newRules = append(newRules, preroutingRule)
|
|
||||||
return newRules, nil
|
return newRules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
func (m *AclManager) DeleteRule(rule firewall.Rule) error {
|
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||||
r, ok := rule.(*Rule)
|
r, ok := rule.(*Rule)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid rule type")
|
return fmt.Errorf("invalid rule type")
|
||||||
@ -199,8 +190,7 @@ func (m *AclManager) DeleteRule(rule firewall.Rule) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createDefaultAllowRules In case if the USP firewall manager can use the native firewall manager we must to create allow rules for
|
// createDefaultAllowRules creates default allow rules for the input and output chains
|
||||||
// input and output chains
|
|
||||||
func (m *AclManager) createDefaultAllowRules() error {
|
func (m *AclManager) createDefaultAllowRules() error {
|
||||||
expIn := []expr.Any{
|
expIn := []expr.Any{
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
@ -214,13 +204,13 @@ func (m *AclManager) createDefaultAllowRules() error {
|
|||||||
SourceRegister: 1,
|
SourceRegister: 1,
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Len: 4,
|
Len: 4,
|
||||||
Mask: []byte{0x00, 0x00, 0x00, 0x00},
|
Mask: []byte{0, 0, 0, 0},
|
||||||
Xor: zeroXor,
|
Xor: []byte{0, 0, 0, 0},
|
||||||
},
|
},
|
||||||
// net address
|
// net address
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: []byte{0x00, 0x00, 0x00, 0x00},
|
Data: []byte{0, 0, 0, 0},
|
||||||
},
|
},
|
||||||
&expr.Verdict{
|
&expr.Verdict{
|
||||||
Kind: expr.VerdictAccept,
|
Kind: expr.VerdictAccept,
|
||||||
@ -246,13 +236,13 @@ func (m *AclManager) createDefaultAllowRules() error {
|
|||||||
SourceRegister: 1,
|
SourceRegister: 1,
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Len: 4,
|
Len: 4,
|
||||||
Mask: []byte{0x00, 0x00, 0x00, 0x00},
|
Mask: []byte{0, 0, 0, 0},
|
||||||
Xor: zeroXor,
|
Xor: []byte{0, 0, 0, 0},
|
||||||
},
|
},
|
||||||
// net address
|
// net address
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: []byte{0x00, 0x00, 0x00, 0x00},
|
Data: []byte{0, 0, 0, 0},
|
||||||
},
|
},
|
||||||
&expr.Verdict{
|
&expr.Verdict{
|
||||||
Kind: expr.VerdictAccept,
|
Kind: expr.VerdictAccept,
|
||||||
@ -266,10 +256,8 @@ func (m *AclManager) createDefaultAllowRules() error {
|
|||||||
Exprs: expOut,
|
Exprs: expOut,
|
||||||
})
|
})
|
||||||
|
|
||||||
err := m.rConn.Flush()
|
if err := m.rConn.Flush(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf(flushError, err)
|
||||||
log.Debugf("failed to create default allow rules: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -290,15 +278,11 @@ func (m *AclManager) Flush() error {
|
|||||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.refreshRuleHandles(m.chainPrerouting); err != nil {
|
|
||||||
log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
|
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
|
||||||
ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset)
|
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
|
||||||
if r, ok := m.rules[ruleId]; ok {
|
if r, ok := m.rules[ruleId]; ok {
|
||||||
return &Rule{
|
return &Rule{
|
||||||
r.nftRule,
|
r.nftRule,
|
||||||
@ -308,18 +292,7 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ifaceKey := expr.MetaKeyIIFNAME
|
var expressions []expr.Any
|
||||||
if direction == firewall.RuleDirectionOUT {
|
|
||||||
ifaceKey = expr.MetaKeyOIFNAME
|
|
||||||
}
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if proto != firewall.ProtocolALL {
|
if proto != firewall.ProtocolALL {
|
||||||
expressions = append(expressions, &expr.Payload{
|
expressions = append(expressions, &expr.Payload{
|
||||||
@ -329,21 +302,15 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
|||||||
Len: uint32(1),
|
Len: uint32(1),
|
||||||
})
|
})
|
||||||
|
|
||||||
var protoData []byte
|
protoData, err := protoToInt(proto)
|
||||||
switch proto {
|
if err != nil {
|
||||||
case firewall.ProtocolTCP:
|
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
||||||
protoData = []byte{unix.IPPROTO_TCP}
|
|
||||||
case firewall.ProtocolUDP:
|
|
||||||
protoData = []byte{unix.IPPROTO_UDP}
|
|
||||||
case firewall.ProtocolICMP:
|
|
||||||
protoData = []byte{unix.IPPROTO_ICMP}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported protocol: %s", proto)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
expressions = append(expressions, &expr.Cmp{
|
expressions = append(expressions, &expr.Cmp{
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Data: protoData,
|
Data: []byte{protoData},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -432,10 +399,9 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
|||||||
} else {
|
} else {
|
||||||
chain = m.chainOutputRules
|
chain = m.chainOutputRules
|
||||||
}
|
}
|
||||||
nftRule := m.rConn.InsertRule(&nftables.Rule{
|
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||||
Table: m.workTable,
|
Table: m.workTable,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Position: 0,
|
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
UserData: userData,
|
UserData: userData,
|
||||||
})
|
})
|
||||||
@ -453,139 +419,13 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
|||||||
return rule, nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) {
|
|
||||||
var protoData []byte
|
|
||||||
switch proto {
|
|
||||||
case firewall.ProtocolTCP:
|
|
||||||
protoData = []byte{unix.IPPROTO_TCP}
|
|
||||||
case firewall.ProtocolUDP:
|
|
||||||
protoData = []byte{unix.IPPROTO_UDP}
|
|
||||||
case firewall.ProtocolICMP:
|
|
||||||
protoData = []byte{unix.IPPROTO_ICMP}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported protocol: %s", proto)
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleId := generateRuleIdForMangle(ipset, ip, proto, port)
|
|
||||||
if r, ok := m.rules[ruleId]; ok {
|
|
||||||
return &Rule{
|
|
||||||
r.nftRule,
|
|
||||||
r.nftSet,
|
|
||||||
r.ruleID,
|
|
||||||
ip,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var ipExpression expr.Any
|
|
||||||
// add individual IP for match if no ipset defined
|
|
||||||
rawIP := ip.To4()
|
|
||||||
if ipset == nil {
|
|
||||||
ipExpression = &expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: rawIP,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ipExpression = &expr.Lookup{
|
|
||||||
SourceRegister: 1,
|
|
||||||
SetName: ipset.Name,
|
|
||||||
SetID: ipset.ID,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 12,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
ipExpression,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 16,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: m.wgIface.Address().IP.To4(),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: uint32(9),
|
|
||||||
Len: uint32(1),
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Register: 1,
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Data: protoData,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if port != nil {
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 2,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: encodePort(*port),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: postroutingMark,
|
|
||||||
},
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyMARK,
|
|
||||||
SourceRegister: true,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
nftRule := m.rConn.InsertRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: m.chainPrerouting,
|
|
||||||
Position: 0,
|
|
||||||
Exprs: expressions,
|
|
||||||
UserData: []byte(ruleId),
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return nil, fmt.Errorf("flush insert rule: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rule := &Rule{
|
|
||||||
nftRule: nftRule,
|
|
||||||
nftSet: ipset,
|
|
||||||
ruleID: ruleId,
|
|
||||||
ip: ip,
|
|
||||||
}
|
|
||||||
|
|
||||||
m.rules[ruleId] = rule
|
|
||||||
if ipset != nil {
|
|
||||||
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
|
||||||
}
|
|
||||||
return rule, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) createDefaultChains() (err error) {
|
func (m *AclManager) createDefaultChains() (err error) {
|
||||||
// chainNameInputRules
|
// chainNameInputRules
|
||||||
chain := m.createChain(chainNameInputRules)
|
chain := m.createChain(chainNameInputRules)
|
||||||
err = m.rConn.Flush()
|
err = m.rConn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||||
return err
|
return fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
m.chainInputRules = chain
|
m.chainInputRules = chain
|
||||||
|
|
||||||
@ -601,9 +441,6 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
// netbird-acl-input-filter
|
// netbird-acl-input-filter
|
||||||
// type filter hook input priority filter; policy accept;
|
// type filter hook input priority filter; policy accept;
|
||||||
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||||
//netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept
|
|
||||||
m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME)
|
|
||||||
m.addFwdAllow(chain, expr.MetaKeyIIFNAME)
|
|
||||||
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
|
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
|
||||||
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
|
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
|
||||||
err = m.rConn.Flush()
|
err = m.rConn.Flush()
|
||||||
@ -615,7 +452,6 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
// netbird-acl-output-filter
|
// netbird-acl-output-filter
|
||||||
// type filter hook output priority filter; policy accept;
|
// type filter hook output priority filter; policy accept;
|
||||||
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
|
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
|
||||||
m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME)
|
|
||||||
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
|
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
|
||||||
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
|
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
|
||||||
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
|
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
|
||||||
@ -627,24 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
|
|
||||||
// netbird-acl-forward-filter
|
// netbird-acl-forward-filter
|
||||||
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||||
m.addJumpRulesToRtForward() // to
|
m.addJumpRulesToRtForward() // to netbird-rt-fwd
|
||||||
m.addMarkAccept()
|
|
||||||
m.addJumpRuleToInputChain() // to netbird-acl-input-rules
|
|
||||||
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
|
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
|
||||||
|
|
||||||
err = m.rConn.Flush()
|
err = m.rConn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
|
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
|
||||||
return err
|
return fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// netbird-acl-output-filter
|
|
||||||
// type filter hook output priority filter; policy accept;
|
|
||||||
m.chainPrerouting = m.createPreroutingMangle()
|
|
||||||
err = m.rConn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -667,59 +494,6 @@ func (m *AclManager) addJumpRulesToRtForward() {
|
|||||||
Chain: m.chainFwFilter,
|
Chain: m.chainFwFilter,
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
})
|
})
|
||||||
|
|
||||||
expressions = []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictJump,
|
|
||||||
Chain: m.routeingFwChainName,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: m.chainFwFilter,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addMarkAccept() {
|
|
||||||
// oifname "wt0" meta mark 0x000007e4 accept
|
|
||||||
// iifname "wt0" meta mark 0x000007e4 accept
|
|
||||||
ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME}
|
|
||||||
for _, iface := range ifaces {
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: iface, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyMARK,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: postroutingMark,
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: m.chainFwFilter,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) createChain(name string) *nftables.Chain {
|
func (m *AclManager) createChain(name string) *nftables.Chain {
|
||||||
@ -729,6 +503,9 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
|
|||||||
}
|
}
|
||||||
|
|
||||||
chain = m.rConn.AddChain(chain)
|
chain = m.rConn.AddChain(chain)
|
||||||
|
|
||||||
|
insertReturnTrafficRule(m.rConn, m.workTable, chain)
|
||||||
|
|
||||||
return chain
|
return chain
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -746,74 +523,6 @@ func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.Cha
|
|||||||
return m.rConn.AddChain(chain)
|
return m.rConn.AddChain(chain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) createPreroutingMangle() *nftables.Chain {
|
|
||||||
polAccept := nftables.ChainPolicyAccept
|
|
||||||
chain := &nftables.Chain{
|
|
||||||
Name: "netbird-acl-prerouting-filter",
|
|
||||||
Table: m.workTable,
|
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
|
||||||
Priority: nftables.ChainPriorityMangle,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Policy: &polAccept,
|
|
||||||
}
|
|
||||||
|
|
||||||
chain = m.rConn.AddChain(chain)
|
|
||||||
|
|
||||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 12,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 4,
|
|
||||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
|
||||||
Mask: m.wgIface.Address().Network.Mask,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 2,
|
|
||||||
Data: ip.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 16,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: m.wgIface.Address().IP.To4(),
|
|
||||||
},
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: postroutingMark,
|
|
||||||
},
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyMARK,
|
|
||||||
SourceRegister: true,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
return chain
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
|
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
|
||||||
expressions := []expr.Any{
|
expressions := []expr.Any{
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||||
@ -832,101 +541,9 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) addJumpRuleToInputChain() {
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictJump,
|
|
||||||
Chain: m.chainInputRules.Name,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: m.chainFwFilter,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) {
|
|
||||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
|
||||||
var srcOp, dstOp expr.CmpOp
|
|
||||||
if netIfName == expr.MetaKeyIIFNAME {
|
|
||||||
srcOp = expr.CmpOpEq
|
|
||||||
dstOp = expr.CmpOpNeq
|
|
||||||
} else {
|
|
||||||
srcOp = expr.CmpOpNeq
|
|
||||||
dstOp = expr.CmpOpEq
|
|
||||||
}
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: netIfName, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 12,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 4,
|
|
||||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
|
||||||
Mask: m.wgIface.Address().Network.Mask,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: srcOp,
|
|
||||||
Register: 2,
|
|
||||||
Data: ip.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 16,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 4,
|
|
||||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
|
||||||
Mask: m.wgIface.Address().Network.Mask,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: dstOp,
|
|
||||||
Register: 2,
|
|
||||||
Data: ip.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: chain.Table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
|
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
|
||||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
||||||
var srcOp, dstOp expr.CmpOp
|
dstOp := expr.CmpOpNeq
|
||||||
if iifname == expr.MetaKeyIIFNAME {
|
|
||||||
srcOp = expr.CmpOpNeq
|
|
||||||
dstOp = expr.CmpOpEq
|
|
||||||
} else {
|
|
||||||
srcOp = expr.CmpOpEq
|
|
||||||
dstOp = expr.CmpOpNeq
|
|
||||||
}
|
|
||||||
expressions := []expr.Any{
|
expressions := []expr.Any{
|
||||||
&expr.Meta{Key: iifname, Register: 1},
|
&expr.Meta{Key: iifname, Register: 1},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
@ -934,24 +551,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ifname(m.wgIface.Name()),
|
Data: ifname(m.wgIface.Name()),
|
||||||
},
|
},
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 12,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 4,
|
|
||||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
|
||||||
Mask: m.wgIface.Address().Network.Mask,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: srcOp,
|
|
||||||
Register: 2,
|
|
||||||
Data: ip.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 2,
|
DestRegister: 2,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
@ -982,7 +581,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
||||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
|
||||||
expressions := []expr.Any{
|
expressions := []expr.Any{
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
@ -990,47 +588,12 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ifname(m.wgIface.Name()),
|
Data: ifname(m.wgIface.Name()),
|
||||||
},
|
},
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 12,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 4,
|
|
||||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
|
||||||
Mask: m.wgIface.Address().Network.Mask,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 2,
|
|
||||||
Data: ip.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 16,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 4,
|
|
||||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
|
||||||
Mask: m.wgIface.Address().Network.Mask,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 2,
|
|
||||||
Data: ip.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
&expr.Verdict{
|
||||||
Kind: expr.VerdictJump,
|
Kind: expr.VerdictJump,
|
||||||
Chain: to,
|
Chain: to,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
_ = m.rConn.AddRule(&nftables.Rule{
|
||||||
Table: chain.Table,
|
Table: chain.Table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
@ -1132,7 +695,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateRuleId(
|
func generatePeerRuleId(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
@ -1155,33 +718,6 @@ func generateRuleId(
|
|||||||
}
|
}
|
||||||
return "set:" + ipset.Name + rulesetID
|
return "set:" + ipset.Name + rulesetID
|
||||||
}
|
}
|
||||||
func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string {
|
|
||||||
// case of icmp port is empty
|
|
||||||
var p string
|
|
||||||
if port != nil {
|
|
||||||
p = port.String()
|
|
||||||
}
|
|
||||||
if ipset != nil {
|
|
||||||
return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p)
|
|
||||||
} else {
|
|
||||||
return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
|
|
||||||
if proto == "all" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if direction != firewall.RuleDirectionIN {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if dPort == nil && proto != firewall.ProtocolICMP {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodePort(port firewall.Port) []byte {
|
func encodePort(port firewall.Port) []byte {
|
||||||
bs := make([]byte, 2)
|
bs := make([]byte, 2)
|
||||||
@ -1191,6 +727,19 @@ func encodePort(port firewall.Port) []byte {
|
|||||||
|
|
||||||
func ifname(n string) []byte {
|
func ifname(n string) []byte {
|
||||||
b := make([]byte, 16)
|
b := make([]byte, 16)
|
||||||
copy(b, []byte(n+"\x00"))
|
copy(b, n+"\x00")
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func protoToInt(protocol firewall.Protocol) (uint8, error) {
|
||||||
|
switch protocol {
|
||||||
|
case firewall.ProtocolTCP:
|
||||||
|
return unix.IPPROTO_TCP, nil
|
||||||
|
case firewall.ProtocolUDP:
|
||||||
|
return unix.IPPROTO_UDP, nil
|
||||||
|
case firewall.ProtocolICMP:
|
||||||
|
return unix.IPPROTO_ICMP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
||||||
|
}
|
||||||
|
@ -5,9 +5,11 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@ -15,8 +17,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// tableName is the name of the table that is used for filtering by the Netbird client
|
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
|
||||||
tableName = "netbird"
|
tableNameNetbird = "netbird"
|
||||||
|
|
||||||
|
tableNameFilter = "filter"
|
||||||
|
chainNameInput = "INPUT"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
@ -41,12 +46,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.router, err = newRouter(context, workTable)
|
m.router, err = newRouter(context, workTable, wgIface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
|
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -54,11 +59,11 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
// AddPeerFiltering rule to the firewall
|
||||||
//
|
//
|
||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *Manager) AddFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
@ -76,33 +81,52 @@ func (m *Manager) AddFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.aclManager.DeleteRule(rule)
|
if !destination.Addr().Is4() {
|
||||||
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePeerRule from the firewall by rule definition
|
||||||
|
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.aclManager.DeletePeerRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRouteRule deletes a routing rule
|
||||||
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.router.AddRoutingRules(pair)
|
return m.router.AddNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.router.RemoveRoutingRules(pair)
|
return m.router.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// AllowNetbird allows netbird interface traffic
|
||||||
@ -126,7 +150,7 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
|
|
||||||
var chain *nftables.Chain
|
var chain *nftables.Chain
|
||||||
for _, c := range chains {
|
for _, c := range chains {
|
||||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
|
||||||
chain = c
|
chain = c
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -157,6 +181,27 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLegacyManagement sets the route manager to use legacy management
|
||||||
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
|
oldLegacy := m.router.legacyManagement
|
||||||
|
|
||||||
|
if oldLegacy != isLegacy {
|
||||||
|
m.router.legacyManagement = isLegacy
|
||||||
|
log.Debugf("Set legacy management to %v", isLegacy)
|
||||||
|
}
|
||||||
|
|
||||||
|
// client reconnected to a newer mgmt, we need to cleanup the legacy rules
|
||||||
|
if !isLegacy && oldLegacy {
|
||||||
|
if err := m.router.RemoveAllLegacyRouteRules(); err != nil {
|
||||||
|
return fmt.Errorf("remove legacy routing rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Legacy routing rules removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset() error {
|
func (m *Manager) Reset() error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
@ -185,14 +230,16 @@ func (m *Manager) Reset() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.router.ResetForwardRules()
|
if err := m.router.Reset(); err != nil {
|
||||||
|
return fmt.Errorf("reset forward rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
tables, err := m.rConn.ListTables()
|
tables, err := m.rConn.ListTables()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list of tables: %w", err)
|
return fmt.Errorf("list of tables: %w", err)
|
||||||
}
|
}
|
||||||
for _, t := range tables {
|
for _, t := range tables {
|
||||||
if t.Name == tableName {
|
if t.Name == tableNameNetbird {
|
||||||
m.rConn.DelTable(t)
|
m.rConn.DelTable(t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -218,12 +265,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range tables {
|
for _, t := range tables {
|
||||||
if t.Name == tableName {
|
if t.Name == tableNameNetbird {
|
||||||
m.rConn.DelTable(t)
|
m.rConn.DelTable(t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
|
||||||
err = m.rConn.Flush()
|
err = m.rConn.Flush()
|
||||||
return table, err
|
return table, err
|
||||||
}
|
}
|
||||||
@ -239,9 +286,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ifname(m.wgIface.Name()),
|
Data: ifname(m.wgIface.Name()),
|
||||||
},
|
},
|
||||||
&expr.Verdict{
|
&expr.Verdict{},
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
UserData: []byte(allowNetbirdInputRuleID),
|
UserData: []byte(allowNetbirdInputRuleID),
|
||||||
}
|
}
|
||||||
@ -251,7 +296,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
|||||||
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
|
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
|
||||||
ifName := ifname(m.wgIface.Name())
|
ifName := ifname(m.wgIface.Name())
|
||||||
for _, rule := range existedRules {
|
for _, rule := range existedRules {
|
||||||
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
|
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
|
||||||
if len(rule.Exprs) < 4 {
|
if len(rule.Exprs) < 4 {
|
||||||
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
|
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
|
||||||
continue
|
continue
|
||||||
@ -265,3 +310,33 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
|
||||||
|
rule := &nftables.Rule{
|
||||||
|
Table: table,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeySTATE,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Bitwise{
|
||||||
|
SourceRegister: 1,
|
||||||
|
DestRegister: 1,
|
||||||
|
Len: 4,
|
||||||
|
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||||
|
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.InsertRule(rule)
|
||||||
|
}
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
@ -17,6 +18,21 @@ import (
|
|||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ifaceMock = &iFaceMock{
|
||||||
|
NameFunc: func() string {
|
||||||
|
return "lo"
|
||||||
|
},
|
||||||
|
AddressFunc: func() iface.WGAddress {
|
||||||
|
return iface.WGAddress{
|
||||||
|
IP: net.ParseIP("100.96.0.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.96.0.0"),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMock struct {
|
type iFaceMock struct {
|
||||||
NameFunc func() string
|
NameFunc func() string
|
||||||
@ -40,23 +56,9 @@ func (i *iFaceMock) Address() iface.WGAddress {
|
|||||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||||
|
|
||||||
func TestNftablesManager(t *testing.T) {
|
func TestNftablesManager(t *testing.T) {
|
||||||
mock := &iFaceMock{
|
|
||||||
NameFunc: func() string {
|
|
||||||
return "lo"
|
|
||||||
},
|
|
||||||
AddressFunc: func() iface.WGAddress {
|
|
||||||
return iface.WGAddress{
|
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
|
||||||
Network: &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(context.Background(), mock)
|
manager, err := Create(context.Background(), ifaceMock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
@ -70,7 +72,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddFiltering(
|
rule, err := manager.AddPeerFiltering(
|
||||||
ip,
|
ip,
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
nil,
|
nil,
|
||||||
@ -88,17 +90,34 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
|
|
||||||
require.Len(t, rules, 1, "expected 1 rules")
|
require.Len(t, rules, 2, "expected 2 rules")
|
||||||
|
|
||||||
|
expectedExprs1 := []expr.Any{
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeySTATE,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Bitwise{
|
||||||
|
SourceRegister: 1,
|
||||||
|
DestRegister: 1,
|
||||||
|
Len: 4,
|
||||||
|
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||||
|
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||||
add := ipToAdd.Unmap()
|
add := ipToAdd.Unmap()
|
||||||
expectedExprs := []expr.Any{
|
expectedExprs2 := []expr.Any{
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname("lo"),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
@ -134,10 +153,10 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
},
|
},
|
||||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||||
}
|
}
|
||||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
|
require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions")
|
||||||
|
|
||||||
for _, r := range rule {
|
for _, r := range rule {
|
||||||
err = manager.DeleteRule(r)
|
err = manager.DeletePeerRule(r)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,7 +165,8 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
|
|
||||||
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
require.Len(t, rules, 0, "expected 0 rules after deletion")
|
// established rule remains
|
||||||
|
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||||
|
|
||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
@ -187,9 +207,9 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
}
|
}
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
|
@ -1,431 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/binaryutil"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
chainNameRouteingFw = "netbird-rt-fwd"
|
|
||||||
chainNameRoutingNat = "netbird-rt-nat"
|
|
||||||
|
|
||||||
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
|
||||||
userDataAcceptForwardRuleDst = "frwacceptdst"
|
|
||||||
|
|
||||||
loopbackInterface = "lo\x00"
|
|
||||||
)
|
|
||||||
|
|
||||||
// some presets for building nftable rules
|
|
||||||
var (
|
|
||||||
zeroXor = binaryutil.NativeEndian.PutUint32(0)
|
|
||||||
|
|
||||||
exprCounterAccept = []expr.Any{
|
|
||||||
&expr.Counter{},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
type router struct {
|
|
||||||
ctx context.Context
|
|
||||||
stop context.CancelFunc
|
|
||||||
conn *nftables.Conn
|
|
||||||
workTable *nftables.Table
|
|
||||||
filterTable *nftables.Table
|
|
||||||
chains map[string]*nftables.Chain
|
|
||||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
|
||||||
rules map[string]*nftables.Rule
|
|
||||||
isDefaultFwdRulesEnabled bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
|
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
|
||||||
|
|
||||||
r := &router{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
conn: &nftables.Conn{},
|
|
||||||
workTable: workTable,
|
|
||||||
chains: make(map[string]*nftables.Chain),
|
|
||||||
rules: make(map[string]*nftables.Rule),
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
r.filterTable, err = r.loadFilterTable()
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, errFilterTableNotFound) {
|
|
||||||
log.Warnf("table 'filter' not found for forward rules")
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.cleanUpDefaultForwardRules()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.createContainers()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to create containers for route: %s", err)
|
|
||||||
}
|
|
||||||
return r, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) RouteingFwChainName() string {
|
|
||||||
return chainNameRouteingFw
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetForwardRules cleans existing nftables default forward rules from the system
|
|
||||||
func (r *router) ResetForwardRules() {
|
|
||||||
err := r.cleanUpDefaultForwardRules()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to reset forward rules: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
|
||||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, table := range tables {
|
|
||||||
if table.Name == "filter" {
|
|
||||||
return table, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errFilterTableNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) createContainers() error {
|
|
||||||
|
|
||||||
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: chainNameRouteingFw,
|
|
||||||
Table: r.workTable,
|
|
||||||
})
|
|
||||||
|
|
||||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: chainNameRoutingNat,
|
|
||||||
Table: r.workTable,
|
|
||||||
Hooknum: nftables.ChainHookPostrouting,
|
|
||||||
Priority: nftables.ChainPriorityNATSource - 1,
|
|
||||||
Type: nftables.ChainTypeNAT,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Add RETURN rule for loopback interface
|
|
||||||
loRule := &nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameRoutingNat],
|
|
||||||
Exprs: []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte(loopbackInterface),
|
|
||||||
},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictReturn},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
r.conn.InsertRule(loRule)
|
|
||||||
|
|
||||||
err := r.refreshRulesMap()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.conn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
|
||||||
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
|
|
||||||
err := r.refreshRulesMap()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if pair.Masquerade {
|
|
||||||
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
|
|
||||||
log.Debugf("add default accept forward rule")
|
|
||||||
r.acceptForwardRule(pair.Source)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.conn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addRoutingRule inserts a nftable rule to the conn client flush queue
|
|
||||||
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
|
||||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
|
||||||
|
|
||||||
var expression []expr.Any
|
|
||||||
if isNat {
|
|
||||||
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
|
|
||||||
} else {
|
|
||||||
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleKey := manager.GenKey(format, pair.ID)
|
|
||||||
|
|
||||||
_, exists := r.rules[ruleKey]
|
|
||||||
if exists {
|
|
||||||
err := r.removeRoutingRule(format, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainName],
|
|
||||||
Exprs: expression,
|
|
||||||
UserData: []byte(ruleKey),
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) acceptForwardRule(sourceNetwork string) {
|
|
||||||
src := generateCIDRMatcherExpressions(true, sourceNetwork)
|
|
||||||
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
|
|
||||||
|
|
||||||
var exprs []expr.Any
|
|
||||||
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
})...)
|
|
||||||
|
|
||||||
rule := &nftables.Rule{
|
|
||||||
Table: r.filterTable,
|
|
||||||
Chain: &nftables.Chain{
|
|
||||||
Name: "FORWARD",
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: exprs,
|
|
||||||
UserData: []byte(userDataAcceptForwardRuleSrc),
|
|
||||||
}
|
|
||||||
|
|
||||||
r.conn.AddRule(rule)
|
|
||||||
|
|
||||||
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
|
|
||||||
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
|
|
||||||
|
|
||||||
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
})...)
|
|
||||||
|
|
||||||
rule = &nftables.Rule{
|
|
||||||
Table: r.filterTable,
|
|
||||||
Chain: &nftables.Chain{
|
|
||||||
Name: "FORWARD",
|
|
||||||
Table: r.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: exprs,
|
|
||||||
UserData: []byte(userDataAcceptForwardRuleDst),
|
|
||||||
}
|
|
||||||
r.conn.AddRule(rule)
|
|
||||||
r.isDefaultFwdRulesEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
|
|
||||||
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
|
|
||||||
err := r.refreshRulesMap()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.removeRoutingRule(manager.NatFormat, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(r.rules) == 0 {
|
|
||||||
err := r.cleanUpDefaultForwardRules()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.conn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
|
||||||
}
|
|
||||||
log.Debugf("nftables: removed rules for %s", pair.Destination)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
|
|
||||||
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
|
|
||||||
ruleKey := manager.GenKey(format, pair.ID)
|
|
||||||
|
|
||||||
rule, found := r.rules[ruleKey]
|
|
||||||
if found {
|
|
||||||
ruleType := "forwarding"
|
|
||||||
if rule.Chain.Type == nftables.ChainTypeNAT {
|
|
||||||
ruleType = "nat"
|
|
||||||
}
|
|
||||||
|
|
||||||
err := r.conn.DelRule(rule)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
|
||||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
|
||||||
func (r *router) refreshRulesMap() error {
|
|
||||||
for _, chain := range r.chains {
|
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
|
||||||
}
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 {
|
|
||||||
r.rules[string(rule.UserData)] = rule
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) cleanUpDefaultForwardRules() error {
|
|
||||||
if r.filterTable == nil {
|
|
||||||
r.isDefaultFwdRulesEnabled = false
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var rules []*nftables.Rule
|
|
||||||
for _, chain := range chains {
|
|
||||||
if chain.Table.Name != r.filterTable.Name {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if chain.Name != "FORWARD" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
rules, err = r.conn.GetRules(r.filterTable, chain)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
|
||||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
|
|
||||||
err := r.conn.DelRule(rule)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.isDefaultFwdRulesEnabled = false
|
|
||||||
return r.conn.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
|
||||||
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
|
|
||||||
ip, network, _ := net.ParseCIDR(cidr)
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
|
||||||
add := ipToAdd.Unmap()
|
|
||||||
|
|
||||||
var offSet uint32
|
|
||||||
if source {
|
|
||||||
offSet = 12 // src offset
|
|
||||||
} else {
|
|
||||||
offSet = 16 // dst offset
|
|
||||||
}
|
|
||||||
|
|
||||||
return []expr.Any{
|
|
||||||
// fetch src add
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: offSet,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
// net mask
|
|
||||||
&expr.Bitwise{
|
|
||||||
DestRegister: 1,
|
|
||||||
SourceRegister: 1,
|
|
||||||
Len: 4,
|
|
||||||
Mask: network.Mask,
|
|
||||||
Xor: zeroXor,
|
|
||||||
},
|
|
||||||
// net address
|
|
||||||
&expr.Cmp{
|
|
||||||
Register: 1,
|
|
||||||
Data: add.AsSlice(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
798
client/firewall/nftables/router_linux.go
Normal file
798
client/firewall/nftables/router_linux.go
Normal file
@ -0,0 +1,798 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/binaryutil"
|
||||||
|
"github.com/google/nftables/expr"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
chainNameRoutingFw = "netbird-rt-fwd"
|
||||||
|
chainNameRoutingNat = "netbird-rt-nat"
|
||||||
|
chainNameForward = "FORWARD"
|
||||||
|
|
||||||
|
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||||
|
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||||
|
)
|
||||||
|
|
||||||
|
const refreshRulesMapError = "refresh rules map: %w"
|
||||||
|
|
||||||
|
var (
|
||||||
|
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
type router struct {
|
||||||
|
ctx context.Context
|
||||||
|
stop context.CancelFunc
|
||||||
|
conn *nftables.Conn
|
||||||
|
workTable *nftables.Table
|
||||||
|
filterTable *nftables.Table
|
||||||
|
chains map[string]*nftables.Chain
|
||||||
|
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||||
|
rules map[string]*nftables.Rule
|
||||||
|
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
||||||
|
|
||||||
|
wgIface iFaceMapper
|
||||||
|
legacyManagement bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
|
||||||
|
r := &router{
|
||||||
|
ctx: ctx,
|
||||||
|
stop: cancel,
|
||||||
|
conn: &nftables.Conn{},
|
||||||
|
workTable: workTable,
|
||||||
|
chains: make(map[string]*nftables.Chain),
|
||||||
|
rules: make(map[string]*nftables.Rule),
|
||||||
|
wgIface: wgIface,
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ipsetCounter = refcounter.New(
|
||||||
|
r.createIpSet,
|
||||||
|
r.deleteIpSet,
|
||||||
|
)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
r.filterTable, err = r.loadFilterTable()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errFilterTableNotFound) {
|
||||||
|
log.Warnf("table 'filter' not found for forward rules")
|
||||||
|
} else {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.cleanUpDefaultForwardRules()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.createContainers()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create containers for route: %s", err)
|
||||||
|
}
|
||||||
|
return r, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset cleans existing nftables default forward rules from the system
|
||||||
|
func (r *router) Reset() error {
|
||||||
|
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||||
|
r.ipsetCounter.Clear()
|
||||||
|
|
||||||
|
return r.cleanUpDefaultForwardRules()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) cleanUpDefaultForwardRules() error {
|
||||||
|
if r.filterTable == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list chains: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, chain := range chains {
|
||||||
|
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := r.conn.GetRules(r.filterTable, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range rules {
|
||||||
|
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||||
|
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
|
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, table := range tables {
|
||||||
|
if table.Name == "filter" {
|
||||||
|
return table, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errFilterTableNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) createContainers() error {
|
||||||
|
|
||||||
|
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameRoutingFw,
|
||||||
|
Table: r.workTable,
|
||||||
|
})
|
||||||
|
|
||||||
|
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||||
|
|
||||||
|
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameRoutingNat,
|
||||||
|
Table: r.workTable,
|
||||||
|
Hooknum: nftables.ChainHookPostrouting,
|
||||||
|
Priority: nftables.ChainPriorityNATSource - 1,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
})
|
||||||
|
|
||||||
|
r.acceptForwardRules()
|
||||||
|
|
||||||
|
err := r.refreshRulesMap()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.conn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRouteFiltering appends a nftables rule to the routing chain
|
||||||
|
func (r *router) AddRouteFiltering(
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination netip.Prefix,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
|
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
|
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||||
|
return ruleKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
chain := r.chains[chainNameRoutingFw]
|
||||||
|
var exprs []expr.Any
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||||
|
// If it's 0.0.0.0/0, we don't need to add any source matching
|
||||||
|
case len(sources) == 1:
|
||||||
|
// If there's only one source, we can use it directly
|
||||||
|
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
|
||||||
|
default:
|
||||||
|
// If there are multiple sources, create or get an ipset
|
||||||
|
var err error
|
||||||
|
exprs, err = r.getIpSetExprs(sources, exprs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get ipset expressions: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle destination
|
||||||
|
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
|
||||||
|
|
||||||
|
// Handle protocol
|
||||||
|
if proto != firewall.ProtocolALL {
|
||||||
|
protoNum, err := protoToInt(proto)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||||
|
}
|
||||||
|
exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1})
|
||||||
|
exprs = append(exprs, &expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{protoNum},
|
||||||
|
})
|
||||||
|
|
||||||
|
exprs = append(exprs, applyPort(sPort, true)...)
|
||||||
|
exprs = append(exprs, applyPort(dPort, false)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs = append(exprs, &expr.Counter{})
|
||||||
|
|
||||||
|
var verdict expr.VerdictKind
|
||||||
|
if action == firewall.ActionAccept {
|
||||||
|
verdict = expr.VerdictAccept
|
||||||
|
} else {
|
||||||
|
verdict = expr.VerdictDrop
|
||||||
|
}
|
||||||
|
exprs = append(exprs, &expr.Verdict{Kind: verdict})
|
||||||
|
|
||||||
|
rule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(ruleKey),
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[string(ruleKey)] = r.conn.AddRule(rule)
|
||||||
|
|
||||||
|
return ruleKey, r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
||||||
|
setName := firewall.GenerateSetName(sources)
|
||||||
|
ref, err := r.ipsetCounter.Increment(setName, sources)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs = append(exprs,
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 12,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: ref.Out.Name,
|
||||||
|
SetID: ref.Out.ID,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return exprs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := rule.GetRuleID()
|
||||||
|
nftRule, exists := r.rules[ruleKey]
|
||||||
|
if !exists {
|
||||||
|
log.Debugf("route rule %s not found", ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
setName := r.findSetNameInRule(nftRule)
|
||||||
|
|
||||||
|
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||||
|
return fmt.Errorf("delete: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if setName != "" {
|
||||||
|
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||||
|
return fmt.Errorf("decrement ipset reference: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
|
||||||
|
// overlapping prefixes will result in an error, so we need to merge them
|
||||||
|
sources = firewall.MergeIPRanges(sources)
|
||||||
|
|
||||||
|
set := &nftables.Set{
|
||||||
|
Name: setName,
|
||||||
|
Table: r.workTable,
|
||||||
|
// required for prefixes
|
||||||
|
Interval: true,
|
||||||
|
KeyType: nftables.TypeIPAddr,
|
||||||
|
}
|
||||||
|
|
||||||
|
var elements []nftables.SetElement
|
||||||
|
for _, prefix := range sources {
|
||||||
|
// TODO: Implement IPv6 support
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
||||||
|
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
||||||
|
firstIP := prefix.Addr()
|
||||||
|
lastIP := calculateLastIP(prefix).Next()
|
||||||
|
|
||||||
|
elements = append(elements,
|
||||||
|
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
|
||||||
|
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
|
||||||
|
nftables.SetElement{Key: firstIP.AsSlice()},
|
||||||
|
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.AddSet(set, elements); err != nil {
|
||||||
|
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
||||||
|
|
||||||
|
return set, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateLastIP determines the last IP in a given prefix.
|
||||||
|
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
||||||
|
hostMask := ^uint32(0) >> prefix.Masked().Bits()
|
||||||
|
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
|
||||||
|
|
||||||
|
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Utility function to convert netip.Addr to uint32.
|
||||||
|
func uint32FromNetipAddr(addr netip.Addr) uint32 {
|
||||||
|
b := addr.As4()
|
||||||
|
return binary.BigEndian.Uint32(b[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Utility function to convert uint32 to a netip-compatible byte slice.
|
||||||
|
func uint32ToBytes(ip uint32) [4]byte {
|
||||||
|
var b [4]byte
|
||||||
|
binary.BigEndian.PutUint32(b[:], ip)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
|
||||||
|
r.conn.DelSet(set)
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Deleted unused ipset %s", setName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
|
||||||
|
for _, e := range rule.Exprs {
|
||||||
|
if lookup, ok := e.(*expr.Lookup); ok {
|
||||||
|
return lookup.SetName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete rule %s: %w", ruleKey, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
log.Debugf("removed route rule %s", ruleKey)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddNatRule appends a nftables rule pair to the nat chain
|
||||||
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.legacyManagement {
|
||||||
|
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||||
|
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||||
|
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair.Masquerade {
|
||||||
|
if err := r.addNatRule(pair); err != nil {
|
||||||
|
return fmt.Errorf("add nat rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
|
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||||
|
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||||
|
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||||
|
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||||
|
|
||||||
|
dir := expr.MetaKeyIIFNAME
|
||||||
|
if pair.Inverse {
|
||||||
|
dir = expr.MetaKeyOIFNAME
|
||||||
|
}
|
||||||
|
|
||||||
|
intf := ifname(r.wgIface.Name())
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: dir,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: intf,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs = append(exprs, sourceExp...)
|
||||||
|
exprs = append(exprs, destExp...)
|
||||||
|
exprs = append(exprs,
|
||||||
|
&expr.Counter{}, &expr.Masq{},
|
||||||
|
)
|
||||||
|
|
||||||
|
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||||
|
|
||||||
|
if _, exists := r.rules[ruleKey]; exists {
|
||||||
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
|
return fmt.Errorf("remove routing rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(ruleKey),
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||||
|
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
|
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||||
|
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||||
|
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
|
||||||
|
|
||||||
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
|
if _, exists := r.rules[ruleKey]; exists {
|
||||||
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
|
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
|
Exprs: expression,
|
||||||
|
UserData: []byte(ruleKey),
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
|
||||||
|
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
} else {
|
||||||
|
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLegacyManagement returns the route manager's legacy management mode
|
||||||
|
func (r *router) GetLegacyManagement() bool {
|
||||||
|
return r.legacyManagement
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||||
|
func (r *router) SetLegacyManagement(isLegacy bool) {
|
||||||
|
r.legacyManagement = isLegacy
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||||
|
func (r *router) RemoveAllLegacyRouteRules() error {
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for k, rule := range r.rules {
|
||||||
|
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure
|
||||||
|
// that our traffic is not dropped by existing rules there.
|
||||||
|
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||||
|
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||||
|
func (r *router) acceptForwardRules() {
|
||||||
|
if r.filterTable == nil {
|
||||||
|
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
intf := ifname(r.wgIface.Name())
|
||||||
|
|
||||||
|
// Rule for incoming interface (iif) with counter
|
||||||
|
iifRule := &nftables.Rule{
|
||||||
|
Table: r.filterTable,
|
||||||
|
Chain: &nftables.Chain{
|
||||||
|
Name: "FORWARD",
|
||||||
|
Table: r.filterTable,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
},
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: intf,
|
||||||
|
},
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||||
|
},
|
||||||
|
UserData: []byte(userDataAcceptForwardRuleIif),
|
||||||
|
}
|
||||||
|
r.conn.InsertRule(iifRule)
|
||||||
|
|
||||||
|
// Rule for outgoing interface (oif) with counter
|
||||||
|
oifRule := &nftables.Rule{
|
||||||
|
Table: r.filterTable,
|
||||||
|
Chain: &nftables.Chain{
|
||||||
|
Name: "FORWARD",
|
||||||
|
Table: r.filterTable,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
},
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: intf,
|
||||||
|
},
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeySTATE,
|
||||||
|
Register: 2,
|
||||||
|
},
|
||||||
|
&expr.Bitwise{
|
||||||
|
SourceRegister: 2,
|
||||||
|
DestRegister: 2,
|
||||||
|
Len: 4,
|
||||||
|
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
|
||||||
|
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 2,
|
||||||
|
Data: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||||
|
},
|
||||||
|
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||||
|
}
|
||||||
|
|
||||||
|
r.conn.InsertRule(oifRule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveNatRule removes a nftables rule pair from nat chains
|
||||||
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
|
return fmt.Errorf("remove nat rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
|
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
|
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("nftables: removed rules for %s", pair.Destination)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
|
||||||
|
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||||
|
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||||
|
|
||||||
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
|
err := r.conn.DelRule(rule)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
} else {
|
||||||
|
log.Debugf("nftables: nat rule %s not found", ruleKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||||
|
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||||
|
func (r *router) refreshRulesMap() error {
|
||||||
|
for _, chain := range r.chains {
|
||||||
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
||||||
|
}
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 {
|
||||||
|
r.rules[string(rule.UserData)] = rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||||
|
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
||||||
|
var offset uint32
|
||||||
|
if source {
|
||||||
|
offset = 12 // src offset
|
||||||
|
} else {
|
||||||
|
offset = 16 // dst offset
|
||||||
|
}
|
||||||
|
|
||||||
|
ones := prefix.Bits()
|
||||||
|
// 0.0.0.0/0 doesn't need extra expressions
|
||||||
|
if ones == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mask := net.CIDRMask(ones, 32)
|
||||||
|
|
||||||
|
return []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: offset,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
// netmask
|
||||||
|
&expr.Bitwise{
|
||||||
|
DestRegister: 1,
|
||||||
|
SourceRegister: 1,
|
||||||
|
Len: 4,
|
||||||
|
Mask: mask,
|
||||||
|
Xor: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
// net address
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: prefix.Masked().Addr().AsSlice(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||||
|
if port == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var exprs []expr.Any
|
||||||
|
|
||||||
|
offset := uint32(2) // Default offset for destination port
|
||||||
|
if isSource {
|
||||||
|
offset = 0 // Offset for source port
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs = append(exprs, &expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseTransportHeader,
|
||||||
|
Offset: offset,
|
||||||
|
Len: 2,
|
||||||
|
})
|
||||||
|
|
||||||
|
if port.IsRange && len(port.Values) == 2 {
|
||||||
|
// Handle port range
|
||||||
|
exprs = append(exprs,
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpGte,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpLte,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// Handle single port or multiple ports
|
||||||
|
for i, p := range port.Values {
|
||||||
|
if i > 0 {
|
||||||
|
// Add a bitwise OR operation between port checks
|
||||||
|
exprs = append(exprs, &expr.Bitwise{
|
||||||
|
SourceRegister: 1,
|
||||||
|
DestRegister: 1,
|
||||||
|
Len: 4,
|
||||||
|
Mask: []byte{0x00, 0x00, 0xff, 0xff},
|
||||||
|
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
exprs = append(exprs, &expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return exprs
|
||||||
|
}
|
@ -4,11 +4,15 @@ package nftables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
@ -24,56 +28,50 @@ const (
|
|||||||
NFTABLES
|
NFTABLES
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||||
if check() != NFTABLES {
|
if check() != NFTABLES {
|
||||||
t.Skip("nftables not supported on this OS")
|
t.Skip("nftables not supported on this OS")
|
||||||
}
|
}
|
||||||
|
|
||||||
table, err := createWorkTable()
|
table, err := createWorkTable()
|
||||||
if err != nil {
|
require.NoError(t, err, "Failed to create work table")
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
for _, testCase := range test.InsertRuleTestCases {
|
for _, testCase := range test.InsertRuleTestCases {
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
manager, err := newRouter(context.TODO(), table)
|
manager, err := newRouter(context.TODO(), table, ifaceMock)
|
||||||
require.NoError(t, err, "failed to create router")
|
require.NoError(t, err, "failed to create router")
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer manager.ResetForwardRules()
|
defer func(manager *router) {
|
||||||
|
require.NoError(t, manager.Reset(), "failed to reset rules")
|
||||||
|
}(manager)
|
||||||
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
err = manager.AddRoutingRules(testCase.InputPair)
|
err = manager.AddNatRule(testCase.InputPair)
|
||||||
defer func() {
|
require.NoError(t, err, "pair should be inserted")
|
||||||
_ = manager.RemoveRoutingRules(testCase.InputPair)
|
|
||||||
}()
|
|
||||||
require.NoError(t, err, "forwarding pair should be inserted")
|
|
||||||
|
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
defer func(manager *router, pair firewall.RouterPair) {
|
||||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule")
|
||||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
}(manager, testCase.InputPair)
|
||||||
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
|
||||||
|
|
||||||
found := 0
|
|
||||||
for _, chain := range manager.chains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
|
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
|
|
||||||
found = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
||||||
|
|
||||||
if testCase.InputPair.Masquerade {
|
if testCase.InputPair.Masquerade {
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||||
|
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||||
|
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||||
|
testingExpression = append(testingExpression,
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(ifaceMock.Name()),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||||
found := 0
|
found := 0
|
||||||
for _, chain := range manager.chains {
|
for _, chain := range manager.chains {
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
@ -88,27 +86,20 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||||
}
|
}
|
||||||
|
|
||||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
|
||||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
|
||||||
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
|
|
||||||
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
|
||||||
|
|
||||||
found = 0
|
|
||||||
for _, chain := range manager.chains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
|
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
|
|
||||||
found = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
||||||
|
|
||||||
if testCase.InputPair.Masquerade {
|
if testCase.InputPair.Masquerade {
|
||||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||||
|
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||||
|
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||||
|
testingExpression = append(testingExpression,
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(ifaceMock.Name()),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||||
found := 0
|
found := 0
|
||||||
for _, chain := range manager.chains {
|
for _, chain := range manager.chains {
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
@ -122,45 +113,37 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
}
|
}
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||||
if check() != NFTABLES {
|
if check() != NFTABLES {
|
||||||
t.Skip("nftables not supported on this OS")
|
t.Skip("nftables not supported on this OS")
|
||||||
}
|
}
|
||||||
|
|
||||||
table, err := createWorkTable()
|
table, err := createWorkTable()
|
||||||
if err != nil {
|
require.NoError(t, err, "Failed to create work table")
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
for _, testCase := range test.RemoveRuleTestCases {
|
for _, testCase := range test.RemoveRuleTestCases {
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
manager, err := newRouter(context.TODO(), table)
|
manager, err := newRouter(context.TODO(), table, ifaceMock)
|
||||||
require.NoError(t, err, "failed to create router")
|
require.NoError(t, err, "failed to create router")
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer manager.ResetForwardRules()
|
defer func(manager *router) {
|
||||||
|
require.NoError(t, manager.Reset(), "failed to reset rules")
|
||||||
|
}(manager)
|
||||||
|
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||||
|
|
||||||
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
|
||||||
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
|
||||||
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: manager.workTable,
|
|
||||||
Chain: manager.chains[chainNameRouteingFw],
|
|
||||||
Exprs: forwardExp,
|
|
||||||
UserData: []byte(forwardRuleKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||||
|
|
||||||
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||||
Table: manager.workTable,
|
Table: manager.workTable,
|
||||||
@ -169,20 +152,11 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
UserData: []byte(natRuleKey),
|
UserData: []byte(natRuleKey),
|
||||||
})
|
})
|
||||||
|
|
||||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source)
|
||||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
|
||||||
|
|
||||||
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
|
||||||
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
|
||||||
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: manager.workTable,
|
|
||||||
Chain: manager.chains[chainNameRouteingFw],
|
|
||||||
Exprs: forwardExp,
|
|
||||||
UserData: []byte(inForwardRuleKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||||
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||||
|
|
||||||
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||||
Table: manager.workTable,
|
Table: manager.workTable,
|
||||||
@ -194,9 +168,10 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
err = nftablesTestingClient.Flush()
|
err = nftablesTestingClient.Flush()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
manager.ResetForwardRules()
|
err = manager.Reset()
|
||||||
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
err = manager.RemoveRoutingRules(testCase.InputPair)
|
err = manager.RemoveNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
for _, chain := range manager.chains {
|
for _, chain := range manager.chains {
|
||||||
@ -204,9 +179,7 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
|
|
||||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
||||||
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
|
|
||||||
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -215,6 +188,468 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTable()
|
||||||
|
require.NoError(t, err, "Failed to create work table")
|
||||||
|
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
r, err := newRouter(context.Background(), workTable, ifaceMock)
|
||||||
|
require.NoError(t, err, "Failed to create router")
|
||||||
|
|
||||||
|
defer func(r *router) {
|
||||||
|
require.NoError(t, r.Reset(), "Failed to reset rules")
|
||||||
|
}(r)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sources []netip.Prefix
|
||||||
|
destination netip.Prefix
|
||||||
|
proto firewall.Protocol
|
||||||
|
sPort *firewall.Port
|
||||||
|
dPort *firewall.Port
|
||||||
|
direction firewall.RuleDirection
|
||||||
|
action firewall.Action
|
||||||
|
expectSet bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Basic TCP rule with single source",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
proto: firewall.ProtocolTCP,
|
||||||
|
sPort: nil,
|
||||||
|
dPort: &firewall.Port{Values: []int{80}},
|
||||||
|
direction: firewall.RuleDirectionIN,
|
||||||
|
action: firewall.ActionAccept,
|
||||||
|
expectSet: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP rule with multiple sources",
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
},
|
||||||
|
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
proto: firewall.ProtocolUDP,
|
||||||
|
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||||
|
dPort: nil,
|
||||||
|
direction: firewall.RuleDirectionOUT,
|
||||||
|
action: firewall.ActionDrop,
|
||||||
|
expectSet: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "All protocols rule",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||||
|
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
proto: firewall.ProtocolALL,
|
||||||
|
sPort: nil,
|
||||||
|
dPort: nil,
|
||||||
|
direction: firewall.RuleDirectionIN,
|
||||||
|
action: firewall.ActionAccept,
|
||||||
|
expectSet: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMP rule",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
|
||||||
|
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
proto: firewall.ProtocolICMP,
|
||||||
|
sPort: nil,
|
||||||
|
dPort: nil,
|
||||||
|
direction: firewall.RuleDirectionIN,
|
||||||
|
action: firewall.ActionAccept,
|
||||||
|
expectSet: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TCP rule with multiple source ports",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||||
|
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
proto: firewall.ProtocolTCP,
|
||||||
|
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
||||||
|
dPort: nil,
|
||||||
|
direction: firewall.RuleDirectionOUT,
|
||||||
|
action: firewall.ActionAccept,
|
||||||
|
expectSet: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP rule with single IP and port range",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
|
||||||
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
proto: firewall.ProtocolUDP,
|
||||||
|
sPort: nil,
|
||||||
|
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||||
|
direction: firewall.RuleDirectionIN,
|
||||||
|
action: firewall.ActionDrop,
|
||||||
|
expectSet: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TCP rule with source and destination ports",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
|
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
proto: firewall.ProtocolTCP,
|
||||||
|
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||||
|
dPort: &firewall.Port{Values: []int{22}},
|
||||||
|
direction: firewall.RuleDirectionOUT,
|
||||||
|
action: firewall.ActionAccept,
|
||||||
|
expectSet: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop all incoming traffic",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
destination: netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
|
proto: firewall.ProtocolALL,
|
||||||
|
sPort: nil,
|
||||||
|
dPort: nil,
|
||||||
|
direction: firewall.RuleDirectionIN,
|
||||||
|
action: firewall.ActionDrop,
|
||||||
|
expectSet: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
|
// Check if the rule is in the internal map
|
||||||
|
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||||
|
assert.True(t, ok, "Rule not found in internal map")
|
||||||
|
|
||||||
|
t.Log("Internal rule expressions:")
|
||||||
|
for i, expr := range rule.Exprs {
|
||||||
|
t.Logf(" [%d] %T: %+v", i, expr, expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify internal rule content
|
||||||
|
verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
|
||||||
|
|
||||||
|
// Check if the rule exists in nftables and verify its content
|
||||||
|
rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
|
||||||
|
require.NoError(t, err, "Failed to get rules from nftables")
|
||||||
|
|
||||||
|
var nftRule *nftables.Rule
|
||||||
|
for _, rule := range rules {
|
||||||
|
if string(rule.UserData) == ruleKey.GetRuleID() {
|
||||||
|
nftRule = rule
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotNil(t, nftRule, "Rule not found in nftables")
|
||||||
|
t.Log("Actual nftables rule expressions:")
|
||||||
|
for i, expr := range nftRule.Exprs {
|
||||||
|
t.Logf(" [%d] %T: %+v", i, expr, expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify actual nftables rule content
|
||||||
|
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
err = r.DeleteRouteRule(ruleKey)
|
||||||
|
require.NoError(t, err, "Failed to delete rule")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNftablesCreateIpSet(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTable()
|
||||||
|
require.NoError(t, err, "Failed to create work table")
|
||||||
|
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
r, err := newRouter(context.Background(), workTable, ifaceMock)
|
||||||
|
require.NoError(t, err, "Failed to create router")
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sources []netip.Prefix
|
||||||
|
expected []netip.Prefix
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Single IP",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple IPs",
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.1/32"),
|
||||||
|
netip.MustParsePrefix("10.0.0.1/32"),
|
||||||
|
netip.MustParsePrefix("172.16.0.1/32"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single Subnet",
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple Subnets with Various Prefix Lengths",
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("203.0.113.0/26"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Mix of Single IPs and Subnets in Different Positions",
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.1/32"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/16"),
|
||||||
|
netip.MustParsePrefix("172.16.0.1/32"),
|
||||||
|
netip.MustParsePrefix("203.0.113.0/24"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Overlapping IPs/Subnets",
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/16"),
|
||||||
|
netip.MustParsePrefix("10.0.0.1/32"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.1.1/32"),
|
||||||
|
},
|
||||||
|
expected: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add this helper function inside TestNftablesCreateIpSet
|
||||||
|
printNftSets := func() {
|
||||||
|
cmd := exec.Command("nft", "list", "sets")
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Failed to run 'nft list sets': %v", err)
|
||||||
|
} else {
|
||||||
|
t.Logf("Current nft sets:\n%s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
setName := firewall.GenerateSetName(tt.sources)
|
||||||
|
set, err := r.createIpSet(setName, tt.sources)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Failed to create IP set: %v", err)
|
||||||
|
printNftSets()
|
||||||
|
require.NoError(t, err, "Failed to create IP set")
|
||||||
|
}
|
||||||
|
require.NotNil(t, set, "Created set is nil")
|
||||||
|
|
||||||
|
// Verify set properties
|
||||||
|
assert.Equal(t, setName, set.Name, "Set name mismatch")
|
||||||
|
assert.Equal(t, r.workTable, set.Table, "Set table mismatch")
|
||||||
|
assert.True(t, set.Interval, "Set interval property should be true")
|
||||||
|
assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch")
|
||||||
|
|
||||||
|
// Fetch the created set from nftables
|
||||||
|
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
|
||||||
|
require.NoError(t, err, "Failed to fetch created set")
|
||||||
|
require.NotNil(t, fetchedSet, "Fetched set is nil")
|
||||||
|
|
||||||
|
// Verify set elements
|
||||||
|
elements, err := r.conn.GetSetElements(fetchedSet)
|
||||||
|
require.NoError(t, err, "Failed to get set elements")
|
||||||
|
|
||||||
|
// Count the number of unique prefixes (excluding interval end markers)
|
||||||
|
uniquePrefixes := make(map[string]bool)
|
||||||
|
for _, elem := range elements {
|
||||||
|
if !elem.IntervalEnd {
|
||||||
|
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
|
||||||
|
uniquePrefixes[ip.String()] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check against expected merged prefixes
|
||||||
|
expectedCount := len(tt.expected)
|
||||||
|
if expectedCount == 0 {
|
||||||
|
expectedCount = len(tt.sources)
|
||||||
|
}
|
||||||
|
assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected")
|
||||||
|
|
||||||
|
// Verify each expected prefix is in the set
|
||||||
|
for _, expected := range tt.expected {
|
||||||
|
found := false
|
||||||
|
for _, elem := range elements {
|
||||||
|
if !elem.IntervalEnd {
|
||||||
|
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
|
||||||
|
if expected.Contains(ip) {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, found, "Expected prefix %s not found in set", expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.conn.DelSet(set)
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
t.Logf("Failed to delete set: %v", err)
|
||||||
|
printNftSets()
|
||||||
|
}
|
||||||
|
require.NoError(t, err, "Failed to delete set")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
assert.NotNil(t, rule, "Rule should not be nil")
|
||||||
|
|
||||||
|
// Verify sources and destination
|
||||||
|
if expectSet {
|
||||||
|
assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
|
||||||
|
} else if len(sources) == 1 && sources[0].Bits() != 0 {
|
||||||
|
if direction == firewall.RuleDirectionIN {
|
||||||
|
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
|
||||||
|
} else {
|
||||||
|
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if direction == firewall.RuleDirectionIN {
|
||||||
|
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
|
||||||
|
} else {
|
||||||
|
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify protocol
|
||||||
|
if proto != firewall.ProtocolALL {
|
||||||
|
assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify ports
|
||||||
|
if sPort != nil {
|
||||||
|
assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort)
|
||||||
|
}
|
||||||
|
if dPort != nil {
|
||||||
|
assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify action
|
||||||
|
assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action)
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsSetLookup(exprs []expr.Any) bool {
|
||||||
|
for _, e := range exprs {
|
||||||
|
if _, ok := e.(*expr.Lookup); ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool {
|
||||||
|
var offset uint32
|
||||||
|
if isSource {
|
||||||
|
offset = 12 // src offset
|
||||||
|
} else {
|
||||||
|
offset = 16 // dst offset
|
||||||
|
}
|
||||||
|
|
||||||
|
var payloadFound, bitwiseFound, cmpFound bool
|
||||||
|
for _, e := range exprs {
|
||||||
|
switch ex := e.(type) {
|
||||||
|
case *expr.Payload:
|
||||||
|
if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 {
|
||||||
|
payloadFound = true
|
||||||
|
}
|
||||||
|
case *expr.Bitwise:
|
||||||
|
if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 {
|
||||||
|
bitwiseFound = true
|
||||||
|
}
|
||||||
|
case *expr.Cmp:
|
||||||
|
if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 {
|
||||||
|
cmpFound = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
||||||
|
var offset uint32 = 2 // Default offset for destination port
|
||||||
|
if isSource {
|
||||||
|
offset = 0 // Offset for source port
|
||||||
|
}
|
||||||
|
|
||||||
|
var payloadFound, portMatchFound bool
|
||||||
|
for _, e := range exprs {
|
||||||
|
switch ex := e.(type) {
|
||||||
|
case *expr.Payload:
|
||||||
|
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
|
||||||
|
payloadFound = true
|
||||||
|
}
|
||||||
|
case *expr.Cmp:
|
||||||
|
if port.IsRange {
|
||||||
|
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
|
||||||
|
portMatchFound = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
|
||||||
|
portValue := binary.BigEndian.Uint16(ex.Data)
|
||||||
|
for _, p := range port.Values {
|
||||||
|
if uint16(p) == portValue {
|
||||||
|
portMatchFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if payloadFound && portMatchFound {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
|
||||||
|
var metaFound, cmpFound bool
|
||||||
|
expectedProto, _ := protoToInt(proto)
|
||||||
|
for _, e := range exprs {
|
||||||
|
switch ex := e.(type) {
|
||||||
|
case *expr.Meta:
|
||||||
|
if ex.Key == expr.MetaKeyL4PROTO {
|
||||||
|
metaFound = true
|
||||||
|
}
|
||||||
|
case *expr.Cmp:
|
||||||
|
if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto {
|
||||||
|
cmpFound = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return metaFound && cmpFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsAction(exprs []expr.Any, action firewall.Action) bool {
|
||||||
|
for _, e := range exprs {
|
||||||
|
if verdict, ok := e.(*expr.Verdict); ok {
|
||||||
|
switch action {
|
||||||
|
case firewall.ActionAccept:
|
||||||
|
return verdict.Kind == expr.VerdictAccept
|
||||||
|
case firewall.ActionDrop:
|
||||||
|
return verdict.Kind == expr.VerdictDrop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||||
func check() int {
|
func check() int {
|
||||||
nf := nftables.Conn{}
|
nf := nftables.Conn{}
|
||||||
@ -250,12 +685,12 @@ func createWorkTable() (*nftables.Table, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range tables {
|
for _, t := range tables {
|
||||||
if t.Name == tableName {
|
if t.Name == tableNameNetbird {
|
||||||
sConn.DelTable(t)
|
sConn.DelTable(t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
|
||||||
err = sConn.Flush()
|
err = sConn.Flush()
|
||||||
|
|
||||||
return table, err
|
return table, err
|
||||||
@ -273,7 +708,7 @@ func deleteWorkTable() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range tables {
|
for _, t := range tables {
|
||||||
if t.Name == tableName {
|
if t.Name == tableNameNetbird {
|
||||||
sConn.DelTable(t)
|
sConn.DelTable(t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package test
|
package test
|
||||||
|
|
||||||
import firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
InsertRuleTestCases = []struct {
|
InsertRuleTestCases = []struct {
|
||||||
@ -13,8 +15,8 @@ var (
|
|||||||
Name: "Insert Forwarding IPV4 Rule",
|
Name: "Insert Forwarding IPV4 Rule",
|
||||||
InputPair: firewall.RouterPair{
|
InputPair: firewall.RouterPair{
|
||||||
ID: "zxa",
|
ID: "zxa",
|
||||||
Source: "100.100.100.1/32",
|
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||||
Destination: "100.100.200.0/24",
|
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -22,8 +24,8 @@ var (
|
|||||||
Name: "Insert Forwarding And Nat IPV4 Rules",
|
Name: "Insert Forwarding And Nat IPV4 Rules",
|
||||||
InputPair: firewall.RouterPair{
|
InputPair: firewall.RouterPair{
|
||||||
ID: "zxa",
|
ID: "zxa",
|
||||||
Source: "100.100.100.1/32",
|
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||||
Destination: "100.100.200.0/24",
|
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -38,8 +40,8 @@ var (
|
|||||||
Name: "Remove Forwarding And Nat IPV4 Rules",
|
Name: "Remove Forwarding And Nat IPV4 Rules",
|
||||||
InputPair: firewall.RouterPair{
|
InputPair: firewall.RouterPair{
|
||||||
ID: "zxa",
|
ID: "zxa",
|
||||||
Source: "100.100.100.1/32",
|
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||||
Destination: "100.100.200.0/24",
|
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -3,6 +3,7 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@ -103,26 +104,26 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
return errRouteNotSupported
|
return errRouteNotSupported
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.InsertRoutingRules(pair)
|
return m.nativeFirewall.AddNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveRoutingRules removes a routing firewall rule
|
// RemoveNatRule removes a routing firewall rule
|
||||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
return errRouteNotSupported
|
return errRouteNotSupported
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.RemoveRoutingRules(pair)
|
return m.nativeFirewall.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
// AddPeerFiltering rule to the firewall
|
||||||
//
|
//
|
||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *Manager) AddFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
@ -188,8 +189,22 @@ func (m *Manager) AddFiltering(
|
|||||||
return []firewall.Rule{&r}, nil
|
return []firewall.Rule{&r}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) {
|
||||||
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
if m.nativeFirewall == nil {
|
||||||
|
return nil, errRouteNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return errRouteNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePeerRule from the firewall by rule definition
|
||||||
|
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@ -215,6 +230,11 @@ func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||||
|
func (m *Manager) SetLegacyManagement(_ bool) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
@ -395,7 +415,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
|||||||
for _, r := range arr {
|
for _, r := range arr {
|
||||||
if r.id == hookID {
|
if r.id == hookID {
|
||||||
rule := r
|
rule := r
|
||||||
return m.DeleteRule(&rule)
|
return m.DeletePeerRule(&rule)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -403,7 +423,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
|||||||
for _, r := range arr {
|
for _, r := range arr {
|
||||||
if r.id == hookID {
|
if r.id == hookID {
|
||||||
rule := r
|
rule := r
|
||||||
return m.DeleteRule(&rule)
|
return m.DeletePeerRule(&rule)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -49,7 +49,7 @@ func TestManagerCreate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManagerAddFiltering(t *testing.T) {
|
func TestManagerAddPeerFiltering(t *testing.T) {
|
||||||
isSetFilterCalled := false
|
isSetFilterCalled := false
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(iface.PacketFilter) error {
|
SetFilterFunc: func(iface.PacketFilter) error {
|
||||||
@ -71,7 +71,7 @@ func TestManagerAddFiltering(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@ -106,7 +106,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@ -119,14 +119,14 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
action = fw.ActionDrop
|
action = fw.ActionDrop
|
||||||
comment = "Test rule 2"
|
comment = "Test rule 2"
|
||||||
|
|
||||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule {
|
for _, r := range rule {
|
||||||
err = m.DeleteRule(r)
|
err = m.DeletePeerRule(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to delete rule: %v", err)
|
t.Errorf("failed to delete rule: %v", err)
|
||||||
return
|
return
|
||||||
@ -140,7 +140,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
err = m.DeleteRule(r)
|
err = m.DeletePeerRule(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to delete rule: %v", err)
|
t.Errorf("failed to delete rule: %v", err)
|
||||||
return
|
return
|
||||||
@ -252,7 +252,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@ -290,7 +290,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@ -406,9 +406,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
25
client/internal/acl/id/id.go
Normal file
25
client/internal/acl/id/id.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package id
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RuleID string
|
||||||
|
|
||||||
|
func (r RuleID) GetRuleID() string {
|
||||||
|
return string(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateRouteRuleKey(
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination netip.Prefix,
|
||||||
|
proto manager.Protocol,
|
||||||
|
sPort *manager.Port,
|
||||||
|
dPort *manager.Port,
|
||||||
|
action manager.Action,
|
||||||
|
) RuleID {
|
||||||
|
return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action))
|
||||||
|
}
|
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -12,6 +13,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
@ -23,16 +25,18 @@ type Manager interface {
|
|||||||
|
|
||||||
// DefaultManager uses firewall manager to handle
|
// DefaultManager uses firewall manager to handle
|
||||||
type DefaultManager struct {
|
type DefaultManager struct {
|
||||||
firewall firewall.Manager
|
firewall firewall.Manager
|
||||||
ipsetCounter int
|
ipsetCounter int
|
||||||
rulesPairs map[string][]firewall.Rule
|
peerRulesPairs map[id.RuleID][]firewall.Rule
|
||||||
mutex sync.Mutex
|
routeRules map[id.RuleID]struct{}
|
||||||
|
mutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||||
return &DefaultManager{
|
return &DefaultManager{
|
||||||
firewall: fm,
|
firewall: fm,
|
||||||
rulesPairs: make(map[string][]firewall.Rule),
|
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
|
||||||
|
routeRules: make(map[id.RuleID]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,7 +50,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
total := 0
|
total := 0
|
||||||
for _, pairs := range d.rulesPairs {
|
for _, pairs := range d.peerRulesPairs {
|
||||||
total += len(pairs)
|
total += len(pairs)
|
||||||
}
|
}
|
||||||
log.Infof(
|
log.Infof(
|
||||||
@ -59,21 +63,34 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
d.applyPeerACLs(networkMap)
|
||||||
if err := d.firewall.Flush(); err != nil {
|
|
||||||
log.Error("failed to flush firewall rules: ", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
|
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
|
||||||
|
// then the mgmt server is older than the client, and we need to allow all traffic for routes
|
||||||
|
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
|
||||||
|
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
|
||||||
|
log.Errorf("failed to set legacy management flag: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
|
||||||
|
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.firewall.Flush(); err != nil {
|
||||||
|
log.Error("failed to flush firewall rules: ", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
||||||
|
|
||||||
enableSSH := (networkMap.PeerConfig != nil &&
|
enableSSH := networkMap.PeerConfig != nil &&
|
||||||
networkMap.PeerConfig.SshConfig != nil &&
|
networkMap.PeerConfig.SshConfig != nil &&
|
||||||
networkMap.PeerConfig.SshConfig.SshEnabled)
|
networkMap.PeerConfig.SshConfig.SshEnabled
|
||||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
|
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
||||||
enableSSH = enableSSH && !ok
|
enableSSH = enableSSH && !ok
|
||||||
}
|
}
|
||||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok {
|
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
|
||||||
enableSSH = enableSSH && !ok
|
enableSSH = enableSSH && !ok
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -83,9 +100,9 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
if enableSSH {
|
if enableSSH {
|
||||||
rules = append(rules, &mgmProto.FirewallRule{
|
rules = append(rules, &mgmProto.FirewallRule{
|
||||||
PeerIP: "0.0.0.0",
|
PeerIP: "0.0.0.0",
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
Port: strconv.Itoa(ssh.DefaultSSHPort),
|
Port: strconv.Itoa(ssh.DefaultSSHPort),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -97,20 +114,20 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
rules = append(rules,
|
rules = append(rules,
|
||||||
&mgmProto.FirewallRule{
|
&mgmProto.FirewallRule{
|
||||||
PeerIP: "0.0.0.0",
|
PeerIP: "0.0.0.0",
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
&mgmProto.FirewallRule{
|
&mgmProto.FirewallRule{
|
||||||
PeerIP: "0.0.0.0",
|
PeerIP: "0.0.0.0",
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
newRulePairs := make(map[string][]firewall.Rule)
|
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||||
ipsetByRuleSelectors := make(map[string]string)
|
ipsetByRuleSelectors := make(map[string]string)
|
||||||
|
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
@ -130,29 +147,97 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
if len(rules) > 0 {
|
if len(rules) > 0 {
|
||||||
d.rulesPairs[pairID] = rulePair
|
d.peerRulesPairs[pairID] = rulePair
|
||||||
newRulePairs[pairID] = rulePair
|
newRulePairs[pairID] = rulePair
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for pairID, rules := range d.rulesPairs {
|
for pairID, rules := range d.peerRulesPairs {
|
||||||
if _, ok := newRulePairs[pairID]; !ok {
|
if _, ok := newRulePairs[pairID]; !ok {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := d.firewall.DeleteRule(rule); err != nil {
|
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||||
log.Errorf("failed to delete firewall rule: %v", err)
|
log.Errorf("failed to delete peer firewall rule: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(d.rulesPairs, pairID)
|
delete(d.peerRulesPairs, pairID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
d.rulesPairs = newRulePairs
|
d.peerRulesPairs = newRulePairs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
|
||||||
|
var newRouteRules = make(map[id.RuleID]struct{})
|
||||||
|
for _, rule := range rules {
|
||||||
|
id, err := d.applyRouteACL(rule)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply route ACL: %w", err)
|
||||||
|
}
|
||||||
|
newRouteRules[id] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for id := range d.routeRules {
|
||||||
|
if _, ok := newRouteRules[id]; !ok {
|
||||||
|
if err := d.firewall.DeleteRouteRule(id); err != nil {
|
||||||
|
log.Errorf("failed to delete route firewall rule: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(d.routeRules, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d.routeRules = newRouteRules
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
|
||||||
|
if len(rule.SourceRanges) == 0 {
|
||||||
|
return "", fmt.Errorf("source ranges is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
var sources []netip.Prefix
|
||||||
|
for _, sourceRange := range rule.SourceRanges {
|
||||||
|
source, err := netip.ParsePrefix(sourceRange)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("parse source range: %w", err)
|
||||||
|
}
|
||||||
|
sources = append(sources, source)
|
||||||
|
}
|
||||||
|
|
||||||
|
var destination netip.Prefix
|
||||||
|
if rule.IsDynamic {
|
||||||
|
destination = getDefault(sources[0])
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
destination, err = netip.ParsePrefix(rule.Destination)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("parse destination: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protocol, err := convertToFirewallProtocol(rule.Protocol)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("invalid protocol: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
action, err := convertFirewallAction(rule.Action)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("invalid action: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dPorts := convertPortInfo(rule.PortInfo)
|
||||||
|
|
||||||
|
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("add route rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return id.RuleID(addedRule.GetRuleID()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||||
r *mgmProto.FirewallRule,
|
r *mgmProto.FirewallRule,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
) (string, []firewall.Rule, error) {
|
) (id.RuleID, []firewall.Rule, error) {
|
||||||
ip := net.ParseIP(r.PeerIP)
|
ip := net.ParseIP(r.PeerIP)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||||
@ -179,16 +264,16 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
|
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "")
|
||||||
if rulesPair, ok := d.rulesPairs[ruleID]; ok {
|
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
||||||
return ruleID, rulesPair, nil
|
return ruleID, rulesPair, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
switch r.Direction {
|
switch r.Direction {
|
||||||
case mgmProto.FirewallRule_IN:
|
case mgmProto.RuleDirection_IN:
|
||||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||||
case mgmProto.FirewallRule_OUT:
|
case mgmProto.RuleDirection_OUT:
|
||||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
@ -210,7 +295,7 @@ func (d *DefaultManager) addInRules(
|
|||||||
comment string,
|
comment string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
rule, err := d.firewall.AddFiltering(
|
rule, err := d.firewall.AddPeerFiltering(
|
||||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
@ -221,7 +306,7 @@ func (d *DefaultManager) addInRules(
|
|||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.firewall.AddFiltering(
|
rule, err = d.firewall.AddPeerFiltering(
|
||||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
@ -239,7 +324,7 @@ func (d *DefaultManager) addOutRules(
|
|||||||
comment string,
|
comment string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
rule, err := d.firewall.AddFiltering(
|
rule, err := d.firewall.AddPeerFiltering(
|
||||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
@ -250,7 +335,7 @@ func (d *DefaultManager) addOutRules(
|
|||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.firewall.AddFiltering(
|
rule, err = d.firewall.AddPeerFiltering(
|
||||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
@ -259,21 +344,21 @@ func (d *DefaultManager) addOutRules(
|
|||||||
return append(rules, rule...), nil
|
return append(rules, rule...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRuleID() returns unique ID for the rule based on its parameters.
|
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
||||||
func (d *DefaultManager) getRuleID(
|
func (d *DefaultManager) getPeerRuleID(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
direction int,
|
direction int,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
comment string,
|
comment string,
|
||||||
) string {
|
) id.RuleID {
|
||||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
||||||
if port != nil {
|
if port != nil {
|
||||||
idStr += port.String()
|
idStr += port.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
|
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
||||||
}
|
}
|
||||||
|
|
||||||
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
||||||
@ -283,7 +368,7 @@ func (d *DefaultManager) getRuleID(
|
|||||||
// but other has port definitions or has drop policy.
|
// but other has port definitions or has drop policy.
|
||||||
func (d *DefaultManager) squashAcceptRules(
|
func (d *DefaultManager) squashAcceptRules(
|
||||||
networkMap *mgmProto.NetworkMap,
|
networkMap *mgmProto.NetworkMap,
|
||||||
) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
|
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
|
||||||
totalIPs := 0
|
totalIPs := 0
|
||||||
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
||||||
for range p.AllowedIps {
|
for range p.AllowedIps {
|
||||||
@ -291,14 +376,14 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int
|
type protoMatch map[mgmProto.RuleProtocol]map[string]int
|
||||||
|
|
||||||
in := protoMatch{}
|
in := protoMatch{}
|
||||||
out := protoMatch{}
|
out := protoMatch{}
|
||||||
|
|
||||||
// trace which type of protocols was squashed
|
// trace which type of protocols was squashed
|
||||||
squashedRules := []*mgmProto.FirewallRule{}
|
squashedRules := []*mgmProto.FirewallRule{}
|
||||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
|
||||||
|
|
||||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
// this function we use to do calculation, can we squash the rules by protocol or not.
|
||||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
// We summ amount of Peers IP for given protocol we found in original rules list.
|
||||||
@ -308,7 +393,7 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
//
|
//
|
||||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
|
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
|
||||||
drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != ""
|
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
||||||
if drop {
|
if drop {
|
||||||
protocols[r.Protocol] = map[string]int{}
|
protocols[r.Protocol] = map[string]int{}
|
||||||
return
|
return
|
||||||
@ -336,7 +421,7 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
|
|
||||||
for i, r := range networkMap.FirewallRules {
|
for i, r := range networkMap.FirewallRules {
|
||||||
// calculate squash for different directions
|
// calculate squash for different directions
|
||||||
if r.Direction == mgmProto.FirewallRule_IN {
|
if r.Direction == mgmProto.RuleDirection_IN {
|
||||||
addRuleToCalculationMap(i, r, in)
|
addRuleToCalculationMap(i, r, in)
|
||||||
} else {
|
} else {
|
||||||
addRuleToCalculationMap(i, r, out)
|
addRuleToCalculationMap(i, r, out)
|
||||||
@ -345,14 +430,14 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
|
|
||||||
// order of squashing by protocol is important
|
// order of squashing by protocol is important
|
||||||
// only for their first element ALL, it must be done first
|
// only for their first element ALL, it must be done first
|
||||||
protocolOrders := []mgmProto.FirewallRuleProtocol{
|
protocolOrders := []mgmProto.RuleProtocol{
|
||||||
mgmProto.FirewallRule_ALL,
|
mgmProto.RuleProtocol_ALL,
|
||||||
mgmProto.FirewallRule_ICMP,
|
mgmProto.RuleProtocol_ICMP,
|
||||||
mgmProto.FirewallRule_TCP,
|
mgmProto.RuleProtocol_TCP,
|
||||||
mgmProto.FirewallRule_UDP,
|
mgmProto.RuleProtocol_UDP,
|
||||||
}
|
}
|
||||||
|
|
||||||
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
squash := func(matches protoMatch, direction mgmProto.RuleDirection) {
|
||||||
for _, protocol := range protocolOrders {
|
for _, protocol := range protocolOrders {
|
||||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
||||||
// don't squash if :
|
// don't squash if :
|
||||||
@ -365,12 +450,12 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
|
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
|
||||||
PeerIP: "0.0.0.0",
|
PeerIP: "0.0.0.0",
|
||||||
Direction: direction,
|
Direction: direction,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: protocol,
|
Protocol: protocol,
|
||||||
})
|
})
|
||||||
squashedProtocols[protocol] = struct{}{}
|
squashedProtocols[protocol] = struct{}{}
|
||||||
|
|
||||||
if protocol == mgmProto.FirewallRule_ALL {
|
if protocol == mgmProto.RuleProtocol_ALL {
|
||||||
// if we have ALL traffic type squashed rule
|
// if we have ALL traffic type squashed rule
|
||||||
// it allows all other type of traffic, so we can stop processing
|
// it allows all other type of traffic, so we can stop processing
|
||||||
break
|
break
|
||||||
@ -378,11 +463,11 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
squash(in, mgmProto.FirewallRule_IN)
|
squash(in, mgmProto.RuleDirection_IN)
|
||||||
squash(out, mgmProto.FirewallRule_OUT)
|
squash(out, mgmProto.RuleDirection_OUT)
|
||||||
|
|
||||||
// if all protocol was squashed everything is allow and we can ignore all other rules
|
// if all protocol was squashed everything is allow and we can ignore all other rules
|
||||||
if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
|
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
||||||
return squashedRules, squashedProtocols
|
return squashedRules, squashedProtocols
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -412,26 +497,26 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
|
|||||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
|
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
||||||
log.Debugf("rollback ACL to previous state")
|
log.Debugf("rollback ACL to previous state")
|
||||||
for _, rules := range newRulePairs {
|
for _, rules := range newRulePairs {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := d.firewall.DeleteRule(rule); err != nil {
|
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
|
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case mgmProto.FirewallRule_TCP:
|
case mgmProto.RuleProtocol_TCP:
|
||||||
return firewall.ProtocolTCP, nil
|
return firewall.ProtocolTCP, nil
|
||||||
case mgmProto.FirewallRule_UDP:
|
case mgmProto.RuleProtocol_UDP:
|
||||||
return firewall.ProtocolUDP, nil
|
return firewall.ProtocolUDP, nil
|
||||||
case mgmProto.FirewallRule_ICMP:
|
case mgmProto.RuleProtocol_ICMP:
|
||||||
return firewall.ProtocolICMP, nil
|
return firewall.ProtocolICMP, nil
|
||||||
case mgmProto.FirewallRule_ALL:
|
case mgmProto.RuleProtocol_ALL:
|
||||||
return firewall.ProtocolALL, nil
|
return firewall.ProtocolALL, nil
|
||||||
default:
|
default:
|
||||||
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
|
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
|
||||||
@ -442,13 +527,41 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
|
|||||||
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
|
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
|
func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) {
|
||||||
switch action {
|
switch action {
|
||||||
case mgmProto.FirewallRule_ACCEPT:
|
case mgmProto.RuleAction_ACCEPT:
|
||||||
return firewall.ActionAccept, nil
|
return firewall.ActionAccept, nil
|
||||||
case mgmProto.FirewallRule_DROP:
|
case mgmProto.RuleAction_DROP:
|
||||||
return firewall.ActionDrop, nil
|
return firewall.ActionDrop, nil
|
||||||
default:
|
default:
|
||||||
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
|
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
|
||||||
|
if portInfo == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if portInfo.GetPort() != 0 {
|
||||||
|
return &firewall.Port{
|
||||||
|
Values: []int{int(portInfo.GetPort())},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if portInfo.GetRange() != nil {
|
||||||
|
return &firewall.Port{
|
||||||
|
IsRange: true,
|
||||||
|
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDefault(prefix netip.Prefix) netip.Prefix {
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||||
|
}
|
||||||
|
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||||
|
}
|
||||||
|
@ -19,16 +19,16 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
FirewallRules: []*mgmProto.FirewallRule{
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.1",
|
PeerIP: "10.93.0.1",
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
Port: "80",
|
Port: "80",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.2",
|
PeerIP: "10.93.0.2",
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
Action: mgmProto.FirewallRule_DROP,
|
Action: mgmProto.RuleAction_DROP,
|
||||||
Protocol: mgmProto.FirewallRule_UDP,
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
Port: "53",
|
Port: "53",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -65,16 +65,16 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
t.Run("apply firewall rules", func(t *testing.T) {
|
t.Run("apply firewall rules", func(t *testing.T) {
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap)
|
||||||
|
|
||||||
if len(acl.rulesPairs) != 2 {
|
if len(acl.peerRulesPairs) != 2 {
|
||||||
t.Errorf("firewall rules not applied: %v", acl.rulesPairs)
|
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add extra rules", func(t *testing.T) {
|
t.Run("add extra rules", func(t *testing.T) {
|
||||||
existedPairs := map[string]struct{}{}
|
existedPairs := map[string]struct{}{}
|
||||||
for id := range acl.rulesPairs {
|
for id := range acl.peerRulesPairs {
|
||||||
existedPairs[id] = struct{}{}
|
existedPairs[id.GetRuleID()] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove first rule
|
// remove first rule
|
||||||
@ -83,24 +83,24 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
networkMap.FirewallRules,
|
networkMap.FirewallRules,
|
||||||
&mgmProto.FirewallRule{
|
&mgmProto.FirewallRule{
|
||||||
PeerIP: "10.93.0.3",
|
PeerIP: "10.93.0.3",
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
Action: mgmProto.FirewallRule_DROP,
|
Action: mgmProto.RuleAction_DROP,
|
||||||
Protocol: mgmProto.FirewallRule_ICMP,
|
Protocol: mgmProto.RuleProtocol_ICMP,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap)
|
||||||
|
|
||||||
// we should have one old and one new rule in the existed rules
|
// we should have one old and one new rule in the existed rules
|
||||||
if len(acl.rulesPairs) != 2 {
|
if len(acl.peerRulesPairs) != 2 {
|
||||||
t.Errorf("firewall rules not applied")
|
t.Errorf("firewall rules not applied")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// check that old rule was removed
|
// check that old rule was removed
|
||||||
previousCount := 0
|
previousCount := 0
|
||||||
for id := range acl.rulesPairs {
|
for id := range acl.peerRulesPairs {
|
||||||
if _, ok := existedPairs[id]; ok {
|
if _, ok := existedPairs[id.GetRuleID()]; ok {
|
||||||
previousCount++
|
previousCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -113,15 +113,15 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = true
|
networkMap.FirewallRulesIsEmpty = true
|
||||||
if acl.ApplyFiltering(networkMap); len(acl.rulesPairs) != 0 {
|
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 {
|
||||||
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.rulesPairs))
|
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = false
|
networkMap.FirewallRulesIsEmpty = false
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap)
|
||||||
if len(acl.rulesPairs) != 2 {
|
if len(acl.peerRulesPairs) != 2 {
|
||||||
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.rulesPairs))
|
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -138,51 +138,51 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
|||||||
FirewallRules: []*mgmProto.FirewallRule{
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.1",
|
PeerIP: "10.93.0.1",
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.2",
|
PeerIP: "10.93.0.2",
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.3",
|
PeerIP: "10.93.0.3",
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.4",
|
PeerIP: "10.93.0.4",
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.1",
|
PeerIP: "10.93.0.1",
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.2",
|
PeerIP: "10.93.0.2",
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.3",
|
PeerIP: "10.93.0.3",
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.4",
|
PeerIP: "10.93.0.4",
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -199,13 +199,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
|||||||
case r.PeerIP != "0.0.0.0":
|
case r.PeerIP != "0.0.0.0":
|
||||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||||
return
|
return
|
||||||
case r.Direction != mgmProto.FirewallRule_IN:
|
case r.Direction != mgmProto.RuleDirection_IN:
|
||||||
t.Errorf("direction should be IN, got: %v", r.Direction)
|
t.Errorf("direction should be IN, got: %v", r.Direction)
|
||||||
return
|
return
|
||||||
case r.Protocol != mgmProto.FirewallRule_ALL:
|
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
||||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||||
return
|
return
|
||||||
case r.Action != mgmProto.FirewallRule_ACCEPT:
|
case r.Action != mgmProto.RuleAction_ACCEPT:
|
||||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -215,13 +215,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
|||||||
case r.PeerIP != "0.0.0.0":
|
case r.PeerIP != "0.0.0.0":
|
||||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
||||||
return
|
return
|
||||||
case r.Direction != mgmProto.FirewallRule_OUT:
|
case r.Direction != mgmProto.RuleDirection_OUT:
|
||||||
t.Errorf("direction should be OUT, got: %v", r.Direction)
|
t.Errorf("direction should be OUT, got: %v", r.Direction)
|
||||||
return
|
return
|
||||||
case r.Protocol != mgmProto.FirewallRule_ALL:
|
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
||||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
||||||
return
|
return
|
||||||
case r.Action != mgmProto.FirewallRule_ACCEPT:
|
case r.Action != mgmProto.RuleAction_ACCEPT:
|
||||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -238,51 +238,51 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
|||||||
FirewallRules: []*mgmProto.FirewallRule{
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.1",
|
PeerIP: "10.93.0.1",
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.2",
|
PeerIP: "10.93.0.2",
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.3",
|
PeerIP: "10.93.0.3",
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.4",
|
PeerIP: "10.93.0.4",
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.1",
|
PeerIP: "10.93.0.1",
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.2",
|
PeerIP: "10.93.0.2",
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.3",
|
PeerIP: "10.93.0.3",
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_ALL,
|
Protocol: mgmProto.RuleProtocol_ALL,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.4",
|
PeerIP: "10.93.0.4",
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_UDP,
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -308,21 +308,21 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
FirewallRules: []*mgmProto.FirewallRule{
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.1",
|
PeerIP: "10.93.0.1",
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.2",
|
PeerIP: "10.93.0.2",
|
||||||
Direction: mgmProto.FirewallRule_IN,
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_TCP,
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "10.93.0.3",
|
PeerIP: "10.93.0.3",
|
||||||
Direction: mgmProto.FirewallRule_OUT,
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
Action: mgmProto.FirewallRule_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: mgmProto.FirewallRule_UDP,
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -357,8 +357,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap)
|
||||||
|
|
||||||
if len(acl.rulesPairs) != 4 {
|
if len(acl.peerRulesPairs) != 4 {
|
||||||
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.rulesPairs))
|
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -704,6 +704,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply ACLs in the beginning to avoid security leaks
|
||||||
|
if e.acl != nil {
|
||||||
|
e.acl.ApplyFiltering(networkMap)
|
||||||
|
}
|
||||||
|
|
||||||
protoRoutes := networkMap.GetRoutes()
|
protoRoutes := networkMap.GetRoutes()
|
||||||
if protoRoutes == nil {
|
if protoRoutes == nil {
|
||||||
protoRoutes = []*mgmProto.Route{}
|
protoRoutes = []*mgmProto.Route{}
|
||||||
@ -770,10 +775,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
log.Errorf("failed to update dns server, err: %v", err)
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.acl != nil {
|
|
||||||
e.acl.ApplyFiltering(networkMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.networkSerial = serial
|
e.networkSerial = serial
|
||||||
|
|
||||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||||
|
@ -303,7 +303,7 @@ func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]neti
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil {
|
if _, err := r.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
|
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -87,10 +87,10 @@ func NewManager(
|
|||||||
}
|
}
|
||||||
|
|
||||||
dm.routeRefCounter = refcounter.New(
|
dm.routeRefCounter = refcounter.New(
|
||||||
func(prefix netip.Prefix, _ any) (any, error) {
|
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||||
return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
|
return struct{}{}, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, _ any) error {
|
func(prefix netip.Prefix, _ struct{}) error {
|
||||||
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
|
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -3,7 +3,8 @@ package refcounter
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
@ -12,118 +13,153 @@ import (
|
|||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix.
|
const logLevel = log.TraceLevel
|
||||||
|
|
||||||
|
// ErrIgnore can be returned by AddFunc to indicate that the counter should not be incremented for the given key.
|
||||||
var ErrIgnore = errors.New("ignore")
|
var ErrIgnore = errors.New("ignore")
|
||||||
|
|
||||||
|
// Ref holds the reference count and associated data for a key.
|
||||||
type Ref[O any] struct {
|
type Ref[O any] struct {
|
||||||
Count int
|
Count int
|
||||||
Out O
|
Out O
|
||||||
}
|
}
|
||||||
|
|
||||||
type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error)
|
// AddFunc is the function type for adding a new key.
|
||||||
type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error
|
// Key is the type of the key (e.g., netip.Prefix).
|
||||||
|
type AddFunc[Key, I, O any] func(key Key, in I) (out O, err error)
|
||||||
|
|
||||||
type Counter[I, O any] struct {
|
// RemoveFunc is the function type for removing a key.
|
||||||
// refCountMap keeps track of the reference Ref for prefixes
|
type RemoveFunc[Key, O any] func(key Key, out O) error
|
||||||
refCountMap map[netip.Prefix]Ref[O]
|
|
||||||
|
// Counter is a generic reference counter for managing keys and their associated data.
|
||||||
|
// Key: The type of the key (e.g., netip.Prefix, string).
|
||||||
|
//
|
||||||
|
// I: The input type for the AddFunc. It is the input type for additional data needed
|
||||||
|
// when adding a key, it is passed as the second argument to AddFunc.
|
||||||
|
//
|
||||||
|
// O: The output type for the AddFunc and RemoveFunc. This is the output returned by AddFunc.
|
||||||
|
// It is stored and passed to RemoveFunc when the reference count reaches 0.
|
||||||
|
//
|
||||||
|
// The types can be aliased to a specific type using the following syntax:
|
||||||
|
//
|
||||||
|
// type RouteRefCounter = Counter[netip.Prefix, any, any]
|
||||||
|
type Counter[Key comparable, I, O any] struct {
|
||||||
|
// refCountMap keeps track of the reference Ref for keys
|
||||||
|
refCountMap map[Key]Ref[O]
|
||||||
refCountMu sync.Mutex
|
refCountMu sync.Mutex
|
||||||
// idMap keeps track of the prefixes associated with an ID for removal
|
// idMap keeps track of the keys associated with an ID for removal
|
||||||
idMap map[string][]netip.Prefix
|
idMap map[string][]Key
|
||||||
idMu sync.Mutex
|
idMu sync.Mutex
|
||||||
add AddFunc[I, O]
|
add AddFunc[Key, I, O]
|
||||||
remove RemoveFunc[I, O]
|
remove RemoveFunc[Key, O]
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Counter instance
|
// New creates a new Counter instance.
|
||||||
func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] {
|
// Usage example:
|
||||||
return &Counter[I, O]{
|
//
|
||||||
refCountMap: map[netip.Prefix]Ref[O]{},
|
// counter := New[netip.Prefix, string, string](
|
||||||
idMap: map[string][]netip.Prefix{},
|
// func(key netip.Prefix, in string) (out string, err error) { ... },
|
||||||
|
// func(key netip.Prefix, out string) error { ... },`
|
||||||
|
// )
|
||||||
|
func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key, O]) *Counter[Key, I, O] {
|
||||||
|
return &Counter[Key, I, O]{
|
||||||
|
refCountMap: map[Key]Ref[O]{},
|
||||||
|
idMap: map[string][]Key{},
|
||||||
add: add,
|
add: add,
|
||||||
remove: remove,
|
remove: remove,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment increments the reference count for the given prefix.
|
// Get retrieves the current reference count and associated data for a key.
|
||||||
// If this is the first reference to the prefix, the AddFunc is called.
|
// If the key doesn't exist, it returns a zero value Ref and false.
|
||||||
func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) {
|
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
||||||
rm.refCountMu.Lock()
|
rm.refCountMu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.refCountMu.Unlock()
|
||||||
|
|
||||||
ref := rm.refCountMap[prefix]
|
ref, ok := rm.refCountMap[key]
|
||||||
log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
|
return ref, ok
|
||||||
|
}
|
||||||
|
|
||||||
// Call AddFunc only if it's a new prefix
|
// Increment increments the reference count for the given key.
|
||||||
|
// If this is the first reference to the key, the AddFunc is called.
|
||||||
|
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
||||||
|
rm.refCountMu.Lock()
|
||||||
|
defer rm.refCountMu.Unlock()
|
||||||
|
|
||||||
|
ref := rm.refCountMap[key]
|
||||||
|
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
|
||||||
|
|
||||||
|
// Call AddFunc only if it's a new key
|
||||||
if ref.Count == 0 {
|
if ref.Count == 0 {
|
||||||
log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out)
|
logCallerF("Calling add for key %v", key)
|
||||||
out, err := rm.add(prefix, in)
|
out, err := rm.add(key, in)
|
||||||
|
|
||||||
if errors.Is(err, ErrIgnore) {
|
if errors.Is(err, ErrIgnore) {
|
||||||
return ref, nil
|
return ref, nil
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err)
|
return ref, fmt.Errorf("failed to add for key %v: %w", key, err)
|
||||||
}
|
}
|
||||||
ref.Out = out
|
ref.Out = out
|
||||||
}
|
}
|
||||||
|
|
||||||
ref.Count++
|
ref.Count++
|
||||||
rm.refCountMap[prefix] = ref
|
rm.refCountMap[key] = ref
|
||||||
|
|
||||||
return ref, nil
|
return ref, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IncrementWithID increments the reference count for the given prefix and groups it under the given ID.
|
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
|
||||||
// If this is the first reference to the prefix, the AddFunc is called.
|
// If this is the first reference to the key, the AddFunc is called.
|
||||||
func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) {
|
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
|
||||||
rm.idMu.Lock()
|
rm.idMu.Lock()
|
||||||
defer rm.idMu.Unlock()
|
defer rm.idMu.Unlock()
|
||||||
|
|
||||||
ref, err := rm.Increment(prefix, in)
|
ref, err := rm.Increment(key, in)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ref, fmt.Errorf("with ID: %w", err)
|
return ref, fmt.Errorf("with ID: %w", err)
|
||||||
}
|
}
|
||||||
rm.idMap[id] = append(rm.idMap[id], prefix)
|
rm.idMap[id] = append(rm.idMap[id], key)
|
||||||
|
|
||||||
return ref, nil
|
return ref, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrement decrements the reference count for the given prefix.
|
// Decrement decrements the reference count for the given key.
|
||||||
// If the reference count reaches 0, the RemoveFunc is called.
|
// If the reference count reaches 0, the RemoveFunc is called.
|
||||||
func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) {
|
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
||||||
rm.refCountMu.Lock()
|
rm.refCountMu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.refCountMu.Unlock()
|
||||||
|
|
||||||
ref, ok := rm.refCountMap[prefix]
|
ref, ok := rm.refCountMap[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Tracef("No reference found for prefix %s", prefix)
|
logCallerF("No reference found for key %v", key)
|
||||||
return ref, nil
|
return ref, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
|
logCallerF("Decreasing ref count [%d -> %d] for key %v with Out [%v]", ref.Count, ref.Count-1, key, ref.Out)
|
||||||
if ref.Count == 1 {
|
if ref.Count == 1 {
|
||||||
log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out)
|
logCallerF("Calling remove for key %v", key)
|
||||||
if err := rm.remove(prefix, ref.Out); err != nil {
|
if err := rm.remove(key, ref.Out); err != nil {
|
||||||
return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err)
|
return ref, fmt.Errorf("remove for key %v: %w", key, err)
|
||||||
}
|
}
|
||||||
delete(rm.refCountMap, prefix)
|
delete(rm.refCountMap, key)
|
||||||
} else {
|
} else {
|
||||||
ref.Count--
|
ref.Count--
|
||||||
rm.refCountMap[prefix] = ref
|
rm.refCountMap[key] = ref
|
||||||
}
|
}
|
||||||
|
|
||||||
return ref, nil
|
return ref, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecrementWithID decrements the reference count for all prefixes associated with the given ID.
|
// DecrementWithID decrements the reference count for all keys associated with the given ID.
|
||||||
// If the reference count reaches 0, the RemoveFunc is called.
|
// If the reference count reaches 0, the RemoveFunc is called.
|
||||||
func (rm *Counter[I, O]) DecrementWithID(id string) error {
|
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
||||||
rm.idMu.Lock()
|
rm.idMu.Lock()
|
||||||
defer rm.idMu.Unlock()
|
defer rm.idMu.Unlock()
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for _, prefix := range rm.idMap[id] {
|
for _, key := range rm.idMap[id] {
|
||||||
if _, err := rm.Decrement(prefix); err != nil {
|
if _, err := rm.Decrement(key); err != nil {
|
||||||
merr = multierror.Append(merr, err)
|
merr = multierror.Append(merr, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -132,24 +168,77 @@ func (rm *Counter[I, O]) DecrementWithID(id string) error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush removes all references and calls RemoveFunc for each prefix.
|
// Flush removes all references and calls RemoveFunc for each key.
|
||||||
func (rm *Counter[I, O]) Flush() error {
|
func (rm *Counter[Key, I, O]) Flush() error {
|
||||||
rm.refCountMu.Lock()
|
rm.refCountMu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.refCountMu.Unlock()
|
||||||
rm.idMu.Lock()
|
rm.idMu.Lock()
|
||||||
defer rm.idMu.Unlock()
|
defer rm.idMu.Unlock()
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for prefix := range rm.refCountMap {
|
for key := range rm.refCountMap {
|
||||||
log.Tracef("Removing for prefix %s", prefix)
|
logCallerF("Calling remove for key %v", key)
|
||||||
ref := rm.refCountMap[prefix]
|
ref := rm.refCountMap[key]
|
||||||
if err := rm.remove(prefix, ref.Out); err != nil {
|
if err := rm.remove(key, ref.Out); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err))
|
merr = multierror.Append(merr, fmt.Errorf("remove for key %v: %w", key, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rm.refCountMap = map[netip.Prefix]Ref[O]{}
|
|
||||||
|
|
||||||
rm.idMap = map[string][]netip.Prefix{}
|
clear(rm.refCountMap)
|
||||||
|
clear(rm.idMap)
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear removes all references without calling RemoveFunc.
|
||||||
|
func (rm *Counter[Key, I, O]) Clear() {
|
||||||
|
rm.refCountMu.Lock()
|
||||||
|
defer rm.refCountMu.Unlock()
|
||||||
|
rm.idMu.Lock()
|
||||||
|
defer rm.idMu.Unlock()
|
||||||
|
|
||||||
|
clear(rm.refCountMap)
|
||||||
|
clear(rm.idMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCallerInfo(depth int, maxDepth int) (string, bool) {
|
||||||
|
if depth >= maxDepth {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
pc, _, _, ok := runtime.Caller(depth)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
if details := runtime.FuncForPC(pc); details != nil {
|
||||||
|
name := details.Name()
|
||||||
|
|
||||||
|
lastDotIndex := strings.LastIndex(name, "/")
|
||||||
|
if lastDotIndex != -1 {
|
||||||
|
name = name[lastDotIndex+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(name, "refcounter.") {
|
||||||
|
// +2 to account for recursion
|
||||||
|
return getCallerInfo(depth+2, maxDepth)
|
||||||
|
}
|
||||||
|
|
||||||
|
return name, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// logCaller logs a message with the package name and method of the function that called the current function.
|
||||||
|
func logCallerF(format string, args ...interface{}) {
|
||||||
|
if log.GetLevel() < logLevel {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if callerName, ok := getCallerInfo(3, 18); ok {
|
||||||
|
format = fmt.Sprintf("[%s] %s", callerName, format)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.StandardLogger().Logf(logLevel, format, args...)
|
||||||
|
}
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
package refcounter
|
package refcounter
|
||||||
|
|
||||||
|
import "net/netip"
|
||||||
|
|
||||||
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
|
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
|
||||||
type RouteRefCounter = Counter[any, any]
|
type RouteRefCounter = Counter[netip.Prefix, struct{}, struct{}]
|
||||||
|
|
||||||
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
|
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
|
||||||
type AllowedIPsRefCounter = Counter[string, string]
|
type AllowedIPsRefCounter = Counter[netip.Prefix, string, string]
|
||||||
|
@ -94,7 +94,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
|
|||||||
return fmt.Errorf("parse prefix: %w", err)
|
return fmt.Errorf("parse prefix: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.firewall.RemoveRoutingRules(routerPair)
|
err = m.firewall.RemoveNatRule(routerPair)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("remove routing rules: %w", err)
|
return fmt.Errorf("remove routing rules: %w", err)
|
||||||
}
|
}
|
||||||
@ -123,7 +123,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
|||||||
return fmt.Errorf("parse prefix: %w", err)
|
return fmt.Errorf("parse prefix: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.firewall.InsertRoutingRules(routerPair)
|
err = m.firewall.AddNatRule(routerPair)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("insert routing rules: %w", err)
|
return fmt.Errorf("insert routing rules: %w", err)
|
||||||
}
|
}
|
||||||
@ -157,7 +157,7 @@ func (m *defaultServerRouter) cleanUp() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.firewall.RemoveRoutingRules(routerPair)
|
err = m.firewall.RemoveNatRule(routerPair)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to remove cleanup route: %v", err)
|
log.Errorf("Failed to remove cleanup route: %v", err)
|
||||||
}
|
}
|
||||||
@ -173,15 +173,15 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
|
|||||||
// TODO: add ipv6
|
// TODO: add ipv6
|
||||||
source := getDefaultPrefix(route.Network)
|
source := getDefaultPrefix(route.Network)
|
||||||
|
|
||||||
destination := route.Network.Masked().String()
|
destination := route.Network.Masked()
|
||||||
if route.IsDynamic() {
|
if route.IsDynamic() {
|
||||||
// TODO: add ipv6
|
// TODO: add ipv6 additionally
|
||||||
destination = "0.0.0.0/0"
|
destination = getDefaultPrefix(destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
return firewall.RouterPair{
|
return firewall.RouterPair{
|
||||||
ID: string(route.ID),
|
ID: route.ID,
|
||||||
Source: source.String(),
|
Source: source,
|
||||||
Destination: destination,
|
Destination: destination,
|
||||||
Masquerade: route.Masquerade,
|
Masquerade: route.Masquerade,
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -30,7 +30,7 @@ func (r *Route) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) AddRoute(context.Context) error {
|
func (r *Route) AddRoute(context.Context) error {
|
||||||
_, err := r.routeRefCounter.Increment(r.route.Network, nil)
|
_, err := r.routeRefCounter.Increment(r.route.Network, struct{}{})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ type Nexthop struct {
|
|||||||
Intf *net.Interface
|
Intf *net.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
type ExclusionCounter = refcounter.Counter[any, Nexthop]
|
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
|
||||||
|
|
||||||
type SysOps struct {
|
type SysOps struct {
|
||||||
refCounter *ExclusionCounter
|
refCounter *ExclusionCounter
|
||||||
|
@ -41,7 +41,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
|
|||||||
}
|
}
|
||||||
|
|
||||||
refCounter := refcounter.New(
|
refCounter := refcounter.New(
|
||||||
func(prefix netip.Prefix, _ any) (Nexthop, error) {
|
func(prefix netip.Prefix, _ struct{}) (Nexthop, error) {
|
||||||
initialNexthop := initialNextHopV4
|
initialNexthop := initialNextHopV4
|
||||||
if prefix.Addr().Is6() {
|
if prefix.Addr().Is6() {
|
||||||
initialNexthop = initialNextHopV6
|
initialNexthop = initialNextHopV6
|
||||||
@ -317,7 +317,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
|
|||||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil {
|
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
||||||
return fmt.Errorf("adding route reference: %v", err)
|
return fmt.Errorf("adding route reference: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -254,6 +254,12 @@ message NetworkMap {
|
|||||||
|
|
||||||
// firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality.
|
// firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality.
|
||||||
bool firewallRulesIsEmpty = 9;
|
bool firewallRulesIsEmpty = 9;
|
||||||
|
|
||||||
|
// RoutesFirewallRules represents a list of routes firewall rules to be applied to peer
|
||||||
|
repeated RouteFirewallRule routesFirewallRules = 10;
|
||||||
|
|
||||||
|
// RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality.
|
||||||
|
bool routesFirewallRulesIsEmpty = 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemotePeerConfig represents a configuration of a remote peer.
|
// RemotePeerConfig represents a configuration of a remote peer.
|
||||||
@ -384,29 +390,32 @@ message NameServer {
|
|||||||
int64 Port = 3;
|
int64 Port = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum RuleProtocol {
|
||||||
|
UNKNOWN = 0;
|
||||||
|
ALL = 1;
|
||||||
|
TCP = 2;
|
||||||
|
UDP = 3;
|
||||||
|
ICMP = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum RuleDirection {
|
||||||
|
IN = 0;
|
||||||
|
OUT = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum RuleAction {
|
||||||
|
ACCEPT = 0;
|
||||||
|
DROP = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// FirewallRule represents a firewall rule
|
// FirewallRule represents a firewall rule
|
||||||
message FirewallRule {
|
message FirewallRule {
|
||||||
string PeerIP = 1;
|
string PeerIP = 1;
|
||||||
direction Direction = 2;
|
RuleDirection Direction = 2;
|
||||||
action Action = 3;
|
RuleAction Action = 3;
|
||||||
protocol Protocol = 4;
|
RuleProtocol Protocol = 4;
|
||||||
string Port = 5;
|
string Port = 5;
|
||||||
|
|
||||||
enum direction {
|
|
||||||
IN = 0;
|
|
||||||
OUT = 1;
|
|
||||||
}
|
|
||||||
enum action {
|
|
||||||
ACCEPT = 0;
|
|
||||||
DROP = 1;
|
|
||||||
}
|
|
||||||
enum protocol {
|
|
||||||
UNKNOWN = 0;
|
|
||||||
ALL = 1;
|
|
||||||
TCP = 2;
|
|
||||||
UDP = 3;
|
|
||||||
ICMP = 4;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message NetworkAddress {
|
message NetworkAddress {
|
||||||
@ -415,5 +424,40 @@ message NetworkAddress {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message Checks {
|
message Checks {
|
||||||
repeated string Files= 1;
|
repeated string Files = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message PortInfo {
|
||||||
|
oneof portSelection {
|
||||||
|
uint32 port = 1;
|
||||||
|
Range range = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Range {
|
||||||
|
uint32 start = 1;
|
||||||
|
uint32 end = 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouteFirewallRule signifies a firewall rule applicable for a routed network.
|
||||||
|
message RouteFirewallRule {
|
||||||
|
// sourceRanges IP ranges of the routing peers.
|
||||||
|
repeated string sourceRanges = 1;
|
||||||
|
|
||||||
|
// Action to be taken by the firewall when the rule is applicable.
|
||||||
|
RuleAction action = 2;
|
||||||
|
|
||||||
|
// Network prefix for the routed network.
|
||||||
|
string destination = 3;
|
||||||
|
|
||||||
|
// Protocol of the routed network.
|
||||||
|
RuleProtocol protocol = 4;
|
||||||
|
|
||||||
|
// Details about the port.
|
||||||
|
PortInfo portInfo = 5;
|
||||||
|
|
||||||
|
// IsDynamic indicates if the route is a DNS route.
|
||||||
|
bool isDynamic = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ type AccountManager interface {
|
|||||||
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
||||||
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
||||||
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||||
SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error
|
SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error
|
||||||
DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
||||||
ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error)
|
ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error)
|
||||||
@ -460,6 +460,7 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
}
|
}
|
||||||
|
|
||||||
routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect)
|
routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect)
|
||||||
|
routesFirewallRules := a.getPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
|
||||||
|
|
||||||
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
|
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
|
||||||
dnsUpdate := nbdns.Config{
|
dnsUpdate := nbdns.Config{
|
||||||
@ -483,6 +484,7 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
DNSConfig: dnsUpdate,
|
DNSConfig: dnsUpdate,
|
||||||
OfflinePeers: expiredPeers,
|
OfflinePeers: expiredPeers,
|
||||||
FirewallRules: firewallRules,
|
FirewallRules: firewallRules,
|
||||||
|
RoutesFirewallRules: routesFirewallRules,
|
||||||
}
|
}
|
||||||
|
|
||||||
if metrics != nil {
|
if metrics != nil {
|
||||||
|
@ -1599,9 +1599,10 @@ func TestAccount_Copy(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Routes: map[route.ID]*route.Route{
|
Routes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
PeerGroups: []string{},
|
PeerGroups: []string{},
|
||||||
Groups: []string{"group1"},
|
Groups: []string{"group1"},
|
||||||
|
AccessControlGroups: []string{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
||||||
|
@ -596,6 +596,10 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
|
|||||||
response.NetworkMap.FirewallRules = firewallRules
|
response.NetworkMap.FirewallRules = firewallRules
|
||||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||||
|
|
||||||
|
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||||
|
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
|
||||||
|
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
||||||
|
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -727,17 +727,39 @@ components:
|
|||||||
enum: ["all", "tcp", "udp", "icmp"]
|
enum: ["all", "tcp", "udp", "icmp"]
|
||||||
example: "tcp"
|
example: "tcp"
|
||||||
ports:
|
ports:
|
||||||
description: Policy rule affected ports or it ranges list
|
description: Policy rule affected ports
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
example: "80"
|
example: "80"
|
||||||
|
port_ranges:
|
||||||
|
description: Policy rule affected ports ranges list
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/RulePortRange'
|
||||||
required:
|
required:
|
||||||
- name
|
- name
|
||||||
- enabled
|
- enabled
|
||||||
- bidirectional
|
- bidirectional
|
||||||
- protocol
|
- protocol
|
||||||
- action
|
- action
|
||||||
|
|
||||||
|
RulePortRange:
|
||||||
|
description: Policy rule affected ports range
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
start:
|
||||||
|
description: The starting port of the range
|
||||||
|
type: integer
|
||||||
|
example: 80
|
||||||
|
end:
|
||||||
|
description: The ending port of the range
|
||||||
|
type: integer
|
||||||
|
example: 320
|
||||||
|
required:
|
||||||
|
- start
|
||||||
|
- end
|
||||||
|
|
||||||
PolicyRuleUpdate:
|
PolicyRuleUpdate:
|
||||||
allOf:
|
allOf:
|
||||||
- $ref: '#/components/schemas/PolicyRuleMinimum'
|
- $ref: '#/components/schemas/PolicyRuleMinimum'
|
||||||
@ -1106,6 +1128,12 @@ components:
|
|||||||
description: Indicate if the route should be kept after a domain doesn't resolve that IP anymore
|
description: Indicate if the route should be kept after a domain doesn't resolve that IP anymore
|
||||||
type: boolean
|
type: boolean
|
||||||
example: true
|
example: true
|
||||||
|
access_control_groups:
|
||||||
|
description: Access control group identifier associated with route.
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
example: "chacbco6lnnbn6cg5s91"
|
||||||
required:
|
required:
|
||||||
- id
|
- id
|
||||||
- description
|
- description
|
||||||
|
@ -780,7 +780,10 @@ type PolicyRule struct {
|
|||||||
// Name Policy rule name identifier
|
// Name Policy rule name identifier
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|
||||||
// Ports Policy rule affected ports or it ranges list
|
// PortRanges Policy rule affected ports ranges list
|
||||||
|
PortRanges *[]RulePortRange `json:"port_ranges,omitempty"`
|
||||||
|
|
||||||
|
// Ports Policy rule affected ports
|
||||||
Ports *[]string `json:"ports,omitempty"`
|
Ports *[]string `json:"ports,omitempty"`
|
||||||
|
|
||||||
// Protocol Policy rule type of the traffic
|
// Protocol Policy rule type of the traffic
|
||||||
@ -816,7 +819,10 @@ type PolicyRuleMinimum struct {
|
|||||||
// Name Policy rule name identifier
|
// Name Policy rule name identifier
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|
||||||
// Ports Policy rule affected ports or it ranges list
|
// PortRanges Policy rule affected ports ranges list
|
||||||
|
PortRanges *[]RulePortRange `json:"port_ranges,omitempty"`
|
||||||
|
|
||||||
|
// Ports Policy rule affected ports
|
||||||
Ports *[]string `json:"ports,omitempty"`
|
Ports *[]string `json:"ports,omitempty"`
|
||||||
|
|
||||||
// Protocol Policy rule type of the traffic
|
// Protocol Policy rule type of the traffic
|
||||||
@ -852,7 +858,10 @@ type PolicyRuleUpdate struct {
|
|||||||
// Name Policy rule name identifier
|
// Name Policy rule name identifier
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|
||||||
// Ports Policy rule affected ports or it ranges list
|
// PortRanges Policy rule affected ports ranges list
|
||||||
|
PortRanges *[]RulePortRange `json:"port_ranges,omitempty"`
|
||||||
|
|
||||||
|
// Ports Policy rule affected ports
|
||||||
Ports *[]string `json:"ports,omitempty"`
|
Ports *[]string `json:"ports,omitempty"`
|
||||||
|
|
||||||
// Protocol Policy rule type of the traffic
|
// Protocol Policy rule type of the traffic
|
||||||
@ -935,6 +944,9 @@ type ProcessCheck struct {
|
|||||||
|
|
||||||
// Route defines model for Route.
|
// Route defines model for Route.
|
||||||
type Route struct {
|
type Route struct {
|
||||||
|
// AccessControlGroups Access control group identifier associated with route.
|
||||||
|
AccessControlGroups *[]string `json:"access_control_groups,omitempty"`
|
||||||
|
|
||||||
// Description Route description
|
// Description Route description
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
|
|
||||||
@ -977,6 +989,9 @@ type Route struct {
|
|||||||
|
|
||||||
// RouteRequest defines model for RouteRequest.
|
// RouteRequest defines model for RouteRequest.
|
||||||
type RouteRequest struct {
|
type RouteRequest struct {
|
||||||
|
// AccessControlGroups Access control group identifier associated with route.
|
||||||
|
AccessControlGroups *[]string `json:"access_control_groups,omitempty"`
|
||||||
|
|
||||||
// Description Route description
|
// Description Route description
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
|
|
||||||
@ -1011,6 +1026,15 @@ type RouteRequest struct {
|
|||||||
PeerGroups *[]string `json:"peer_groups,omitempty"`
|
PeerGroups *[]string `json:"peer_groups,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RulePortRange Policy rule affected ports range
|
||||||
|
type RulePortRange struct {
|
||||||
|
// End The ending port of the range
|
||||||
|
End int `json:"end"`
|
||||||
|
|
||||||
|
// Start The starting port of the range
|
||||||
|
Start int `json:"start"`
|
||||||
|
}
|
||||||
|
|
||||||
// SetupKey defines model for SetupKey.
|
// SetupKey defines model for SetupKey.
|
||||||
type SetupKey struct {
|
type SetupKey struct {
|
||||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||||
|
@ -172,6 +172,11 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (rule.Ports != nil && len(*rule.Ports) != 0) && (rule.PortRanges != nil && len(*rule.PortRanges) != 0) {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either individual ports or port ranges, not both"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if rule.Ports != nil && len(*rule.Ports) != 0 {
|
if rule.Ports != nil && len(*rule.Ports) != 0 {
|
||||||
for _, v := range *rule.Ports {
|
for _, v := range *rule.Ports {
|
||||||
if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 {
|
if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 {
|
||||||
@ -182,10 +187,23 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rule.PortRanges != nil && len(*rule.PortRanges) != 0 {
|
||||||
|
for _, portRange := range *rule.PortRanges {
|
||||||
|
if portRange.Start < 1 || portRange.End > 65535 {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pr.PortRanges = append(pr.PortRanges, server.RulePortRange{
|
||||||
|
Start: uint16(portRange.Start),
|
||||||
|
End: uint16(portRange.End),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// validate policy object
|
// validate policy object
|
||||||
switch pr.Protocol {
|
switch pr.Protocol {
|
||||||
case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP:
|
case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP:
|
||||||
if len(pr.Ports) != 0 {
|
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -194,7 +212,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP:
|
case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP:
|
||||||
if !pr.Bidirectional && len(pr.Ports) == 0 {
|
if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -320,6 +338,17 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic
|
|||||||
rule.Ports = &portsCopy
|
rule.Ports = &portsCopy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(r.PortRanges) != 0 {
|
||||||
|
portRanges := make([]api.RulePortRange, 0, len(r.PortRanges))
|
||||||
|
for _, portRange := range r.PortRanges {
|
||||||
|
portRanges = append(portRanges, api.RulePortRange{
|
||||||
|
End: int(portRange.End),
|
||||||
|
Start: int(portRange.Start),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
rule.PortRanges = &portRanges
|
||||||
|
}
|
||||||
|
|
||||||
for _, gid := range r.Sources {
|
for _, gid := range r.Sources {
|
||||||
_, ok := cache[gid]
|
_, ok := cache[gid]
|
||||||
if ok {
|
if ok {
|
||||||
|
@ -117,9 +117,14 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
peerGroupIds = *req.PeerGroups
|
peerGroupIds = *req.PeerGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var accessControlGroupIds []string
|
||||||
|
if req.AccessControlGroups != nil {
|
||||||
|
accessControlGroupIds = *req.AccessControlGroups
|
||||||
|
}
|
||||||
|
|
||||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
|
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
|
||||||
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute,
|
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute)
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -233,6 +238,10 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
newRoute.PeerGroups = *req.PeerGroups
|
newRoute.PeerGroups = *req.PeerGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.AccessControlGroups != nil {
|
||||||
|
newRoute.AccessControlGroups = *req.AccessControlGroups
|
||||||
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
|
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
@ -326,6 +335,9 @@ func toRouteResponse(serverRoute *route.Route) (*api.Route, error) {
|
|||||||
if len(serverRoute.PeerGroups) > 0 {
|
if len(serverRoute.PeerGroups) > 0 {
|
||||||
route.PeerGroups = &serverRoute.PeerGroups
|
route.PeerGroups = &serverRoute.PeerGroups
|
||||||
}
|
}
|
||||||
|
if len(serverRoute.AccessControlGroups) > 0 {
|
||||||
|
route.AccessControlGroups = &serverRoute.AccessControlGroups
|
||||||
|
}
|
||||||
return route, nil
|
return route, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,7 +105,7 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
}
|
}
|
||||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||||
},
|
},
|
||||||
CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
|
CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
|
||||||
if peerID == notFoundPeerID {
|
if peerID == notFoundPeerID {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||||
}
|
}
|
||||||
@ -119,18 +119,19 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &route.Route{
|
return &route.Route{
|
||||||
ID: existingRouteID,
|
ID: existingRouteID,
|
||||||
NetID: netID,
|
NetID: netID,
|
||||||
Peer: peerID,
|
Peer: peerID,
|
||||||
PeerGroups: peerGroups,
|
PeerGroups: peerGroups,
|
||||||
Network: prefix,
|
Network: prefix,
|
||||||
Domains: domains,
|
Domains: domains,
|
||||||
NetworkType: networkType,
|
NetworkType: networkType,
|
||||||
Description: description,
|
Description: description,
|
||||||
Masquerade: masquerade,
|
Masquerade: masquerade,
|
||||||
Enabled: enabled,
|
Enabled: enabled,
|
||||||
Groups: groups,
|
Groups: groups,
|
||||||
KeepRoute: keepRoute,
|
KeepRoute: keepRoute,
|
||||||
|
AccessControlGroups: accessControlGroups,
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error {
|
SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error {
|
||||||
@ -268,6 +269,27 @@ func TestRoutesHandlers(t *testing.T) {
|
|||||||
Groups: []string{existingGroupID},
|
Groups: []string{existingGroupID},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "POST OK With Access Control Groups",
|
||||||
|
requestType: http.MethodPost,
|
||||||
|
requestPath: "/api/routes",
|
||||||
|
requestBody: bytes.NewBuffer(
|
||||||
|
[]byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))),
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: true,
|
||||||
|
expectedRoute: &api.Route{
|
||||||
|
Id: existingRouteID,
|
||||||
|
Description: "Post",
|
||||||
|
NetworkId: "awesomeNet",
|
||||||
|
Network: toPtr("192.168.0.0/16"),
|
||||||
|
Peer: &existingPeerID,
|
||||||
|
NetworkType: route.IPv4NetworkString,
|
||||||
|
Masquerade: false,
|
||||||
|
Enabled: false,
|
||||||
|
Groups: []string{existingGroupID},
|
||||||
|
AccessControlGroups: &[]string{existingGroupID},
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "POST Non Linux Peer",
|
name: "POST Non Linux Peer",
|
||||||
requestType: http.MethodPost,
|
requestType: http.MethodPost,
|
||||||
|
@ -58,7 +58,7 @@ type MockAccountManager struct {
|
|||||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||||
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
|
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
|
||||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||||
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
|
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
|
||||||
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
||||||
@ -367,7 +367,7 @@ func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID,
|
|||||||
if am.DeleteRuleFunc != nil {
|
if am.DeleteRuleFunc != nil {
|
||||||
return am.DeleteRuleFunc(ctx, accountID, ruleID, userID)
|
return am.DeleteRuleFunc(ctx, accountID, ruleID, userID)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method DeleteRule is not implemented")
|
return status.Errorf(codes.Unimplemented, "method DeletePeerRule is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPolicy mock implementation of GetPolicy from server.AccountManager interface
|
// GetPolicy mock implementation of GetPolicy from server.AccountManager interface
|
||||||
@ -442,9 +442,9 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
||||||
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||||
if am.CreateRouteFunc != nil {
|
if am.CreateRouteFunc != nil {
|
||||||
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute)
|
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
|
||||||
}
|
}
|
||||||
|
@ -26,12 +26,13 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type NetworkMap struct {
|
type NetworkMap struct {
|
||||||
Peers []*nbpeer.Peer
|
Peers []*nbpeer.Peer
|
||||||
Network *Network
|
Network *Network
|
||||||
Routes []*route.Route
|
Routes []*route.Route
|
||||||
DNSConfig nbdns.Config
|
DNSConfig nbdns.Config
|
||||||
OfflinePeers []*nbpeer.Peer
|
OfflinePeers []*nbpeer.Peer
|
||||||
FirewallRules []*FirewallRule
|
FirewallRules []*FirewallRule
|
||||||
|
RoutesFirewallRules []*RouteFirewallRule
|
||||||
}
|
}
|
||||||
|
|
||||||
type Network struct {
|
type Network struct {
|
||||||
|
@ -646,7 +646,6 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
|||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
|
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
|
||||||
@ -991,9 +990,9 @@ func TestToSyncResponse(t *testing.T) {
|
|||||||
// assert network map Firewall
|
// assert network map Firewall
|
||||||
assert.Equal(t, 1, len(response.NetworkMap.FirewallRules))
|
assert.Equal(t, 1, len(response.NetworkMap.FirewallRules))
|
||||||
assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP)
|
assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP)
|
||||||
assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction)
|
assert.Equal(t, proto.RuleDirection_IN, response.NetworkMap.FirewallRules[0].Direction)
|
||||||
assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action)
|
assert.Equal(t, proto.RuleAction_ACCEPT, response.NetworkMap.FirewallRules[0].Action)
|
||||||
assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol)
|
assert.Equal(t, proto.RuleProtocol_TCP, response.NetworkMap.FirewallRules[0].Protocol)
|
||||||
assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port)
|
assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port)
|
||||||
// assert posture checks
|
// assert posture checks
|
||||||
assert.Equal(t, 1, len(response.Checks))
|
assert.Equal(t, 1, len(response.Checks))
|
||||||
|
@ -76,6 +76,12 @@ type PolicyUpdateOperation struct {
|
|||||||
Values []string
|
Values []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RulePortRange represents a range of ports for a firewall rule.
|
||||||
|
type RulePortRange struct {
|
||||||
|
Start uint16
|
||||||
|
End uint16
|
||||||
|
}
|
||||||
|
|
||||||
// PolicyRule is the metadata of the policy
|
// PolicyRule is the metadata of the policy
|
||||||
type PolicyRule struct {
|
type PolicyRule struct {
|
||||||
// ID of the policy rule
|
// ID of the policy rule
|
||||||
@ -110,6 +116,9 @@ type PolicyRule struct {
|
|||||||
|
|
||||||
// Ports or it ranges list
|
// Ports or it ranges list
|
||||||
Ports []string `gorm:"serializer:json"`
|
Ports []string `gorm:"serializer:json"`
|
||||||
|
|
||||||
|
// PortRanges a list of port ranges.
|
||||||
|
PortRanges []RulePortRange `gorm:"serializer:json"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy returns a copy of a policy rule
|
// Copy returns a copy of a policy rule
|
||||||
@ -125,10 +134,12 @@ func (pm *PolicyRule) Copy() *PolicyRule {
|
|||||||
Bidirectional: pm.Bidirectional,
|
Bidirectional: pm.Bidirectional,
|
||||||
Protocol: pm.Protocol,
|
Protocol: pm.Protocol,
|
||||||
Ports: make([]string, len(pm.Ports)),
|
Ports: make([]string, len(pm.Ports)),
|
||||||
|
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
|
||||||
}
|
}
|
||||||
copy(rule.Destinations, pm.Destinations)
|
copy(rule.Destinations, pm.Destinations)
|
||||||
copy(rule.Sources, pm.Sources)
|
copy(rule.Sources, pm.Sources)
|
||||||
copy(rule.Ports, pm.Ports)
|
copy(rule.Ports, pm.Ports)
|
||||||
|
copy(rule.PortRanges, pm.PortRanges)
|
||||||
return rule
|
return rule
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -445,36 +456,17 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
||||||
result := make([]*proto.FirewallRule, len(update))
|
result := make([]*proto.FirewallRule, len(rules))
|
||||||
for i := range update {
|
for i := range rules {
|
||||||
direction := proto.FirewallRule_IN
|
rule := rules[i]
|
||||||
if update[i].Direction == firewallRuleDirectionOUT {
|
|
||||||
direction = proto.FirewallRule_OUT
|
|
||||||
}
|
|
||||||
action := proto.FirewallRule_ACCEPT
|
|
||||||
if update[i].Action == string(PolicyTrafficActionDrop) {
|
|
||||||
action = proto.FirewallRule_DROP
|
|
||||||
}
|
|
||||||
|
|
||||||
protocol := proto.FirewallRule_UNKNOWN
|
|
||||||
switch PolicyRuleProtocolType(update[i].Protocol) {
|
|
||||||
case PolicyRuleProtocolALL:
|
|
||||||
protocol = proto.FirewallRule_ALL
|
|
||||||
case PolicyRuleProtocolTCP:
|
|
||||||
protocol = proto.FirewallRule_TCP
|
|
||||||
case PolicyRuleProtocolUDP:
|
|
||||||
protocol = proto.FirewallRule_UDP
|
|
||||||
case PolicyRuleProtocolICMP:
|
|
||||||
protocol = proto.FirewallRule_ICMP
|
|
||||||
}
|
|
||||||
|
|
||||||
result[i] = &proto.FirewallRule{
|
result[i] = &proto.FirewallRule{
|
||||||
PeerIP: update[i].PeerIP,
|
PeerIP: rule.PeerIP,
|
||||||
Direction: direction,
|
Direction: getProtoDirection(rule.Direction),
|
||||||
Action: action,
|
Action: getProtoAction(rule.Action),
|
||||||
Protocol: protocol,
|
Protocol: getProtoProtocol(rule.Protocol),
|
||||||
Port: update[i].Port,
|
Port: rule.Port,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
@ -4,9 +4,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
@ -15,6 +21,30 @@ import (
|
|||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// RouteFirewallRule a firewall rule applicable for a routed network.
|
||||||
|
type RouteFirewallRule struct {
|
||||||
|
// SourceRanges IP ranges of the routing peers.
|
||||||
|
SourceRanges []string
|
||||||
|
|
||||||
|
// Action of the traffic when the rule is applicable
|
||||||
|
Action string
|
||||||
|
|
||||||
|
// Destination a network prefix for the routed traffic
|
||||||
|
Destination string
|
||||||
|
|
||||||
|
// Protocol of the traffic
|
||||||
|
Protocol string
|
||||||
|
|
||||||
|
// Port of the traffic
|
||||||
|
Port uint16
|
||||||
|
|
||||||
|
// PortRange represents the range of ports for a firewall rule
|
||||||
|
PortRange RulePortRange
|
||||||
|
|
||||||
|
// isDynamic indicates whether the rule is for DNS routing
|
||||||
|
IsDynamic bool
|
||||||
|
}
|
||||||
|
|
||||||
// GetRoute gets a route object from account and route IDs
|
// GetRoute gets a route object from account and route IDs
|
||||||
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
@ -112,7 +142,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateRoute creates and saves a new route
|
// CreateRoute creates and saves a new route
|
||||||
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@ -157,6 +187,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(accessControlGroupIDs) > 0 {
|
||||||
|
err = validateGroups(accessControlGroupIDs, account.Groups)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
|
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -187,6 +224,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
newRoute.Enabled = enabled
|
newRoute.Enabled = enabled
|
||||||
newRoute.Groups = groups
|
newRoute.Groups = groups
|
||||||
newRoute.KeepRoute = keepRoute
|
newRoute.KeepRoute = keepRoute
|
||||||
|
newRoute.AccessControlGroups = accessControlGroupIDs
|
||||||
|
|
||||||
if account.Routes == nil {
|
if account.Routes == nil {
|
||||||
account.Routes = make(map[route.ID]*route.Route)
|
account.Routes = make(map[route.ID]*route.Route)
|
||||||
@ -258,6 +296,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(routeToSave.AccessControlGroups) > 0 {
|
||||||
|
err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
|
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -351,3 +396,248 @@ func getPlaceholderIP() netip.Prefix {
|
|||||||
// Using an IP from the documentation range to minimize impact in case older clients try to set a route
|
// Using an IP from the documentation range to minimize impact in case older clients try to set a route
|
||||||
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
|
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
|
||||||
|
func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
|
||||||
|
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
|
||||||
|
|
||||||
|
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
|
||||||
|
for _, route := range enabledRoutes {
|
||||||
|
// If no access control groups are specified, accept all traffic.
|
||||||
|
if len(route.AccessControlGroups) == 0 {
|
||||||
|
defaultPermit := getDefaultPermit(route)
|
||||||
|
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
policies := getAllRoutePoliciesFromGroups(a, route.AccessControlGroups)
|
||||||
|
for _, policy := range policies {
|
||||||
|
if !policy.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range policy.Rules {
|
||||||
|
if !rule.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
distributionGroupPeers, _ := a.getAllPeersFromGroups(ctx, route.Groups, peerID, nil, validatedPeersMap)
|
||||||
|
rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN)
|
||||||
|
routesFirewallRules = append(routesFirewallRules, rules...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routesFirewallRules
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
|
||||||
|
var rules []*RouteFirewallRule
|
||||||
|
|
||||||
|
sources := []string{"0.0.0.0/0"}
|
||||||
|
if route.Network.Addr().Is6() {
|
||||||
|
sources = []string{"::/0"}
|
||||||
|
}
|
||||||
|
rule := RouteFirewallRule{
|
||||||
|
SourceRanges: sources,
|
||||||
|
Action: string(PolicyTrafficActionAccept),
|
||||||
|
Destination: route.Network.String(),
|
||||||
|
Protocol: string(PolicyRuleProtocolALL),
|
||||||
|
IsDynamic: route.IsDynamic(),
|
||||||
|
}
|
||||||
|
|
||||||
|
rules = append(rules, &rule)
|
||||||
|
|
||||||
|
// dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
|
||||||
|
if route.IsDynamic() {
|
||||||
|
ruleV6 := rule
|
||||||
|
ruleV6.SourceRanges = []string{"::/0"}
|
||||||
|
rules = append(rules, &ruleV6)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
|
||||||
|
// and returns a list of policies that have rules with destinations matching the specified groups.
|
||||||
|
func getAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
|
||||||
|
routePolicies := make([]*Policy, 0)
|
||||||
|
for _, groupID := range accessControlGroups {
|
||||||
|
group, ok := account.Groups[groupID]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, policy := range account.Policies {
|
||||||
|
for _, rule := range policy.Rules {
|
||||||
|
exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool {
|
||||||
|
return groupID == group.ID
|
||||||
|
})
|
||||||
|
if exist {
|
||||||
|
routePolicies = append(routePolicies, policy)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routePolicies
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRouteFirewallRules generates a list of firewall rules for a given route.
|
||||||
|
func generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
|
||||||
|
rulesExists := make(map[string]struct{})
|
||||||
|
rules := make([]*RouteFirewallRule, 0)
|
||||||
|
|
||||||
|
sourceRanges := make([]string, 0, len(groupPeers))
|
||||||
|
for _, peer := range groupPeers {
|
||||||
|
if peer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP))
|
||||||
|
}
|
||||||
|
|
||||||
|
baseRule := RouteFirewallRule{
|
||||||
|
SourceRanges: sourceRanges,
|
||||||
|
Action: string(rule.Action),
|
||||||
|
Destination: route.Network.String(),
|
||||||
|
Protocol: string(rule.Protocol),
|
||||||
|
IsDynamic: route.IsDynamic(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate rule for port range
|
||||||
|
if len(rule.Ports) == 0 {
|
||||||
|
rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
|
||||||
|
} else {
|
||||||
|
rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: generate IPv6 rules for dynamic routes
|
||||||
|
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRuleIDBase generates the base rule ID for checking duplicates.
|
||||||
|
func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string {
|
||||||
|
return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(firewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRulesForPeer generates rules for a given peer based on ports and port ranges.
|
||||||
|
func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
|
||||||
|
rules := make([]*RouteFirewallRule, 0)
|
||||||
|
|
||||||
|
ruleIDBase := generateRuleIDBase(rule, baseRule)
|
||||||
|
if len(rule.Ports) == 0 {
|
||||||
|
if len(rule.PortRanges) == 0 {
|
||||||
|
if _, ok := rulesExists[ruleIDBase]; !ok {
|
||||||
|
rulesExists[ruleIDBase] = struct{}{}
|
||||||
|
rules = append(rules, &baseRule)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, portRange := range rule.PortRanges {
|
||||||
|
ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End)
|
||||||
|
if _, ok := rulesExists[ruleID]; !ok {
|
||||||
|
rulesExists[ruleID] = struct{}{}
|
||||||
|
pr := baseRule
|
||||||
|
pr.PortRange = portRange
|
||||||
|
rules = append(rules, &pr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRulesWithPorts generates rules when specific ports are provided.
|
||||||
|
func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
|
||||||
|
rules := make([]*RouteFirewallRule, 0)
|
||||||
|
ruleIDBase := generateRuleIDBase(rule, baseRule)
|
||||||
|
|
||||||
|
for _, port := range rule.Ports {
|
||||||
|
ruleID := ruleIDBase + port
|
||||||
|
if _, ok := rulesExists[ruleID]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rulesExists[ruleID] = struct{}{}
|
||||||
|
|
||||||
|
pr := baseRule
|
||||||
|
p, err := strconv.ParseUint(port, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
pr.Port = uint16(p)
|
||||||
|
rules = append(rules, &pr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFirewallRule {
|
||||||
|
result := make([]*proto.RouteFirewallRule, len(rules))
|
||||||
|
for i := range rules {
|
||||||
|
rule := rules[i]
|
||||||
|
result[i] = &proto.RouteFirewallRule{
|
||||||
|
SourceRanges: rule.SourceRanges,
|
||||||
|
Action: getProtoAction(rule.Action),
|
||||||
|
Destination: rule.Destination,
|
||||||
|
Protocol: getProtoProtocol(rule.Protocol),
|
||||||
|
PortInfo: getProtoPortInfo(rule),
|
||||||
|
IsDynamic: rule.IsDynamic,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProtoDirection converts the direction to proto.RuleDirection.
|
||||||
|
func getProtoDirection(direction int) proto.RuleDirection {
|
||||||
|
if direction == firewallRuleDirectionOUT {
|
||||||
|
return proto.RuleDirection_OUT
|
||||||
|
}
|
||||||
|
return proto.RuleDirection_IN
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProtoAction converts the action to proto.RuleAction.
|
||||||
|
func getProtoAction(action string) proto.RuleAction {
|
||||||
|
if action == string(PolicyTrafficActionDrop) {
|
||||||
|
return proto.RuleAction_DROP
|
||||||
|
}
|
||||||
|
return proto.RuleAction_ACCEPT
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProtoProtocol converts the protocol to proto.RuleProtocol.
|
||||||
|
func getProtoProtocol(protocol string) proto.RuleProtocol {
|
||||||
|
switch PolicyRuleProtocolType(protocol) {
|
||||||
|
case PolicyRuleProtocolALL:
|
||||||
|
return proto.RuleProtocol_ALL
|
||||||
|
case PolicyRuleProtocolTCP:
|
||||||
|
return proto.RuleProtocol_TCP
|
||||||
|
case PolicyRuleProtocolUDP:
|
||||||
|
return proto.RuleProtocol_UDP
|
||||||
|
case PolicyRuleProtocolICMP:
|
||||||
|
return proto.RuleProtocol_ICMP
|
||||||
|
default:
|
||||||
|
return proto.RuleProtocol_UNKNOWN
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProtoPortInfo converts the port info to proto.PortInfo.
|
||||||
|
func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
|
||||||
|
var portInfo proto.PortInfo
|
||||||
|
if rule.Port != 0 {
|
||||||
|
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
|
||||||
|
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
|
||||||
|
portInfo.PortSelection = &proto.PortInfo_Range_{
|
||||||
|
Range: &proto.PortInfo_Range{
|
||||||
|
Start: uint32(portRange.Start),
|
||||||
|
End: uint32(portRange.End),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &portInfo
|
||||||
|
}
|
||||||
|
@ -2,6 +2,8 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -44,18 +46,19 @@ var existingDomains = domain.List{"example.com"}
|
|||||||
|
|
||||||
func TestCreateRoute(t *testing.T) {
|
func TestCreateRoute(t *testing.T) {
|
||||||
type input struct {
|
type input struct {
|
||||||
network netip.Prefix
|
network netip.Prefix
|
||||||
domains domain.List
|
domains domain.List
|
||||||
keepRoute bool
|
keepRoute bool
|
||||||
networkType route.NetworkType
|
networkType route.NetworkType
|
||||||
netID route.NetID
|
netID route.NetID
|
||||||
peerKey string
|
peerKey string
|
||||||
peerGroupIDs []string
|
peerGroupIDs []string
|
||||||
description string
|
description string
|
||||||
masquerade bool
|
masquerade bool
|
||||||
metric int
|
metric int
|
||||||
enabled bool
|
enabled bool
|
||||||
groups []string
|
groups []string
|
||||||
|
accessControlGroups []string
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
@ -69,100 +72,107 @@ func TestCreateRoute(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Happy Path Network",
|
name: "Happy Path Network",
|
||||||
inputArgs: input{
|
inputArgs: input{
|
||||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
networkType: route.IPv4Network,
|
networkType: route.IPv4Network,
|
||||||
netID: "happy",
|
netID: "happy",
|
||||||
peerKey: peer1ID,
|
peerKey: peer1ID,
|
||||||
description: "super",
|
description: "super",
|
||||||
masquerade: false,
|
masquerade: false,
|
||||||
metric: 9999,
|
metric: 9999,
|
||||||
enabled: true,
|
enabled: true,
|
||||||
groups: []string{routeGroup1},
|
groups: []string{routeGroup1},
|
||||||
|
accessControlGroups: []string{routeGroup1},
|
||||||
},
|
},
|
||||||
errFunc: require.NoError,
|
errFunc: require.NoError,
|
||||||
shouldCreate: true,
|
shouldCreate: true,
|
||||||
expectedRoute: &route.Route{
|
expectedRoute: &route.Route{
|
||||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
NetID: "happy",
|
NetID: "happy",
|
||||||
Peer: peer1ID,
|
Peer: peer1ID,
|
||||||
Description: "super",
|
Description: "super",
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Groups: []string{routeGroup1},
|
Groups: []string{routeGroup1},
|
||||||
|
AccessControlGroups: []string{routeGroup1},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Happy Path Domains",
|
name: "Happy Path Domains",
|
||||||
inputArgs: input{
|
inputArgs: input{
|
||||||
domains: domain.List{"domain1", "domain2"},
|
domains: domain.List{"domain1", "domain2"},
|
||||||
keepRoute: true,
|
keepRoute: true,
|
||||||
networkType: route.DomainNetwork,
|
networkType: route.DomainNetwork,
|
||||||
netID: "happy",
|
netID: "happy",
|
||||||
peerKey: peer1ID,
|
peerKey: peer1ID,
|
||||||
description: "super",
|
description: "super",
|
||||||
masquerade: false,
|
masquerade: false,
|
||||||
metric: 9999,
|
metric: 9999,
|
||||||
enabled: true,
|
enabled: true,
|
||||||
groups: []string{routeGroup1},
|
groups: []string{routeGroup1},
|
||||||
|
accessControlGroups: []string{routeGroup1},
|
||||||
},
|
},
|
||||||
errFunc: require.NoError,
|
errFunc: require.NoError,
|
||||||
shouldCreate: true,
|
shouldCreate: true,
|
||||||
expectedRoute: &route.Route{
|
expectedRoute: &route.Route{
|
||||||
Network: netip.MustParsePrefix("192.0.2.0/32"),
|
Network: netip.MustParsePrefix("192.0.2.0/32"),
|
||||||
Domains: domain.List{"domain1", "domain2"},
|
Domains: domain.List{"domain1", "domain2"},
|
||||||
NetworkType: route.DomainNetwork,
|
NetworkType: route.DomainNetwork,
|
||||||
NetID: "happy",
|
NetID: "happy",
|
||||||
Peer: peer1ID,
|
Peer: peer1ID,
|
||||||
Description: "super",
|
Description: "super",
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Groups: []string{routeGroup1},
|
Groups: []string{routeGroup1},
|
||||||
KeepRoute: true,
|
KeepRoute: true,
|
||||||
|
AccessControlGroups: []string{routeGroup1},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Happy Path Peer Groups",
|
name: "Happy Path Peer Groups",
|
||||||
inputArgs: input{
|
inputArgs: input{
|
||||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
networkType: route.IPv4Network,
|
networkType: route.IPv4Network,
|
||||||
netID: "happy",
|
netID: "happy",
|
||||||
peerGroupIDs: []string{routeGroupHA1, routeGroupHA2},
|
peerGroupIDs: []string{routeGroupHA1, routeGroupHA2},
|
||||||
description: "super",
|
description: "super",
|
||||||
masquerade: false,
|
masquerade: false,
|
||||||
metric: 9999,
|
metric: 9999,
|
||||||
enabled: true,
|
enabled: true,
|
||||||
groups: []string{routeGroup1, routeGroup2},
|
groups: []string{routeGroup1, routeGroup2},
|
||||||
|
accessControlGroups: []string{routeGroup1, routeGroup2},
|
||||||
},
|
},
|
||||||
errFunc: require.NoError,
|
errFunc: require.NoError,
|
||||||
shouldCreate: true,
|
shouldCreate: true,
|
||||||
expectedRoute: &route.Route{
|
expectedRoute: &route.Route{
|
||||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
NetID: "happy",
|
NetID: "happy",
|
||||||
PeerGroups: []string{routeGroupHA1, routeGroupHA2},
|
PeerGroups: []string{routeGroupHA1, routeGroupHA2},
|
||||||
Description: "super",
|
Description: "super",
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Groups: []string{routeGroup1, routeGroup2},
|
Groups: []string{routeGroup1, routeGroup2},
|
||||||
|
AccessControlGroups: []string{routeGroup1, routeGroup2},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Both network and domains provided should fail",
|
name: "Both network and domains provided should fail",
|
||||||
inputArgs: input{
|
inputArgs: input{
|
||||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
domains: domain.List{"domain1", "domain2"},
|
domains: domain.List{"domain1", "domain2"},
|
||||||
netID: "happy",
|
netID: "happy",
|
||||||
peerKey: peer1ID,
|
peerKey: peer1ID,
|
||||||
peerGroupIDs: []string{routeGroupHA1},
|
peerGroupIDs: []string{routeGroupHA1},
|
||||||
description: "super",
|
description: "super",
|
||||||
masquerade: false,
|
masquerade: false,
|
||||||
metric: 9999,
|
metric: 9999,
|
||||||
enabled: true,
|
enabled: true,
|
||||||
groups: []string{routeGroup1},
|
groups: []string{routeGroup1},
|
||||||
|
accessControlGroups: []string{routeGroup2},
|
||||||
},
|
},
|
||||||
errFunc: require.Error,
|
errFunc: require.Error,
|
||||||
shouldCreate: false,
|
shouldCreate: false,
|
||||||
@ -170,16 +180,17 @@ func TestCreateRoute(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Both peer and peer_groups Provided Should Fail",
|
name: "Both peer and peer_groups Provided Should Fail",
|
||||||
inputArgs: input{
|
inputArgs: input{
|
||||||
network: netip.MustParsePrefix("192.168.0.0/16"),
|
network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
networkType: route.IPv4Network,
|
networkType: route.IPv4Network,
|
||||||
netID: "happy",
|
netID: "happy",
|
||||||
peerKey: peer1ID,
|
peerKey: peer1ID,
|
||||||
peerGroupIDs: []string{routeGroupHA1},
|
peerGroupIDs: []string{routeGroupHA1},
|
||||||
description: "super",
|
description: "super",
|
||||||
masquerade: false,
|
masquerade: false,
|
||||||
metric: 9999,
|
metric: 9999,
|
||||||
enabled: true,
|
enabled: true,
|
||||||
groups: []string{routeGroup1},
|
groups: []string{routeGroup1},
|
||||||
|
accessControlGroups: []string{routeGroup2},
|
||||||
},
|
},
|
||||||
errFunc: require.Error,
|
errFunc: require.Error,
|
||||||
shouldCreate: false,
|
shouldCreate: false,
|
||||||
@ -423,13 +434,13 @@ func TestCreateRoute(t *testing.T) {
|
|||||||
if testCase.createInitRoute {
|
if testCase.createInitRoute {
|
||||||
groupAll, errInit := account.GetGroupAll()
|
groupAll, errInit := account.GetGroupAll()
|
||||||
require.NoError(t, errInit)
|
require.NoError(t, errInit)
|
||||||
_, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false)
|
_, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
|
||||||
require.NoError(t, errInit)
|
require.NoError(t, errInit)
|
||||||
_, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false)
|
_, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
|
||||||
require.NoError(t, errInit)
|
require.NoError(t, errInit)
|
||||||
}
|
}
|
||||||
|
|
||||||
outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
|
outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
|
||||||
|
|
||||||
testCase.errFunc(t, err)
|
testCase.errFunc(t, err)
|
||||||
|
|
||||||
@ -1037,15 +1048,16 @@ func TestDeleteRoute(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||||
baseRoute := &route.Route{
|
baseRoute := &route.Route{
|
||||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
NetID: "superNet",
|
NetID: "superNet",
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
PeerGroups: []string{routeGroupHA1, routeGroupHA2},
|
PeerGroups: []string{routeGroupHA1, routeGroupHA2},
|
||||||
Description: "ha route",
|
Description: "ha route",
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Groups: []string{routeGroup1, routeGroup2},
|
Groups: []string{routeGroup1, routeGroup2},
|
||||||
|
AccessControlGroups: []string{routeGroup1},
|
||||||
}
|
}
|
||||||
|
|
||||||
am, err := createRouterManager(t)
|
am, err := createRouterManager(t)
|
||||||
@ -1062,7 +1074,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
|
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
|
||||||
|
|
||||||
newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
|
newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, newRoute.Enabled, true)
|
require.Equal(t, newRoute.Enabled, true)
|
||||||
|
|
||||||
@ -1127,16 +1139,17 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
|||||||
// no routes for peer in different groups
|
// no routes for peer in different groups
|
||||||
// no routes when route is deleted
|
// no routes when route is deleted
|
||||||
baseRoute := &route.Route{
|
baseRoute := &route.Route{
|
||||||
ID: "testingRoute",
|
ID: "testingRoute",
|
||||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
NetID: "superNet",
|
NetID: "superNet",
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
Peer: peer1ID,
|
Peer: peer1ID,
|
||||||
Description: "super",
|
Description: "super",
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Groups: []string{routeGroup1},
|
Groups: []string{routeGroup1},
|
||||||
|
AccessControlGroups: []string{routeGroup1},
|
||||||
}
|
}
|
||||||
|
|
||||||
am, err := createRouterManager(t)
|
am, err := createRouterManager(t)
|
||||||
@ -1153,7 +1166,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
|
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
|
||||||
|
|
||||||
createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, userID, baseRoute.KeepRoute)
|
createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
|
noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||||
@ -1467,3 +1480,300 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
|
|||||||
|
|
||||||
return am.Store.GetAccount(context.Background(), account.Id)
|
return am.Store.GetAccount(context.Background(), account.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||||
|
var (
|
||||||
|
peerBIp = "100.65.80.39"
|
||||||
|
peerCIp = "100.65.254.139"
|
||||||
|
peerHIp = "100.65.29.55"
|
||||||
|
)
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
Peers: map[string]*nbpeer.Peer{
|
||||||
|
"peerA": {
|
||||||
|
ID: "peerA",
|
||||||
|
IP: net.ParseIP("100.65.14.88"),
|
||||||
|
Status: &nbpeer.PeerStatus{},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{
|
||||||
|
GoOS: "linux",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"peerB": {
|
||||||
|
ID: "peerB",
|
||||||
|
IP: net.ParseIP(peerBIp),
|
||||||
|
Status: &nbpeer.PeerStatus{},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{},
|
||||||
|
},
|
||||||
|
"peerC": {
|
||||||
|
ID: "peerC",
|
||||||
|
IP: net.ParseIP(peerCIp),
|
||||||
|
Status: &nbpeer.PeerStatus{},
|
||||||
|
},
|
||||||
|
"peerD": {
|
||||||
|
ID: "peerD",
|
||||||
|
IP: net.ParseIP("100.65.62.5"),
|
||||||
|
Status: &nbpeer.PeerStatus{},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{
|
||||||
|
GoOS: "linux",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"peerE": {
|
||||||
|
ID: "peerE",
|
||||||
|
IP: net.ParseIP("100.65.32.206"),
|
||||||
|
Key: peer1Key,
|
||||||
|
Status: &nbpeer.PeerStatus{},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{
|
||||||
|
GoOS: "linux",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"peerF": {
|
||||||
|
ID: "peerF",
|
||||||
|
IP: net.ParseIP("100.65.250.202"),
|
||||||
|
Status: &nbpeer.PeerStatus{},
|
||||||
|
},
|
||||||
|
"peerG": {
|
||||||
|
ID: "peerG",
|
||||||
|
IP: net.ParseIP("100.65.13.186"),
|
||||||
|
Status: &nbpeer.PeerStatus{},
|
||||||
|
},
|
||||||
|
"peerH": {
|
||||||
|
ID: "peerH",
|
||||||
|
IP: net.ParseIP(peerHIp),
|
||||||
|
Status: &nbpeer.PeerStatus{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Groups: map[string]*nbgroup.Group{
|
||||||
|
"routingPeer1": {
|
||||||
|
ID: "routingPeer1",
|
||||||
|
Name: "RoutingPeer1",
|
||||||
|
Peers: []string{
|
||||||
|
"peerA",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"routingPeer2": {
|
||||||
|
ID: "routingPeer2",
|
||||||
|
Name: "RoutingPeer2",
|
||||||
|
Peers: []string{
|
||||||
|
"peerD",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Name: "Route1",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Name: "Route2",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
"finance": {
|
||||||
|
ID: "finance",
|
||||||
|
Name: "Finance",
|
||||||
|
Peers: []string{
|
||||||
|
"peerF",
|
||||||
|
"peerG",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"dev": {
|
||||||
|
ID: "dev",
|
||||||
|
Name: "Dev",
|
||||||
|
Peers: []string{
|
||||||
|
"peerC",
|
||||||
|
"peerH",
|
||||||
|
"peerB",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"contractors": {
|
||||||
|
ID: "contractors",
|
||||||
|
Name: "Contractors",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Routes: map[route.ID]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
NetID: "route1",
|
||||||
|
NetworkType: route.IPv4Network,
|
||||||
|
PeerGroups: []string{"routingPeer1", "routingPeer2"},
|
||||||
|
Description: "Route1 ha route",
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 9999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{"dev"},
|
||||||
|
AccessControlGroups: []string{"route1"},
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Network: existingNetwork,
|
||||||
|
NetID: "route2",
|
||||||
|
NetworkType: route.IPv4Network,
|
||||||
|
Peer: "peerE",
|
||||||
|
Description: "Allow",
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 9999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{"finance"},
|
||||||
|
AccessControlGroups: []string{"route2"},
|
||||||
|
},
|
||||||
|
"route3": {
|
||||||
|
ID: "route3",
|
||||||
|
Network: netip.MustParsePrefix("192.0.2.0/32"),
|
||||||
|
Domains: domain.List{"example.com"},
|
||||||
|
NetID: "route3",
|
||||||
|
NetworkType: route.DomainNetwork,
|
||||||
|
Peer: "peerE",
|
||||||
|
Description: "Allow all traffic to routed DNS network",
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 9999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{"contractors"},
|
||||||
|
AccessControlGroups: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Policies: []*Policy{
|
||||||
|
{
|
||||||
|
ID: "RuleRoute1",
|
||||||
|
Name: "Route1",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "RuleRoute1",
|
||||||
|
Name: "ruleRoute1",
|
||||||
|
Bidirectional: true,
|
||||||
|
Enabled: true,
|
||||||
|
Protocol: PolicyRuleProtocolALL,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
Ports: []string{"80", "320"},
|
||||||
|
Sources: []string{
|
||||||
|
"dev",
|
||||||
|
},
|
||||||
|
Destinations: []string{
|
||||||
|
"route1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "RuleRoute2",
|
||||||
|
Name: "Route2",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "RuleRoute2",
|
||||||
|
Name: "ruleRoute2",
|
||||||
|
Bidirectional: true,
|
||||||
|
Enabled: true,
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{
|
||||||
|
Start: 80,
|
||||||
|
End: 350,
|
||||||
|
}, {
|
||||||
|
Start: 80,
|
||||||
|
End: 350,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Sources: []string{
|
||||||
|
"finance",
|
||||||
|
},
|
||||||
|
Destinations: []string{
|
||||||
|
"route2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
validatedPeers := make(map[string]struct{})
|
||||||
|
for p := range account.Peers {
|
||||||
|
validatedPeers[p] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("check applied policies for the route", func(t *testing.T) {
|
||||||
|
route1 := account.Routes["route1"]
|
||||||
|
policies := getAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
|
||||||
|
assert.Len(t, policies, 1)
|
||||||
|
|
||||||
|
route2 := account.Routes["route2"]
|
||||||
|
policies = getAllRoutePoliciesFromGroups(account, route2.AccessControlGroups)
|
||||||
|
assert.Len(t, policies, 1)
|
||||||
|
|
||||||
|
route3 := account.Routes["route3"]
|
||||||
|
policies = getAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
|
||||||
|
assert.Len(t, policies, 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("check peer routes firewall rules", func(t *testing.T) {
|
||||||
|
routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
|
||||||
|
assert.Len(t, routesFirewallRules, 2)
|
||||||
|
|
||||||
|
expectedRoutesFirewallRules := []*RouteFirewallRule{
|
||||||
|
{
|
||||||
|
SourceRanges: []string{
|
||||||
|
fmt.Sprintf(AllowedIPsFormat, peerCIp),
|
||||||
|
fmt.Sprintf(AllowedIPsFormat, peerHIp),
|
||||||
|
fmt.Sprintf(AllowedIPsFormat, peerBIp),
|
||||||
|
},
|
||||||
|
Action: "accept",
|
||||||
|
Destination: "192.168.0.0/16",
|
||||||
|
Protocol: "all",
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SourceRanges: []string{
|
||||||
|
fmt.Sprintf(AllowedIPsFormat, peerCIp),
|
||||||
|
fmt.Sprintf(AllowedIPsFormat, peerHIp),
|
||||||
|
fmt.Sprintf(AllowedIPsFormat, peerBIp),
|
||||||
|
},
|
||||||
|
Action: "accept",
|
||||||
|
Destination: "192.168.0.0/16",
|
||||||
|
Protocol: "all",
|
||||||
|
Port: 320,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
|
||||||
|
|
||||||
|
//peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
|
||||||
|
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
|
||||||
|
assert.Len(t, routesFirewallRules, 2)
|
||||||
|
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
|
||||||
|
|
||||||
|
// peerE is a single routing peer for route 2 and route 3
|
||||||
|
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
|
||||||
|
assert.Len(t, routesFirewallRules, 3)
|
||||||
|
|
||||||
|
expectedRoutesFirewallRules = []*RouteFirewallRule{
|
||||||
|
{
|
||||||
|
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
|
||||||
|
Action: "accept",
|
||||||
|
Destination: existingNetwork.String(),
|
||||||
|
Protocol: "tcp",
|
||||||
|
PortRange: RulePortRange{Start: 80, End: 350},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SourceRanges: []string{"0.0.0.0/0"},
|
||||||
|
Action: "accept",
|
||||||
|
Destination: "192.0.2.0/32",
|
||||||
|
Protocol: "all",
|
||||||
|
IsDynamic: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SourceRanges: []string{"::/0"},
|
||||||
|
Action: "accept",
|
||||||
|
Destination: "192.0.2.0/32",
|
||||||
|
Protocol: "all",
|
||||||
|
IsDynamic: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
|
||||||
|
|
||||||
|
// peerC is part of route1 distribution groups but should not receive the routes firewall rules
|
||||||
|
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
|
||||||
|
assert.Len(t, routesFirewallRules, 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
@ -100,6 +100,7 @@ type Route struct {
|
|||||||
Metric int
|
Metric int
|
||||||
Enabled bool
|
Enabled bool
|
||||||
Groups []string `gorm:"serializer:json"`
|
Groups []string `gorm:"serializer:json"`
|
||||||
|
AccessControlGroups []string `gorm:"serializer:json"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventMeta returns activity event meta related to the route
|
// EventMeta returns activity event meta related to the route
|
||||||
@ -123,6 +124,7 @@ func (r *Route) Copy() *Route {
|
|||||||
Masquerade: r.Masquerade,
|
Masquerade: r.Masquerade,
|
||||||
Enabled: r.Enabled,
|
Enabled: r.Enabled,
|
||||||
Groups: slices.Clone(r.Groups),
|
Groups: slices.Clone(r.Groups),
|
||||||
|
AccessControlGroups: slices.Clone(r.AccessControlGroups),
|
||||||
}
|
}
|
||||||
return route
|
return route
|
||||||
}
|
}
|
||||||
@ -147,7 +149,8 @@ func (r *Route) IsEqual(other *Route) bool {
|
|||||||
other.Masquerade == r.Masquerade &&
|
other.Masquerade == r.Masquerade &&
|
||||||
other.Enabled == r.Enabled &&
|
other.Enabled == r.Enabled &&
|
||||||
slices.Equal(r.Groups, other.Groups) &&
|
slices.Equal(r.Groups, other.Groups) &&
|
||||||
slices.Equal(r.PeerGroups, other.PeerGroups)
|
slices.Equal(r.PeerGroups, other.PeerGroups)&&
|
||||||
|
slices.Equal(r.AccessControlGroups, other.AccessControlGroups)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsDynamic returns if the route is dynamic, i.e. has domains
|
// IsDynamic returns if the route is dynamic, i.e. has domains
|
||||||
|
Loading…
Reference in New Issue
Block a user