mirror of
https://github.com/netbirdio/netbird.git
synced 2025-04-14 14:38:27 +02:00
Fix/acl for forward (#1305)
Fix ACL on routed traffic and code refactor
This commit is contained in:
parent
b03343bc4d
commit
006ba32086
32
client/firewall/create.go
Normal file
32
client/firewall/create.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewFirewall creates a firewall manager instance
|
||||||
|
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
||||||
|
if !iface.IsUserspaceBind() {
|
||||||
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
|
||||||
|
// use userspace packet filtering firewall
|
||||||
|
fm, err := uspfilter.Create(iface)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = fm.AllowNetbird()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to allow netbird interface traffic: %v", err)
|
||||||
|
}
|
||||||
|
return fm, nil
|
||||||
|
}
|
107
client/firewall/create_linux.go
Normal file
107
client/firewall/create_linux.go
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/google/nftables"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbiptables "github.com/netbirdio/netbird/client/firewall/iptables"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// UNKNOWN is the default value for the firewall type for unknown firewall type
|
||||||
|
UNKNOWN FWType = iota
|
||||||
|
// IPTABLES is the value for the iptables firewall type
|
||||||
|
IPTABLES
|
||||||
|
// NFTABLES is the value for the nftables firewall type
|
||||||
|
NFTABLES
|
||||||
|
)
|
||||||
|
|
||||||
|
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
|
||||||
|
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||||
|
|
||||||
|
// FWType is the type for the firewall type
|
||||||
|
type FWType int
|
||||||
|
|
||||||
|
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
||||||
|
// on the linux system we try to user nftables or iptables
|
||||||
|
// in any case, because we need to allow netbird interface traffic
|
||||||
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
|
// for the userspace packet filtering firewall
|
||||||
|
var fm firewall.Manager
|
||||||
|
var errFw error
|
||||||
|
|
||||||
|
switch check() {
|
||||||
|
case IPTABLES:
|
||||||
|
log.Debug("creating an iptables firewall manager")
|
||||||
|
fm, errFw = nbiptables.Create(context, iface)
|
||||||
|
if errFw != nil {
|
||||||
|
log.Errorf("failed to create iptables manager: %s", errFw)
|
||||||
|
}
|
||||||
|
case NFTABLES:
|
||||||
|
log.Debug("creating an nftables firewall manager")
|
||||||
|
fm, errFw = nbnftables.Create(context, iface)
|
||||||
|
if errFw != nil {
|
||||||
|
log.Errorf("failed to create nftables manager: %s", errFw)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
errFw = fmt.Errorf("no firewall manager found")
|
||||||
|
log.Debug("no firewall manager found, try to use userspace packet filtering firewall")
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface.IsUserspaceBind() {
|
||||||
|
var errUsp error
|
||||||
|
if errFw == nil {
|
||||||
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||||
|
} else {
|
||||||
|
fm, errUsp = uspfilter.Create(iface)
|
||||||
|
}
|
||||||
|
if errUsp != nil {
|
||||||
|
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
|
||||||
|
return nil, errUsp
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fm.AllowNetbird(); err != nil {
|
||||||
|
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||||
|
}
|
||||||
|
return fm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if errFw != nil {
|
||||||
|
return nil, errFw
|
||||||
|
}
|
||||||
|
|
||||||
|
return fm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||||
|
func check() FWType {
|
||||||
|
nf := nftables.Conn{}
|
||||||
|
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||||
|
return NFTABLES
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
if err != nil {
|
||||||
|
return UNKNOWN
|
||||||
|
}
|
||||||
|
if isIptablesClientAvailable(ip) {
|
||||||
|
return IPTABLES
|
||||||
|
}
|
||||||
|
|
||||||
|
return UNKNOWN
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||||
|
_, err := client.ListChains("filter")
|
||||||
|
return err == nil
|
||||||
|
}
|
11
client/firewall/iface.go
Normal file
11
client/firewall/iface.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/iface"
|
||||||
|
|
||||||
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
|
type IFaceMapper interface {
|
||||||
|
Name() string
|
||||||
|
Address() iface.WGAddress
|
||||||
|
IsUserspaceBind() bool
|
||||||
|
SetFilter(iface.PacketFilter) error
|
||||||
|
}
|
473
client/firewall/iptables/acl_linux.go
Normal file
473
client/firewall/iptables/acl_linux.go
Normal file
@ -0,0 +1,473 @@
|
|||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/nadoo/ipset"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
tableName = "filter"
|
||||||
|
|
||||||
|
// rules chains contains the effective ACL rules
|
||||||
|
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||||
|
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||||
|
|
||||||
|
postRoutingMark = "0x000007e4"
|
||||||
|
)
|
||||||
|
|
||||||
|
type aclManager struct {
|
||||||
|
iptablesClient *iptables.IPTables
|
||||||
|
wgIface iFaceMapper
|
||||||
|
routeingFwChainName string
|
||||||
|
|
||||||
|
entries map[string][][]string
|
||||||
|
ipsetStore *ipsetStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
|
||||||
|
m := &aclManager{
|
||||||
|
iptablesClient: iptablesClient,
|
||||||
|
wgIface: wgIface,
|
||||||
|
routeingFwChainName: routeingFwChainName,
|
||||||
|
|
||||||
|
entries: make(map[string][][]string),
|
||||||
|
ipsetStore: newIpsetStore(),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ipset.Init()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to init ipset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.seedInitialEntries()
|
||||||
|
|
||||||
|
err = m.cleanChains()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.createDefaultChains()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) AddFiltering(
|
||||||
|
ip net.IP,
|
||||||
|
protocol firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
direction firewall.RuleDirection,
|
||||||
|
action firewall.Action,
|
||||||
|
ipsetName string,
|
||||||
|
) ([]firewall.Rule, error) {
|
||||||
|
var dPortVal, sPortVal string
|
||||||
|
if dPort != nil && dPort.Values != nil {
|
||||||
|
// TODO: we support only one port per rule in current implementation of ACLs
|
||||||
|
dPortVal = strconv.Itoa(dPort.Values[0])
|
||||||
|
}
|
||||||
|
if sPort != nil && sPort.Values != nil {
|
||||||
|
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
var chain string
|
||||||
|
if direction == firewall.RuleDirectionOUT {
|
||||||
|
chain = chainNameOutputRules
|
||||||
|
} else {
|
||||||
|
chain = chainNameInputRules
|
||||||
|
}
|
||||||
|
|
||||||
|
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
|
||||||
|
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
|
||||||
|
if ipsetName != "" {
|
||||||
|
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||||
|
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||||
|
}
|
||||||
|
// if ruleset already exists it means we already have the firewall rule
|
||||||
|
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||||
|
ipList.addIP(ip.String())
|
||||||
|
return []firewall.Rule{&Rule{
|
||||||
|
ruleID: uuid.New().String(),
|
||||||
|
ipsetName: ipsetName,
|
||||||
|
ip: ip.String(),
|
||||||
|
chain: chain,
|
||||||
|
specs: specs,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Flush(ipsetName); err != nil {
|
||||||
|
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
|
||||||
|
}
|
||||||
|
if err := ipset.Create(ipsetName); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||||
|
}
|
||||||
|
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipList := newIpList(ip.String())
|
||||||
|
m.ipsetStore.addIpList(ipsetName, ipList)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
return nil, fmt.Errorf("rule already exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := &Rule{
|
||||||
|
ruleID: uuid.New().String(),
|
||||||
|
specs: specs,
|
||||||
|
ipsetName: ipsetName,
|
||||||
|
ip: ip.String(),
|
||||||
|
chain: chain,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !shouldAddToPrerouting(protocol, dPort, direction) {
|
||||||
|
return []firewall.Rule{rule}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
|
||||||
|
if err != nil {
|
||||||
|
return []firewall.Rule{rule}, err
|
||||||
|
}
|
||||||
|
return []firewall.Rule{rule, rulePrerouting}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRule from the firewall by rule definition
|
||||||
|
func (m *aclManager) DeleteRule(rule firewall.Rule) error {
|
||||||
|
r, ok := rule.(*Rule)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("invalid rule type")
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.chain == "PREROUTING" {
|
||||||
|
goto DELETERULE
|
||||||
|
}
|
||||||
|
|
||||||
|
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||||
|
// delete IP from ruleset IPs list and ipset
|
||||||
|
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||||
|
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||||
|
}
|
||||||
|
delete(ipsetList.ips, r.ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if after delete, set still contains other IPs,
|
||||||
|
// no need to delete firewall rule and we should exit here
|
||||||
|
if len(ipsetList.ips) != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// we delete last IP from the set, that means we need to delete
|
||||||
|
// set itself and associated firewall rule too
|
||||||
|
m.ipsetStore.deleteIpset(r.ipsetName)
|
||||||
|
|
||||||
|
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||||
|
log.Errorf("delete empty ipset: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DELETERULE:
|
||||||
|
var table string
|
||||||
|
if r.chain == "PREROUTING" {
|
||||||
|
table = "mangle"
|
||||||
|
} else {
|
||||||
|
table = "filter"
|
||||||
|
}
|
||||||
|
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) Reset() error {
|
||||||
|
return m.cleanChains()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
|
||||||
|
var src []string
|
||||||
|
if ipsetName != "" {
|
||||||
|
src = []string{"-m", "set", "--set", ipsetName, "src"}
|
||||||
|
} else {
|
||||||
|
src = []string{"-s", ip.String()}
|
||||||
|
}
|
||||||
|
specs := []string{
|
||||||
|
"-d", m.wgIface.Address().IP.String(),
|
||||||
|
"-p", protocol,
|
||||||
|
"--dport", port,
|
||||||
|
"-j", "MARK", "--set-mark", postRoutingMark,
|
||||||
|
}
|
||||||
|
|
||||||
|
specs = append(src, specs...)
|
||||||
|
|
||||||
|
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
return nil, fmt.Errorf("rule already exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := &Rule{
|
||||||
|
ruleID: uuid.New().String(),
|
||||||
|
specs: specs,
|
||||||
|
ipsetName: ipsetName,
|
||||||
|
ip: ip.String(),
|
||||||
|
chain: "PREROUTING",
|
||||||
|
}
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo write less destructive cleanup mechanism
|
||||||
|
func (m *aclManager) cleanChains() error {
|
||||||
|
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to list chains: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
rules := m.entries["OUTPUT"]
|
||||||
|
for _, rule := range rules {
|
||||||
|
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to list chains: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
for _, rule := range m.entries["INPUT"] {
|
||||||
|
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range m.entries["FORWARD"] {
|
||||||
|
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to list chains: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
for _, rule := range m.entries["PREROUTING"] {
|
||||||
|
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||||
|
if err := ipset.Flush(ipsetName); err != nil {
|
||||||
|
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||||
|
}
|
||||||
|
if err := ipset.Destroy(ipsetName); err != nil {
|
||||||
|
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||||
|
}
|
||||||
|
m.ipsetStore.deleteIpset(ipsetName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) createDefaultChains() error {
|
||||||
|
// chain netbird-acl-input-rules
|
||||||
|
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
|
||||||
|
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// chain netbird-acl-output-rules
|
||||||
|
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
|
||||||
|
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for chainName, rules := range m.entries {
|
||||||
|
for _, rule := range rules {
|
||||||
|
if chainName == "FORWARD" {
|
||||||
|
// position 2 because we add it after router's, jump rule
|
||||||
|
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
|
||||||
|
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
|
||||||
|
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) seedInitialEntries() {
|
||||||
|
m.appendToEntries("INPUT",
|
||||||
|
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||||
|
|
||||||
|
m.appendToEntries("INPUT",
|
||||||
|
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||||
|
|
||||||
|
m.appendToEntries("INPUT",
|
||||||
|
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
|
||||||
|
|
||||||
|
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||||
|
|
||||||
|
m.appendToEntries("OUTPUT",
|
||||||
|
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||||
|
|
||||||
|
m.appendToEntries("OUTPUT",
|
||||||
|
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
||||||
|
|
||||||
|
m.appendToEntries("OUTPUT",
|
||||||
|
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
|
||||||
|
|
||||||
|
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
|
||||||
|
|
||||||
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||||
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||||
|
m.appendToEntries("FORWARD",
|
||||||
|
[]string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
||||||
|
m.appendToEntries("FORWARD",
|
||||||
|
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
|
||||||
|
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
||||||
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
|
||||||
|
|
||||||
|
m.appendToEntries("PREROUTING",
|
||||||
|
[]string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||||
|
m.entries[chainName] = append(m.entries[chainName], spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
|
func filterRuleSpecs(
|
||||||
|
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
||||||
|
) (specs []string) {
|
||||||
|
matchByIP := true
|
||||||
|
// don't use IP matching if IP is ip 0.0.0.0
|
||||||
|
if ip.String() == "0.0.0.0" {
|
||||||
|
matchByIP = false
|
||||||
|
}
|
||||||
|
switch direction {
|
||||||
|
case firewall.RuleDirectionIN:
|
||||||
|
if matchByIP {
|
||||||
|
if ipsetName != "" {
|
||||||
|
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||||
|
} else {
|
||||||
|
specs = append(specs, "-s", ip.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case firewall.RuleDirectionOUT:
|
||||||
|
if matchByIP {
|
||||||
|
if ipsetName != "" {
|
||||||
|
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
||||||
|
} else {
|
||||||
|
specs = append(specs, "-d", ip.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if protocol != "all" {
|
||||||
|
specs = append(specs, "-p", protocol)
|
||||||
|
}
|
||||||
|
if sPort != "" {
|
||||||
|
specs = append(specs, "--sport", sPort)
|
||||||
|
}
|
||||||
|
if dPort != "" {
|
||||||
|
specs = append(specs, "--dport", dPort)
|
||||||
|
}
|
||||||
|
return append(specs, "-j", actionToStr(action))
|
||||||
|
}
|
||||||
|
|
||||||
|
func actionToStr(action firewall.Action) string {
|
||||||
|
if action == firewall.ActionAccept {
|
||||||
|
return "ACCEPT"
|
||||||
|
}
|
||||||
|
return "DROP"
|
||||||
|
}
|
||||||
|
|
||||||
|
func transformIPsetName(ipsetName string, sPort, dPort string) string {
|
||||||
|
switch {
|
||||||
|
case ipsetName == "":
|
||||||
|
return ""
|
||||||
|
case sPort != "" && dPort != "":
|
||||||
|
return ipsetName + "-sport-dport"
|
||||||
|
case sPort != "":
|
||||||
|
return ipsetName + "-sport"
|
||||||
|
case dPort != "":
|
||||||
|
return ipsetName + "-dport"
|
||||||
|
default:
|
||||||
|
return ipsetName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
|
||||||
|
if proto == "all" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if direction != firewall.RuleDirectionIN {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if dPort == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
@ -1,43 +1,27 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/nadoo/ipset"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
// ChainInputFilterName is the name of the chain that is used for filtering incoming packets
|
|
||||||
ChainInputFilterName = "NETBIRD-ACL-INPUT"
|
|
||||||
|
|
||||||
// ChainOutputFilterName is the name of the chain that is used for filtering outgoing packets
|
|
||||||
ChainOutputFilterName = "NETBIRD-ACL-OUTPUT"
|
|
||||||
)
|
|
||||||
|
|
||||||
// dropAllDefaultRule in the Netbird chain
|
|
||||||
var dropAllDefaultRule = []string{"-j", "DROP"}
|
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
|
wgIface iFaceMapper
|
||||||
|
|
||||||
ipv4Client *iptables.IPTables
|
ipv4Client *iptables.IPTables
|
||||||
ipv6Client *iptables.IPTables
|
aclMgr *aclManager
|
||||||
|
router *routerManager
|
||||||
inputDefaultRuleSpecs []string
|
|
||||||
outputDefaultRuleSpecs []string
|
|
||||||
wgIface iFaceMapper
|
|
||||||
|
|
||||||
rulesets map[string]ruleset
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
@ -47,47 +31,29 @@ type iFaceMapper interface {
|
|||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type ruleset struct {
|
|
||||||
rule *Rule
|
|
||||||
ips map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create iptables firewall manager
|
// Create iptables firewall manager
|
||||||
func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
|
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||||
m := &Manager{
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
wgIface: wgIface,
|
|
||||||
inputDefaultRuleSpecs: []string{
|
|
||||||
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
|
|
||||||
outputDefaultRuleSpecs: []string{
|
|
||||||
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
|
|
||||||
rulesets: make(map[string]ruleset),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ipset.Init()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("init ipset: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// init clients for booth ipv4 and ipv6
|
|
||||||
m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
if ipv6Supported {
|
m := &Manager{
|
||||||
m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
wgIface: wgIface,
|
||||||
if err != nil {
|
ipv4Client: iptablesClient,
|
||||||
log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.ipv4Client == nil && m.ipv6Client == nil {
|
m.router, err = newRouterManager(context, iptablesClient)
|
||||||
return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it")
|
if err != nil {
|
||||||
|
log.Debugf("failed to initialize route related chains: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to initialize ACL manager: %s", err)
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.Reset(); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to reset firewall: %v", err)
|
|
||||||
}
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,159 +62,44 @@ func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
|
|||||||
// Comment will be ignored because some system this feature is not supported
|
// Comment will be ignored because some system this feature is not supported
|
||||||
func (m *Manager) AddFiltering(
|
func (m *Manager) AddFiltering(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol fw.Protocol,
|
protocol firewall.Protocol,
|
||||||
sPort *fw.Port,
|
sPort *firewall.Port,
|
||||||
dPort *fw.Port,
|
dPort *firewall.Port,
|
||||||
direction fw.RuleDirection,
|
direction firewall.RuleDirection,
|
||||||
action fw.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
client, err := m.client(ip)
|
return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var dPortVal, sPortVal string
|
|
||||||
if dPort != nil && dPort.Values != nil {
|
|
||||||
// TODO: we support only one port per rule in current implementation of ACLs
|
|
||||||
dPortVal = strconv.Itoa(dPort.Values[0])
|
|
||||||
}
|
|
||||||
if sPort != nil && sPort.Values != nil {
|
|
||||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
|
||||||
}
|
|
||||||
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
|
|
||||||
|
|
||||||
ruleID := uuid.New().String()
|
|
||||||
|
|
||||||
if ipsetName != "" {
|
|
||||||
rs, rsExists := m.rulesets[ipsetName]
|
|
||||||
if !rsExists {
|
|
||||||
if err := ipset.Flush(ipsetName); err != nil {
|
|
||||||
log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
|
|
||||||
}
|
|
||||||
if err := ipset.Create(ipsetName); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rsExists {
|
|
||||||
// if ruleset already exists it means we already have the firewall rule
|
|
||||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
|
||||||
rs.ips[ip.String()] = ruleID
|
|
||||||
return &Rule{
|
|
||||||
ruleID: ruleID,
|
|
||||||
ipsetName: ipsetName,
|
|
||||||
ip: ip.String(),
|
|
||||||
dst: direction == fw.RuleDirectionOUT,
|
|
||||||
v6: ip.To4() == nil,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
// this is new ipset so we need to create firewall rule for it
|
|
||||||
}
|
|
||||||
|
|
||||||
specs := m.filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
|
|
||||||
|
|
||||||
if direction == fw.RuleDirectionOUT {
|
|
||||||
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("check is output rule already exists: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
return nil, fmt.Errorf("input rule already exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.Insert("filter", ChainOutputFilterName, 1, specs...); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ok, err := client.Exists("filter", ChainInputFilterName, specs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("check is input rule already exists: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
return nil, fmt.Errorf("input rule already exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.Insert("filter", ChainInputFilterName, 1, specs...); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rule := &Rule{
|
|
||||||
ruleID: ruleID,
|
|
||||||
specs: specs,
|
|
||||||
ipsetName: ipsetName,
|
|
||||||
ip: ip.String(),
|
|
||||||
dst: direction == fw.RuleDirectionOUT,
|
|
||||||
v6: ip.To4() == nil,
|
|
||||||
}
|
|
||||||
if ipsetName != "" {
|
|
||||||
// ipset name is defined and it means that this rule was created
|
|
||||||
// for it, need to associate it with ruleset
|
|
||||||
m.rulesets[ipsetName] = ruleset{
|
|
||||||
rule: rule,
|
|
||||||
ips: map[string]string{rule.ip: ruleID},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return rule, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeleteRule from the firewall by rule definition
|
||||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
r, ok := rule.(*Rule)
|
return m.aclMgr.DeleteRule(rule)
|
||||||
if !ok {
|
}
|
||||||
return fmt.Errorf("invalid rule type")
|
|
||||||
}
|
|
||||||
|
|
||||||
client := m.ipv4Client
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
if r.v6 {
|
return true
|
||||||
if m.ipv6Client == nil {
|
}
|
||||||
return fmt.Errorf("ipv6 is not supported")
|
|
||||||
}
|
|
||||||
client = m.ipv6Client
|
|
||||||
}
|
|
||||||
|
|
||||||
if rs, ok := m.rulesets[r.ipsetName]; ok {
|
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||||
// delete IP from ruleset IPs list and ipset
|
m.mutex.Lock()
|
||||||
if _, ok := rs.ips[r.ip]; ok {
|
defer m.mutex.Unlock()
|
||||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
|
||||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
|
||||||
}
|
|
||||||
delete(rs.ips, r.ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// if after delete, set still contains other IPs,
|
return m.router.InsertRoutingRules(pair)
|
||||||
// no need to delete firewall rule and we should exit here
|
}
|
||||||
if len(rs.ips) != 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// we delete last IP from the set, that means we need to delete
|
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||||
// set itself and associated firewall rule too
|
m.mutex.Lock()
|
||||||
delete(m.rulesets, r.ipsetName)
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
return m.router.RemoveRoutingRules(pair)
|
||||||
log.Errorf("delete empty ipset: %v", err)
|
|
||||||
}
|
|
||||||
r = rs.rule
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.dst {
|
|
||||||
return client.Delete("filter", ChainOutputFilterName, r.specs...)
|
|
||||||
}
|
|
||||||
return client.Delete("filter", ChainInputFilterName, r.specs...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
@ -256,223 +107,49 @@ func (m *Manager) Reset() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if err := m.reset(m.ipv4Client, "filter"); err != nil {
|
errAcl := m.aclMgr.Reset()
|
||||||
return fmt.Errorf("clean ipv4 firewall ACL input chain: %w", err)
|
if errAcl != nil {
|
||||||
|
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
|
||||||
}
|
}
|
||||||
if m.ipv6Client != nil {
|
errMgr := m.router.Reset()
|
||||||
if err := m.reset(m.ipv6Client, "filter"); err != nil {
|
if errMgr != nil {
|
||||||
return fmt.Errorf("clean ipv6 firewall ACL input chain: %w", err)
|
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
|
||||||
}
|
return errMgr
|
||||||
}
|
}
|
||||||
|
return errAcl
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// AllowNetbird allows netbird interface traffic
|
||||||
func (m *Manager) AllowNetbird() error {
|
func (m *Manager) AllowNetbird() error {
|
||||||
if m.wgIface.IsUserspaceBind() {
|
if !m.wgIface.IsUserspaceBind() {
|
||||||
_, err := m.AddFiltering(
|
return nil
|
||||||
net.ParseIP("0.0.0.0"),
|
|
||||||
"all",
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
fw.RuleDirectionIN,
|
|
||||||
fw.ActionAccept,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
|
||||||
}
|
|
||||||
_, err = m.AddFiltering(
|
|
||||||
net.ParseIP("0.0.0.0"),
|
|
||||||
"all",
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
fw.RuleDirectionOUT,
|
|
||||||
fw.ActionAccept,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
_, err := m.AddFiltering(
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
"all",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
firewall.RuleDirectionIN,
|
||||||
|
firewall.ActionAccept,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
||||||
|
}
|
||||||
|
_, err = m.AddFiltering(
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
"all",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
firewall.RuleDirectionOUT,
|
||||||
|
firewall.ActionAccept,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
// reset firewall chain, clear it and drop it
|
|
||||||
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
|
||||||
ok, err := client.ChainExists(table, ChainInputFilterName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to check if input chain exists: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
if ok, err := client.Exists("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
|
|
||||||
return err
|
|
||||||
} else if ok {
|
|
||||||
if err := client.Delete("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
|
|
||||||
log.WithError(err).Errorf("failed to delete default input rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = client.ChainExists(table, ChainOutputFilterName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to check if output chain exists: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
if ok, err := client.Exists("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
|
|
||||||
return err
|
|
||||||
} else if ok {
|
|
||||||
if err := client.Delete("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
|
|
||||||
log.WithError(err).Errorf("failed to delete default output rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.ClearAndDeleteChain(table, ChainInputFilterName); err != nil {
|
|
||||||
log.Errorf("failed to clear and delete input chain: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.ClearAndDeleteChain(table, ChainOutputFilterName); err != nil {
|
|
||||||
log.Errorf("failed to clear and delete input chain: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for ipsetName := range m.rulesets {
|
|
||||||
if err := ipset.Flush(ipsetName); err != nil {
|
|
||||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
|
||||||
}
|
|
||||||
if err := ipset.Destroy(ipsetName); err != nil {
|
|
||||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
|
||||||
}
|
|
||||||
delete(m.rulesets, ipsetName)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// filterRuleSpecs returns the specs of a filtering rule
|
|
||||||
func (m *Manager) filterRuleSpecs(
|
|
||||||
ip net.IP, protocol string, sPort, dPort string, direction fw.RuleDirection, action fw.Action, ipsetName string,
|
|
||||||
) (specs []string) {
|
|
||||||
matchByIP := true
|
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
|
||||||
if s := ip.String(); s == "0.0.0.0" || s == "::" {
|
|
||||||
matchByIP = false
|
|
||||||
}
|
|
||||||
switch direction {
|
|
||||||
case fw.RuleDirectionIN:
|
|
||||||
if matchByIP {
|
|
||||||
if ipsetName != "" {
|
|
||||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
|
||||||
} else {
|
|
||||||
specs = append(specs, "-s", ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case fw.RuleDirectionOUT:
|
|
||||||
if matchByIP {
|
|
||||||
if ipsetName != "" {
|
|
||||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
|
||||||
} else {
|
|
||||||
specs = append(specs, "-d", ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if protocol != "all" {
|
|
||||||
specs = append(specs, "-p", protocol)
|
|
||||||
}
|
|
||||||
if sPort != "" {
|
|
||||||
specs = append(specs, "--sport", sPort)
|
|
||||||
}
|
|
||||||
if dPort != "" {
|
|
||||||
specs = append(specs, "--dport", dPort)
|
|
||||||
}
|
|
||||||
return append(specs, "-j", m.actionToStr(action))
|
|
||||||
}
|
|
||||||
|
|
||||||
// rawClient returns corresponding iptables client for the given ip
|
|
||||||
func (m *Manager) rawClient(ip net.IP) (*iptables.IPTables, error) {
|
|
||||||
if ip.To4() != nil {
|
|
||||||
return m.ipv4Client, nil
|
|
||||||
}
|
|
||||||
if m.ipv6Client == nil {
|
|
||||||
return nil, fmt.Errorf("ipv6 is not supported")
|
|
||||||
}
|
|
||||||
return m.ipv6Client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// client returns client with initialized chain and default rules
|
|
||||||
func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
|
|
||||||
client, err := m.rawClient(ip)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := client.ChainExists("filter", ChainInputFilterName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
if err := client.NewChain("filter", ChainInputFilterName); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create input chain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.AppendUnique("filter", ChainInputFilterName, dropAllDefaultRule...); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = client.ChainExists("filter", ChainOutputFilterName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
if err := client.NewChain("filter", ChainOutputFilterName); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create output chain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.AppendUnique("filter", ChainOutputFilterName, dropAllDefaultRule...); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create default drop all in netbird output chain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.AppendUnique("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create output chain jump rule: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) actionToStr(action fw.Action) string {
|
|
||||||
if action == fw.ActionAccept {
|
|
||||||
return "ACCEPT"
|
|
||||||
}
|
|
||||||
return "DROP"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
|
|
||||||
switch {
|
|
||||||
case ipsetName == "":
|
|
||||||
return ""
|
|
||||||
case sPort != "" && dPort != "":
|
|
||||||
return ipsetName + "-sport-dport"
|
|
||||||
case sPort != "":
|
|
||||||
return ipsetName + "-sport"
|
|
||||||
case dPort != "":
|
|
||||||
return ipsetName + "-dport"
|
|
||||||
default:
|
|
||||||
return ipsetName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
@ -9,7 +10,7 @@ import (
|
|||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -55,7 +56,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock, true)
|
manager, err := Create(context.Background(), mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@ -67,17 +68,20 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rule1 fw.Rule
|
var rule1 []fw.Rule
|
||||||
t.Run("add first rule", func(t *testing.T) {
|
t.Run("add first rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.2")
|
ip := net.ParseIP("10.20.0.2")
|
||||||
port := &fw.Port{Values: []int{8080}}
|
port := &fw.Port{Values: []int{8080}}
|
||||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
for _, r := range rule1 {
|
||||||
|
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
||||||
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
var rule2 fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
@ -87,21 +91,28 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
for _, r := range rule2 {
|
||||||
|
rr := r.(*Rule)
|
||||||
|
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
t.Run("delete first rule", func(t *testing.T) {
|
||||||
err := manager.DeleteRule(rule1)
|
for _, r := range rule1 {
|
||||||
require.NoError(t, err, "failed to delete rule")
|
err := manager.DeleteRule(r)
|
||||||
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
err := manager.DeleteRule(rule2)
|
for _, r := range rule2 {
|
||||||
require.NoError(t, err, "failed to delete rule")
|
err := manager.DeleteRule(r)
|
||||||
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
}
|
||||||
|
|
||||||
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
@ -114,11 +125,11 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
|
|
||||||
ok, err := ipv4Client.ChainExists("filter", ChainInputFilterName)
|
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||||
require.NoError(t, err, "failed check chain exists")
|
require.NoError(t, err, "failed check chain exists")
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
require.NoErrorf(t, err, "chain '%v' still exists after Reset", ChainInputFilterName)
|
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -143,7 +154,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock, true)
|
manager, err := Create(context.Background(), mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@ -155,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rule1 fw.Rule
|
var rule1 []fw.Rule
|
||||||
t.Run("add first rule with set", func(t *testing.T) {
|
t.Run("add first rule with set", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.2")
|
ip := net.ParseIP("10.20.0.2")
|
||||||
port := &fw.Port{Values: []int{8080}}
|
port := &fw.Port{Values: []int{8080}}
|
||||||
@ -165,12 +176,14 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
for _, r := range rule1 {
|
||||||
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
||||||
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
||||||
|
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
var rule2 fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
@ -180,23 +193,29 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
||||||
"default", "accept HTTPS traffic from ports range",
|
"default", "accept HTTPS traffic from ports range",
|
||||||
)
|
)
|
||||||
require.NoError(t, err, "failed to add rule")
|
for _, r := range rule2 {
|
||||||
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
|
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
t.Run("delete first rule", func(t *testing.T) {
|
||||||
err := manager.DeleteRule(rule1)
|
for _, r := range rule1 {
|
||||||
require.NoError(t, err, "failed to delete rule")
|
err := manager.DeleteRule(r)
|
||||||
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
err := manager.DeleteRule(rule2)
|
for _, r := range rule2 {
|
||||||
require.NoError(t, err, "failed to delete rule")
|
err := manager.DeleteRule(r)
|
||||||
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
|
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
@ -206,7 +225,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
||||||
require.NoError(t, err, "failed to check rule")
|
require.NoError(t, err, "failed to check rule")
|
||||||
require.Falsef(t, !exists && mustExists, "rule '%v' does not exist", rulespec)
|
require.Falsef(t, !exists && mustExists, "rule '%v' does not exist", rulespec)
|
||||||
@ -232,7 +251,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock, true)
|
manager, err := Create(context.Background(), mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
@ -243,7 +262,6 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err = manager.client(net.ParseIP("10.20.0.100"))
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
ip := net.ParseIP("10.20.0.100")
|
||||||
|
340
client/firewall/iptables/router_linux.go
Normal file
340
client/firewall/iptables/router_linux.go
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
Ipv4Forwarding = "netbird-rt-forwarding"
|
||||||
|
ipv4Nat = "netbird-rt-nat"
|
||||||
|
)
|
||||||
|
|
||||||
|
// constants needed to manage and create iptable rules
|
||||||
|
const (
|
||||||
|
tableFilter = "filter"
|
||||||
|
tableNat = "nat"
|
||||||
|
chainFORWARD = "FORWARD"
|
||||||
|
chainPOSTROUTING = "POSTROUTING"
|
||||||
|
chainRTNAT = "NETBIRD-RT-NAT"
|
||||||
|
chainRTFWD = "NETBIRD-RT-FWD"
|
||||||
|
routingFinalForwardJump = "ACCEPT"
|
||||||
|
routingFinalNatJump = "MASQUERADE"
|
||||||
|
)
|
||||||
|
|
||||||
|
type routerManager struct {
|
||||||
|
ctx context.Context
|
||||||
|
stop context.CancelFunc
|
||||||
|
iptablesClient *iptables.IPTables
|
||||||
|
rules map[string][]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
m := &routerManager{
|
||||||
|
ctx: ctx,
|
||||||
|
stop: cancel,
|
||||||
|
iptablesClient: iptablesClient,
|
||||||
|
rules: make(map[string][]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.cleanUpDefaultForwardRules()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to cleanup routing rules: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = m.createContainers()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create containers for route: %s", err)
|
||||||
|
}
|
||||||
|
return m, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
|
||||||
|
func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||||
|
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pair.Masquerade {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertRoutingRule inserts an iptable rule
|
||||||
|
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||||
|
rule := genRuleSpec(jump, ruleKey, pair.Source, pair.Destination)
|
||||||
|
existingRule, found := i.rules[ruleKey]
|
||||||
|
if found {
|
||||||
|
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||||
|
}
|
||||||
|
delete(i.rules, ruleKey)
|
||||||
|
}
|
||||||
|
err = i.iptablesClient.Insert(table, chain, 1, rule...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
i.rules[ruleKey] = rule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
|
||||||
|
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||||
|
err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pair.Masquerade {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||||
|
existingRule, found := i.rules[ruleKey]
|
||||||
|
if found {
|
||||||
|
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(i.rules, ruleKey)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *routerManager) RouteingFwChainName() string {
|
||||||
|
return chainRTFWD
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *routerManager) Reset() error {
|
||||||
|
err := i.cleanUpDefaultForwardRules()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
i.rules = make(map[string][]string)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *routerManager) cleanUpDefaultForwardRules() error {
|
||||||
|
err := i.cleanJumpRules()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("flushing routing related tables")
|
||||||
|
ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
|
||||||
|
return err
|
||||||
|
} else if ok {
|
||||||
|
err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
|
||||||
|
return err
|
||||||
|
} else if ok {
|
||||||
|
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *routerManager) createContainers() error {
|
||||||
|
if i.rules[Ipv4Forwarding] != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
errMSGFormat := "failed creating chain %s,error: %v"
|
||||||
|
err := i.createChain(tableFilter, chainRTFWD)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.createChain(tableNat, chainRTNAT)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(errMSGFormat, chainRTNAT, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.addJumpRules()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error while creating jump rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addJumpRules create jump rules to send packets to NetBird chains
|
||||||
|
func (i *routerManager) addJumpRules() error {
|
||||||
|
rule := []string{"-j", chainRTFWD}
|
||||||
|
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
i.rules[Ipv4Forwarding] = rule
|
||||||
|
|
||||||
|
rule = []string{"-j", chainRTNAT}
|
||||||
|
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
i.rules[ipv4Nat] = rule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
|
||||||
|
func (i *routerManager) cleanJumpRules() error {
|
||||||
|
var err error
|
||||||
|
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
|
||||||
|
rule, found := i.rules[Ipv4Forwarding]
|
||||||
|
if found {
|
||||||
|
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rule, found = i.rules[ipv4Nat]
|
||||||
|
if found {
|
||||||
|
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list rules: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ruleString := range rules {
|
||||||
|
if !strings.Contains(ruleString, "NETBIRD") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rule := strings.Fields(ruleString)
|
||||||
|
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ruleString := range rules {
|
||||||
|
if !strings.Contains(ruleString, "NETBIRD") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rule := strings.Fields(ruleString)
|
||||||
|
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *routerManager) createChain(table, newChain string) error {
|
||||||
|
chains, err := i.iptablesClient.ListChains(table)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldCreateChain := true
|
||||||
|
for _, chain := range chains {
|
||||||
|
if chain == newChain {
|
||||||
|
shouldCreateChain = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldCreateChain {
|
||||||
|
err = i.iptablesClient.NewChain(table, newChain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// genRuleSpec generates rule specification with comment identifier
|
||||||
|
func genRuleSpec(jump, id, source, destination string) []string {
|
||||||
|
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getIptablesRuleType(table string) string {
|
||||||
|
ruleType := "forwarding"
|
||||||
|
if table == tableNat {
|
||||||
|
ruleType = "nat"
|
||||||
|
}
|
||||||
|
return ruleType
|
||||||
|
}
|
229
client/firewall/iptables/router_linux_test.go
Normal file
229
client/firewall/iptables/router_linux_test.go
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os/exec"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isIptablesSupported() bool {
|
||||||
|
_, err4 := exec.LookPath("iptables")
|
||||||
|
return err4 == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||||
|
if !isIptablesSupported() {
|
||||||
|
t.SkipNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
require.NoError(t, err, "failed to init iptables client")
|
||||||
|
|
||||||
|
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||||
|
require.NoError(t, err, "should return a valid iptables manager")
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = manager.Reset()
|
||||||
|
}()
|
||||||
|
|
||||||
|
require.Len(t, manager.rules, 2, "should have created rules map")
|
||||||
|
|
||||||
|
exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
|
||||||
|
require.True(t, exists, "forwarding rule should exist")
|
||||||
|
|
||||||
|
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||||
|
require.True(t, exists, "postrouting rule should exist")
|
||||||
|
|
||||||
|
pair := firewall.RouterPair{
|
||||||
|
ID: "abc",
|
||||||
|
Source: "100.100.100.1/32",
|
||||||
|
Destination: "100.100.100.0/24",
|
||||||
|
Masquerade: true,
|
||||||
|
}
|
||||||
|
forward4RuleKey := firewall.GenKey(firewall.ForwardingFormat, pair.ID)
|
||||||
|
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
||||||
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
|
nat4RuleKey := firewall.GenKey(firewall.NatFormat, pair.ID)
|
||||||
|
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
||||||
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
|
err = manager.Reset()
|
||||||
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
||||||
|
|
||||||
|
if !isIptablesSupported() {
|
||||||
|
t.SkipNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range test.InsertRuleTestCases {
|
||||||
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
require.NoError(t, err, "failed to init iptables client")
|
||||||
|
|
||||||
|
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||||
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := manager.Reset()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to reset iptables manager: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = manager.InsertRoutingRules(testCase.InputPair)
|
||||||
|
require.NoError(t, err, "forwarding pair should be inserted")
|
||||||
|
|
||||||
|
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||||
|
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||||
|
|
||||||
|
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||||
|
require.True(t, exists, "forwarding rule should exist")
|
||||||
|
|
||||||
|
foundRule, found := manager.rules[forwardRuleKey]
|
||||||
|
require.True(t, found, "forwarding rule should exist in the manager map")
|
||||||
|
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
||||||
|
|
||||||
|
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||||
|
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||||
|
require.True(t, exists, "income forwarding rule should exist")
|
||||||
|
|
||||||
|
foundRule, found = manager.rules[inForwardRuleKey]
|
||||||
|
require.True(t, found, "income forwarding rule should exist in the manager map")
|
||||||
|
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
|
||||||
|
|
||||||
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||||
|
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||||
|
if testCase.InputPair.Masquerade {
|
||||||
|
require.True(t, exists, "nat rule should be created")
|
||||||
|
foundNatRule, foundNat := manager.rules[natRuleKey]
|
||||||
|
require.True(t, foundNat, "nat rule should exist in the map")
|
||||||
|
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
|
||||||
|
} else {
|
||||||
|
require.False(t, exists, "nat rule should not be created")
|
||||||
|
_, foundNat := manager.rules[natRuleKey]
|
||||||
|
require.False(t, foundNat, "nat rule should not exist in the map")
|
||||||
|
}
|
||||||
|
|
||||||
|
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||||
|
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||||
|
if testCase.InputPair.Masquerade {
|
||||||
|
require.True(t, exists, "income nat rule should be created")
|
||||||
|
foundNatRule, foundNat := manager.rules[inNatRuleKey]
|
||||||
|
require.True(t, foundNat, "income nat rule should exist in the map")
|
||||||
|
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
|
||||||
|
} else {
|
||||||
|
require.False(t, exists, "nat rule should not be created")
|
||||||
|
_, foundNat := manager.rules[inNatRuleKey]
|
||||||
|
require.False(t, foundNat, "income nat rule should not exist in the map")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
||||||
|
|
||||||
|
if !isIptablesSupported() {
|
||||||
|
t.SkipNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range test.RemoveRuleTestCases {
|
||||||
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
|
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
|
||||||
|
manager, err := newRouterManager(context.TODO(), iptablesClient)
|
||||||
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
defer func() {
|
||||||
|
_ = manager.Reset()
|
||||||
|
}()
|
||||||
|
|
||||||
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||||
|
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||||
|
|
||||||
|
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
|
||||||
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
|
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||||
|
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||||
|
|
||||||
|
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
|
||||||
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||||
|
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
|
||||||
|
|
||||||
|
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
||||||
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
|
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||||
|
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
|
||||||
|
|
||||||
|
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
||||||
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
|
err = manager.Reset()
|
||||||
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
err = manager.RemoveRoutingRules(testCase.InputPair)
|
||||||
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||||
|
require.False(t, exists, "forwarding rule should not exist")
|
||||||
|
|
||||||
|
_, found := manager.rules[forwardRuleKey]
|
||||||
|
require.False(t, found, "forwarding rule should exist in the manager map")
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
|
||||||
|
require.False(t, exists, "income forwarding rule should not exist")
|
||||||
|
|
||||||
|
_, found = manager.rules[inForwardRuleKey]
|
||||||
|
require.False(t, found, "income forwarding rule should exist in the manager map")
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||||
|
require.False(t, exists, "nat rule should not exist")
|
||||||
|
|
||||||
|
_, found = manager.rules[natRuleKey]
|
||||||
|
require.False(t, found, "nat rule should exist in the manager map")
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||||
|
require.False(t, exists, "income nat rule should not exist")
|
||||||
|
|
||||||
|
_, found = manager.rules[inNatRuleKey]
|
||||||
|
require.False(t, found, "income nat rule should exist in the manager map")
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -7,8 +7,7 @@ type Rule struct {
|
|||||||
|
|
||||||
specs []string
|
specs []string
|
||||||
ip string
|
ip string
|
||||||
dst bool
|
chain string
|
||||||
v6 bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
|
50
client/firewall/iptables/rulestore_linux.go
Normal file
50
client/firewall/iptables/rulestore_linux.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package iptables
|
||||||
|
|
||||||
|
type ipList struct {
|
||||||
|
ips map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIpList(ip string) ipList {
|
||||||
|
ips := make(map[string]struct{})
|
||||||
|
ips[ip] = struct{}{}
|
||||||
|
|
||||||
|
return ipList{
|
||||||
|
ips: ips,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipList) addIP(ip string) {
|
||||||
|
s.ips[ip] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ipsetStore struct {
|
||||||
|
ipsets map[string]ipList // ipsetName -> ruleset
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIpsetStore() *ipsetStore {
|
||||||
|
return &ipsetStore{
|
||||||
|
ipsets: make(map[string]ipList),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
|
||||||
|
r, ok := s.ipsets[ipsetName]
|
||||||
|
return r, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
|
||||||
|
s.ipsets[ipsetName] = list
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||||
|
s.ipsets[ipsetName] = ipList{}
|
||||||
|
delete(s.ipsets, ipsetName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) ipsetNames() []string {
|
||||||
|
names := make([]string, 0, len(s.ipsets))
|
||||||
|
for name := range s.ipsets {
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
@ -1,9 +1,17 @@
|
|||||||
package firewall
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
NatFormat = "netbird-nat-%s"
|
||||||
|
ForwardingFormat = "netbird-fwd-%s"
|
||||||
|
InNatFormat = "netbird-nat-in-%s"
|
||||||
|
InForwardingFormat = "netbird-fwd-in-%s"
|
||||||
|
)
|
||||||
|
|
||||||
// Rule abstraction should be implemented by each firewall manager
|
// Rule abstraction should be implemented by each firewall manager
|
||||||
//
|
//
|
||||||
// Each firewall type for different OS can use different type
|
// Each firewall type for different OS can use different type
|
||||||
@ -27,10 +35,8 @@ const (
|
|||||||
type Action int
|
type Action int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ActionUnknown is a unknown action
|
|
||||||
ActionUnknown Action = iota
|
|
||||||
// ActionAccept is the action to accept a packet
|
// ActionAccept is the action to accept a packet
|
||||||
ActionAccept
|
ActionAccept Action = iota
|
||||||
// ActionDrop is the action to drop a packet
|
// ActionDrop is the action to drop a packet
|
||||||
ActionDrop
|
ActionDrop
|
||||||
)
|
)
|
||||||
@ -56,16 +62,27 @@ type Manager interface {
|
|||||||
action Action,
|
action Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (Rule, error)
|
) ([]Rule, error)
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeleteRule from the firewall by rule definition
|
||||||
DeleteRule(rule Rule) error
|
DeleteRule(rule Rule) error
|
||||||
|
|
||||||
|
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||||
|
IsServerRouteSupported() bool
|
||||||
|
|
||||||
|
// InsertRoutingRules inserts a routing firewall rule
|
||||||
|
InsertRoutingRules(pair RouterPair) error
|
||||||
|
|
||||||
|
// RemoveRoutingRules removes a routing firewall rule
|
||||||
|
RemoveRoutingRules(pair RouterPair) error
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
Reset() error
|
Reset() error
|
||||||
|
|
||||||
// Flush the changes to firewall controller
|
// Flush the changes to firewall controller
|
||||||
Flush() error
|
Flush() error
|
||||||
|
}
|
||||||
// TODO: migrate routemanager firewal actions to this interface
|
|
||||||
|
func GenKey(format string, input string) string {
|
||||||
|
return fmt.Sprintf(format, input)
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package firewall
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
18
client/firewall/manager/routerpair.go
Normal file
18
client/firewall/manager/routerpair.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
type RouterPair struct {
|
||||||
|
ID string
|
||||||
|
Source string
|
||||||
|
Destination string
|
||||||
|
Masquerade bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetInPair(pair RouterPair) RouterPair {
|
||||||
|
return RouterPair{
|
||||||
|
ID: pair.ID,
|
||||||
|
// invert Source/Destination
|
||||||
|
Source: pair.Destination,
|
||||||
|
Destination: pair.Source,
|
||||||
|
Masquerade: pair.Masquerade,
|
||||||
|
}
|
||||||
|
}
|
1121
client/firewall/nftables/acl_linux.go
Normal file
1121
client/firewall/nftables/acl_linux.go
Normal file
File diff suppressed because it is too large
Load Diff
85
client/firewall/nftables/ipsetstore_linux.go
Normal file
85
client/firewall/nftables/ipsetstore_linux.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ipsetStore struct {
|
||||||
|
ipsetReference map[string]int
|
||||||
|
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIpsetStore() *ipsetStore {
|
||||||
|
return &ipsetStore{
|
||||||
|
ipsetReference: make(map[string]int),
|
||||||
|
ipsets: make(map[string]map[string]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
|
||||||
|
r, ok := s.ipsets[ipsetName]
|
||||||
|
return r, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
|
||||||
|
s.ipsetReference[ipsetName] = 0
|
||||||
|
ipList := make(map[string]struct{})
|
||||||
|
s.ipsets[ipsetName] = ipList
|
||||||
|
return ipList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||||
|
delete(s.ipsetReference, ipsetName)
|
||||||
|
delete(s.ipsets, ipsetName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
|
||||||
|
ipList, ok := s.ipsets[ipsetName]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(ipList, ip.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
|
||||||
|
ipList, ok := s.ipsets[ipsetName]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ipList[ip.String()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
|
||||||
|
ipList, ok := s.ipsets[ipsetName]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, ok = ipList[ip.String()]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
|
||||||
|
s.ipsetReference[ipsetName]++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
|
||||||
|
r, ok := s.ipsetReference[ipsetName]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.ipsetReference[ipsetName]--
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
|
||||||
|
if _, ok := s.ipsetReference[ipsetName]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.ipsetReference[ipsetName] == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
@ -2,90 +2,52 @@ package nftables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// FilterTableName is the name of the table that is used for filtering by the Netbird client
|
// tableName is the name of the table that is used for filtering by the Netbird client
|
||||||
FilterTableName = "netbird-acl"
|
tableName = "netbird"
|
||||||
|
|
||||||
// FilterInputChainName is the name of the chain that is used for filtering incoming packets
|
|
||||||
FilterInputChainName = "netbird-acl-input-filter"
|
|
||||||
|
|
||||||
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
|
|
||||||
FilterOutputChainName = "netbird-acl-output-filter"
|
|
||||||
|
|
||||||
AllowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
rConn *nftables.Conn
|
||||||
rConn *nftables.Conn
|
|
||||||
sConn *nftables.Conn
|
|
||||||
tableIPv4 *nftables.Table
|
|
||||||
tableIPv6 *nftables.Table
|
|
||||||
|
|
||||||
filterInputChainIPv4 *nftables.Chain
|
|
||||||
filterOutputChainIPv4 *nftables.Chain
|
|
||||||
|
|
||||||
filterInputChainIPv6 *nftables.Chain
|
|
||||||
filterOutputChainIPv6 *nftables.Chain
|
|
||||||
|
|
||||||
rulesetManager *rulesetManager
|
|
||||||
setRemovedIPs map[string]struct{}
|
|
||||||
setRemoved map[string]*nftables.Set
|
|
||||||
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
}
|
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
router *router
|
||||||
type iFaceMapper interface {
|
aclManager *AclManager
|
||||||
Name() string
|
|
||||||
Address() iface.WGAddress
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create nftables firewall manager
|
// Create nftables firewall manager
|
||||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
||||||
// sConn is used for creating sets and adding/removing elements from them
|
m := &Manager{
|
||||||
// it's differ then rConn (which does create new conn for each flush operation)
|
rConn: &nftables.Conn{},
|
||||||
// and is permanent. Using same connection for booth type of operations
|
wgIface: wgIface,
|
||||||
// overloads netlink with high amount of rules ( > 10000)
|
}
|
||||||
sConn, err := nftables.New(nftables.AsLasting())
|
|
||||||
|
workTable, err := m.createWorkTable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m := &Manager{
|
m.router, err = newRouter(context, workTable)
|
||||||
rConn: &nftables.Conn{},
|
if err != nil {
|
||||||
sConn: sConn,
|
return nil, err
|
||||||
|
|
||||||
rulesetManager: newRuleManager(),
|
|
||||||
setRemovedIPs: map[string]struct{}{},
|
|
||||||
setRemoved: map[string]*nftables.Set{},
|
|
||||||
|
|
||||||
wgIface: wgIface,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.Reset(); err != nil {
|
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,649 +60,58 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
|||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *Manager) AddFiltering(
|
func (m *Manager) AddFiltering(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto fw.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *fw.Port,
|
sPort *firewall.Port,
|
||||||
dPort *fw.Port,
|
dPort *firewall.Port,
|
||||||
direction fw.RuleDirection,
|
direction firewall.RuleDirection,
|
||||||
action fw.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
var (
|
|
||||||
err error
|
|
||||||
ipset *nftables.Set
|
|
||||||
table *nftables.Table
|
|
||||||
chain *nftables.Chain
|
|
||||||
)
|
|
||||||
|
|
||||||
if direction == fw.RuleDirectionOUT {
|
|
||||||
table, chain, err = m.chain(
|
|
||||||
ip,
|
|
||||||
FilterOutputChainName,
|
|
||||||
nftables.ChainHookOutput,
|
|
||||||
nftables.ChainPriorityFilter,
|
|
||||||
nftables.ChainTypeFilter)
|
|
||||||
} else {
|
|
||||||
table, chain, err = m.chain(
|
|
||||||
ip,
|
|
||||||
FilterInputChainName,
|
|
||||||
nftables.ChainHookInput,
|
|
||||||
nftables.ChainPriorityFilter,
|
|
||||||
nftables.ChainTypeFilter)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rawIP := ip.To4()
|
rawIP := ip.To4()
|
||||||
if rawIP == nil {
|
if rawIP == nil {
|
||||||
rawIP = ip.To16()
|
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
|
return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
||||||
|
|
||||||
if ipsetName != "" {
|
|
||||||
// if we already have set with given name, just add ip to the set
|
|
||||||
// and return rule with new ID in other case let's create rule
|
|
||||||
// with fresh created set and set element
|
|
||||||
|
|
||||||
var isSetNew bool
|
|
||||||
ipset, err = m.rConn.GetSetByName(table, ipsetName)
|
|
||||||
if err != nil {
|
|
||||||
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
|
|
||||||
return nil, fmt.Errorf("get set name: %v", err)
|
|
||||||
}
|
|
||||||
isSetNew = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
|
||||||
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
|
||||||
}
|
|
||||||
if err := m.sConn.Flush(); err != nil {
|
|
||||||
return nil, fmt.Errorf("flush add elements: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isSetNew {
|
|
||||||
// if we already have nftables rules with set for given direction
|
|
||||||
// just add new rule to the ruleset and return new fw.Rule object
|
|
||||||
|
|
||||||
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
|
|
||||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
|
||||||
}
|
|
||||||
// if ipset exists but it is not linked to rule for given direction
|
|
||||||
// create new rule for direction and bind ipset to it later
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceKey := expr.MetaKeyIIFNAME
|
|
||||||
if direction == fw.RuleDirectionOUT {
|
|
||||||
ifaceKey = expr.MetaKeyOIFNAME
|
|
||||||
}
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if proto != "all" {
|
|
||||||
expressions = append(expressions, &expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: uint32(9),
|
|
||||||
Len: uint32(1),
|
|
||||||
})
|
|
||||||
|
|
||||||
var protoData []byte
|
|
||||||
switch proto {
|
|
||||||
case fw.ProtocolTCP:
|
|
||||||
protoData = []byte{unix.IPPROTO_TCP}
|
|
||||||
case fw.ProtocolUDP:
|
|
||||||
protoData = []byte{unix.IPPROTO_UDP}
|
|
||||||
case fw.ProtocolICMP:
|
|
||||||
protoData = []byte{unix.IPPROTO_ICMP}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported protocol: %s", proto)
|
|
||||||
}
|
|
||||||
expressions = append(expressions, &expr.Cmp{
|
|
||||||
Register: 1,
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Data: protoData,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
|
|
||||||
// in that case not add IP match expression into the rule definition
|
|
||||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
|
||||||
// source address position
|
|
||||||
addrLen := uint32(len(rawIP))
|
|
||||||
addrOffset := uint32(12)
|
|
||||||
if addrLen == 16 {
|
|
||||||
addrOffset = 8
|
|
||||||
}
|
|
||||||
|
|
||||||
// change to destination address position if need
|
|
||||||
if direction == fw.RuleDirectionOUT {
|
|
||||||
addrOffset += addrLen
|
|
||||||
}
|
|
||||||
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: addrOffset,
|
|
||||||
Len: addrLen,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
// add individual IP for match if no ipset defined
|
|
||||||
if ipset == nil {
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: rawIP,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Lookup{
|
|
||||||
SourceRegister: 1,
|
|
||||||
SetName: ipsetName,
|
|
||||||
SetID: ipset.ID,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if sPort != nil && len(sPort.Values) != 0 {
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 0,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: encodePort(*sPort),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dPort != nil && len(dPort.Values) != 0 {
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 2,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: encodePort(*dPort),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if action == fw.ActionAccept {
|
|
||||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
|
||||||
} else {
|
|
||||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
|
||||||
}
|
|
||||||
|
|
||||||
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
|
|
||||||
|
|
||||||
rule := m.rConn.InsertRule(&nftables.Rule{
|
|
||||||
Table: table,
|
|
||||||
Chain: chain,
|
|
||||||
Position: 0,
|
|
||||||
Exprs: expressions,
|
|
||||||
UserData: userData,
|
|
||||||
})
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return nil, fmt.Errorf("flush insert rule: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
|
|
||||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRulesetID returns ruleset ID based on given parameters
|
|
||||||
func (m *Manager) getRulesetID(
|
|
||||||
ip net.IP,
|
|
||||||
proto fw.Protocol,
|
|
||||||
sPort *fw.Port,
|
|
||||||
dPort *fw.Port,
|
|
||||||
direction fw.RuleDirection,
|
|
||||||
action fw.Action,
|
|
||||||
ipsetName string,
|
|
||||||
) string {
|
|
||||||
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
|
||||||
if sPort != nil {
|
|
||||||
rulesetID += sPort.String()
|
|
||||||
}
|
|
||||||
rulesetID += ":"
|
|
||||||
if dPort != nil {
|
|
||||||
rulesetID += dPort.String()
|
|
||||||
}
|
|
||||||
rulesetID += ":"
|
|
||||||
rulesetID += strconv.Itoa(int(action))
|
|
||||||
if ipsetName == "" {
|
|
||||||
return "ip:" + ip.String() + rulesetID
|
|
||||||
}
|
|
||||||
return "set:" + ipsetName + rulesetID
|
|
||||||
}
|
|
||||||
|
|
||||||
// createSet in given table by name
|
|
||||||
func (m *Manager) createSet(
|
|
||||||
table *nftables.Table,
|
|
||||||
rawIP []byte,
|
|
||||||
name string,
|
|
||||||
) (*nftables.Set, error) {
|
|
||||||
keyType := nftables.TypeIPAddr
|
|
||||||
if len(rawIP) == 16 {
|
|
||||||
keyType = nftables.TypeIP6Addr
|
|
||||||
}
|
|
||||||
// else we create new ipset and continue creating rule
|
|
||||||
ipset := &nftables.Set{
|
|
||||||
Name: name,
|
|
||||||
Table: table,
|
|
||||||
Dynamic: true,
|
|
||||||
KeyType: keyType,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
|
||||||
return nil, fmt.Errorf("create set: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return nil, fmt.Errorf("flush created set: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ipset, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// chain returns the chain for the given IP address with specific settings
|
|
||||||
func (m *Manager) chain(
|
|
||||||
ip net.IP,
|
|
||||||
name string,
|
|
||||||
hook nftables.ChainHook,
|
|
||||||
priority nftables.ChainPriority,
|
|
||||||
cType nftables.ChainType,
|
|
||||||
) (*nftables.Table, *nftables.Chain, error) {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
getChain := func(c *nftables.Chain, tf nftables.TableFamily) (*nftables.Chain, error) {
|
|
||||||
if c != nil {
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip.To4() != nil {
|
|
||||||
if name == FilterInputChainName {
|
|
||||||
m.filterInputChainIPv4, err = getChain(m.filterInputChainIPv4, nftables.TableFamilyIPv4)
|
|
||||||
return m.tableIPv4, m.filterInputChainIPv4, err
|
|
||||||
}
|
|
||||||
m.filterOutputChainIPv4, err = getChain(m.filterOutputChainIPv4, nftables.TableFamilyIPv4)
|
|
||||||
return m.tableIPv4, m.filterOutputChainIPv4, err
|
|
||||||
}
|
|
||||||
if name == FilterInputChainName {
|
|
||||||
m.filterInputChainIPv6, err = getChain(m.filterInputChainIPv6, nftables.TableFamilyIPv6)
|
|
||||||
return m.tableIPv4, m.filterInputChainIPv6, err
|
|
||||||
}
|
|
||||||
m.filterOutputChainIPv6, err = getChain(m.filterOutputChainIPv6, nftables.TableFamilyIPv6)
|
|
||||||
return m.tableIPv4, m.filterOutputChainIPv6, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// table returns the table for the given family of the IP address
|
|
||||||
func (m *Manager) table(
|
|
||||||
family nftables.TableFamily, tableName string,
|
|
||||||
) (*nftables.Table, error) {
|
|
||||||
// we cache access to Netbird ACL table only
|
|
||||||
if tableName != FilterTableName {
|
|
||||||
return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
|
|
||||||
}
|
|
||||||
|
|
||||||
if family == nftables.TableFamilyIPv4 {
|
|
||||||
if m.tableIPv4 != nil {
|
|
||||||
return m.tableIPv4, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m.tableIPv4 = table
|
|
||||||
return m.tableIPv4, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.tableIPv6 != nil {
|
|
||||||
return m.tableIPv6, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m.tableIPv6 = table
|
|
||||||
return m.tableIPv6, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) createTableIfNotExists(
|
|
||||||
family nftables.TableFamily, tableName string,
|
|
||||||
) (*nftables.Table, error) {
|
|
||||||
tables, err := m.rConn.ListTablesOfFamily(family)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("list of tables: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range tables {
|
|
||||||
if t.Name == tableName {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return table, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) createChainIfNotExists(
|
|
||||||
family nftables.TableFamily,
|
|
||||||
tableName string,
|
|
||||||
name string,
|
|
||||||
hooknum nftables.ChainHook,
|
|
||||||
priority nftables.ChainPriority,
|
|
||||||
chainType nftables.ChainType,
|
|
||||||
) (*nftables.Chain, error) {
|
|
||||||
table, err := m.table(family, tableName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
chains, err := m.rConn.ListChainsOfTableFamily(family)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("list of chains: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range chains {
|
|
||||||
if c.Name == name && c.Table.Name == table.Name {
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
polAccept := nftables.ChainPolicyAccept
|
|
||||||
chain := &nftables.Chain{
|
|
||||||
Name: name,
|
|
||||||
Table: table,
|
|
||||||
Hooknum: hooknum,
|
|
||||||
Priority: priority,
|
|
||||||
Type: chainType,
|
|
||||||
Policy: &polAccept,
|
|
||||||
}
|
|
||||||
|
|
||||||
chain = m.rConn.AddChain(chain)
|
|
||||||
|
|
||||||
ifaceKey := expr.MetaKeyIIFNAME
|
|
||||||
shiftDSTAddr := 0
|
|
||||||
if name == FilterOutputChainName {
|
|
||||||
ifaceKey = expr.MetaKeyOIFNAME
|
|
||||||
shiftDSTAddr = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
mask, _ := netip.AddrFromSlice(m.wgIface.Address().Network.Mask)
|
|
||||||
if m.wgIface.Address().IP.To4() == nil {
|
|
||||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To16())
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: uint32(8 + (16 * shiftDSTAddr)),
|
|
||||||
Len: 16,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 16,
|
|
||||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
|
||||||
Mask: mask.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 2,
|
|
||||||
Data: ip.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: uint32(12 + (4 * shiftDSTAddr)),
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 4,
|
|
||||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
|
||||||
Mask: m.wgIface.Address().Network.Mask,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 2,
|
|
||||||
Data: ip.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
|
|
||||||
expressions = []expr.Any{
|
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
|
||||||
}
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return chain, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeleteRule from the firewall by rule definition
|
||||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
nativeRule, ok := rule.(*Rule)
|
return m.aclManager.DeleteRule(rule)
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("invalid rule type")
|
|
||||||
}
|
|
||||||
|
|
||||||
if nativeRule.nftRule == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if nativeRule.nftSet != nil {
|
|
||||||
// call twice of delete set element raises error
|
|
||||||
// so we need to check if element is already removed
|
|
||||||
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
|
|
||||||
if _, ok := m.setRemovedIPs[key]; !ok {
|
|
||||||
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
|
|
||||||
}
|
|
||||||
if err := m.sConn.Flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.setRemovedIPs[key] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.rulesetManager.deleteRule(nativeRule) {
|
|
||||||
// deleteRule indicates that we still have IP in the ruleset
|
|
||||||
// it means we should not remove the nftables rule but need to update set
|
|
||||||
// so we prepare IP to be removed from set on the next flush call
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
|
|
||||||
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v", err)
|
|
||||||
}
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
nativeRule.nftRule = nil
|
|
||||||
|
|
||||||
if nativeRule.nftSet != nil {
|
|
||||||
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
|
|
||||||
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
|
|
||||||
}
|
|
||||||
nativeRule.nftSet = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
func (m *Manager) Reset() error {
|
return true
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
chains, err := m.rConn.ListChains()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list of chains: %w", err)
|
|
||||||
}
|
|
||||||
for _, c := range chains {
|
|
||||||
// delete Netbird allow input traffic rule if it exists
|
|
||||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
|
||||||
rules, err := m.rConn.GetRules(c.Table, c)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, r := range rules {
|
|
||||||
if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) {
|
|
||||||
if err := m.rConn.DelRule(r); err != nil {
|
|
||||||
log.Errorf("delete rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
|
||||||
m.rConn.DelChain(c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tables, err := m.rConn.ListTables()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list of tables: %w", err)
|
|
||||||
}
|
|
||||||
for _, t := range tables {
|
|
||||||
if t.Name == FilterTableName {
|
|
||||||
m.rConn.DelTable(t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.rConn.Flush()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush rule/chain/set operations from the buffer
|
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||||
//
|
|
||||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
|
||||||
func (m *Manager) Flush() error {
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if err := m.flushWithBackoff(); err != nil {
|
return m.router.InsertRoutingRules(pair)
|
||||||
return err
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// set must be removed after flush rule changes
|
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||||
// otherwise we will get error
|
m.mutex.Lock()
|
||||||
for _, s := range m.setRemoved {
|
defer m.mutex.Unlock()
|
||||||
m.rConn.FlushSet(s)
|
|
||||||
m.rConn.DelSet(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(m.setRemoved) > 0 {
|
return m.router.RemoveRoutingRules(pair)
|
||||||
if err := m.flushWithBackoff(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.setRemovedIPs = map[string]struct{}{}
|
|
||||||
m.setRemoved = map[string]*nftables.Set{}
|
|
||||||
|
|
||||||
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
|
|
||||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
|
|
||||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
|
|
||||||
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
|
|
||||||
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
// todo review this method usage
|
||||||
func (m *Manager) AllowNetbird() error {
|
func (m *Manager) AllowNetbird() error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
tf := nftables.TableFamilyIPv4
|
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||||
if m.wgIface.Address().IP.To4() == nil {
|
|
||||||
tf = nftables.TableFamilyIPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
chains, err := m.rConn.ListChainsOfTableFamily(tf)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list of chains: %w", err)
|
return fmt.Errorf("list of chains: %w", err)
|
||||||
}
|
}
|
||||||
@ -777,47 +148,75 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) flushWithBackoff() (err error) {
|
// Reset firewall to the default state
|
||||||
backoff := 4
|
func (m *Manager) Reset() error {
|
||||||
backoffTime := 1000 * time.Millisecond
|
m.mutex.Lock()
|
||||||
for i := 0; ; i++ {
|
defer m.mutex.Unlock()
|
||||||
err = m.rConn.Flush()
|
|
||||||
if err != nil {
|
chains, err := m.rConn.ListChains()
|
||||||
if !strings.Contains(err.Error(), "busy") {
|
if err != nil {
|
||||||
return
|
return fmt.Errorf("list of chains: %w", err)
|
||||||
}
|
|
||||||
log.Error("failed to flush nftables, retrying...")
|
|
||||||
if i == backoff-1 {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
time.Sleep(backoffTime)
|
|
||||||
backoffTime *= 2
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
return
|
|
||||||
|
for _, c := range chains {
|
||||||
|
// delete Netbird allow input traffic rule if it exists
|
||||||
|
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||||
|
rules, err := m.rConn.GetRules(c.Table, c)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, r := range rules {
|
||||||
|
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
||||||
|
if err := m.rConn.DelRule(r); err != nil {
|
||||||
|
log.Errorf("delete rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.router.ResetForwardRules()
|
||||||
|
|
||||||
|
tables, err := m.rConn.ListTables()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list of tables: %w", err)
|
||||||
|
}
|
||||||
|
for _, t := range tables {
|
||||||
|
if t.Name == tableName {
|
||||||
|
m.rConn.DelTable(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.rConn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
|
// Flush rule/chain/set operations from the buffer
|
||||||
if table == nil || chain == nil {
|
//
|
||||||
return nil
|
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||||
}
|
// todo review this method usage
|
||||||
|
func (m *Manager) Flush() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
list, err := m.rConn.GetRules(table, chain)
|
return m.aclManager.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||||
|
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("list of tables: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, rule := range list {
|
for _, t := range tables {
|
||||||
if len(rule.UserData) != 0 {
|
if t.Name == tableName {
|
||||||
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
|
m.rConn.DelTable(t)
|
||||||
log.Errorf("failed to set rule handle: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||||
|
err = m.rConn.Flush()
|
||||||
|
return table, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||||
@ -835,7 +234,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
|||||||
Kind: expr.VerdictAccept,
|
Kind: expr.VerdictAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
UserData: []byte(AllowNetbirdInputRuleID),
|
UserData: []byte(allowNetbirdInputRuleID),
|
||||||
}
|
}
|
||||||
_ = m.rConn.InsertRule(rule)
|
_ = m.rConn.InsertRule(rule)
|
||||||
}
|
}
|
||||||
@ -857,15 +256,3 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodePort(port fw.Port) []byte {
|
|
||||||
bs := make([]byte, 2)
|
|
||||||
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
|
|
||||||
return bs
|
|
||||||
}
|
|
||||||
|
|
||||||
func ifname(n string) []byte {
|
|
||||||
b := make([]byte, 16)
|
|
||||||
copy(b, []byte(n+"\x00"))
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@ -12,7 +13,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -53,7 +54,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(context.Background(), mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
@ -82,14 +83,10 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
require.NoError(t, err, "failed to flush")
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
|
|
||||||
// test expectations:
|
require.Len(t, rules, 1, "expected 1 rules")
|
||||||
// 1) regular rule
|
|
||||||
// 2) "accept extra routed traffic rule" for the interface
|
|
||||||
// 3) "drop all rule" for the interface
|
|
||||||
require.Len(t, rules, 3, "expected 3 rules")
|
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||||
add := ipToAdd.Unmap()
|
add := ipToAdd.Unmap()
|
||||||
@ -137,18 +134,17 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
|
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
|
||||||
|
|
||||||
err = manager.DeleteRule(rule)
|
for _, r := range rule {
|
||||||
require.NoError(t, err, "failed to delete rule")
|
err = manager.DeleteRule(r)
|
||||||
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
}
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
require.NoError(t, err, "failed to flush")
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
// test expectations:
|
require.Len(t, rules, 0, "expected 0 rules after deletion")
|
||||||
// 1) "accept extra routed traffic rule" for the interface
|
|
||||||
// 2) "drop all rule" for the interface
|
|
||||||
require.Len(t, rules, 2, "expected 2 rules after deletion")
|
|
||||||
|
|
||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
@ -173,7 +169,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(context.Background(), mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
|
413
client/firewall/nftables/route_linux.go
Normal file
413
client/firewall/nftables/route_linux.go
Normal file
@ -0,0 +1,413 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/binaryutil"
|
||||||
|
"github.com/google/nftables/expr"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
chainNameRouteingFw = "netbird-rt-fwd"
|
||||||
|
chainNameRoutingNat = "netbird-rt-nat"
|
||||||
|
|
||||||
|
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
||||||
|
userDataAcceptForwardRuleDst = "frwacceptdst"
|
||||||
|
)
|
||||||
|
|
||||||
|
// some presets for building nftable rules
|
||||||
|
var (
|
||||||
|
zeroXor = binaryutil.NativeEndian.PutUint32(0)
|
||||||
|
|
||||||
|
exprCounterAccept = []expr.Any{
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
type router struct {
|
||||||
|
ctx context.Context
|
||||||
|
stop context.CancelFunc
|
||||||
|
conn *nftables.Conn
|
||||||
|
workTable *nftables.Table
|
||||||
|
filterTable *nftables.Table
|
||||||
|
chains map[string]*nftables.Chain
|
||||||
|
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||||
|
rules map[string]*nftables.Rule
|
||||||
|
isDefaultFwdRulesEnabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
|
||||||
|
r := &router{
|
||||||
|
ctx: ctx,
|
||||||
|
stop: cancel,
|
||||||
|
conn: &nftables.Conn{},
|
||||||
|
workTable: workTable,
|
||||||
|
chains: make(map[string]*nftables.Chain),
|
||||||
|
rules: make(map[string]*nftables.Rule),
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
r.filterTable, err = r.loadFilterTable()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errFilterTableNotFound) {
|
||||||
|
log.Warnf("table 'filter' not found for forward rules")
|
||||||
|
} else {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.cleanUpDefaultForwardRules()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.createContainers()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create containers for route: %s", err)
|
||||||
|
}
|
||||||
|
return r, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) RouteingFwChainName() string {
|
||||||
|
return chainNameRouteingFw
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetForwardRules cleans existing nftables default forward rules from the system
|
||||||
|
func (r *router) ResetForwardRules() {
|
||||||
|
err := r.cleanUpDefaultForwardRules()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to reset forward rules: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
|
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, table := range tables {
|
||||||
|
if table.Name == "filter" {
|
||||||
|
return table, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errFilterTableNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) createContainers() error {
|
||||||
|
|
||||||
|
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameRouteingFw,
|
||||||
|
Table: r.workTable,
|
||||||
|
})
|
||||||
|
|
||||||
|
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameRoutingNat,
|
||||||
|
Table: r.workTable,
|
||||||
|
Hooknum: nftables.ChainHookPostrouting,
|
||||||
|
Priority: nftables.ChainPriorityNATSource - 1,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
})
|
||||||
|
|
||||||
|
err := r.refreshRulesMap()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.conn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||||
|
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
|
||||||
|
err := r.refreshRulesMap()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair.Masquerade {
|
||||||
|
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
|
||||||
|
log.Debugf("add default accept forward rule")
|
||||||
|
r.acceptForwardRule(pair.Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.conn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertRoutingRule inserts a nftable rule to the conn client flush queue
|
||||||
|
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
||||||
|
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||||
|
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||||
|
|
||||||
|
var expression []expr.Any
|
||||||
|
if isNat {
|
||||||
|
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
|
||||||
|
} else {
|
||||||
|
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := manager.GenKey(format, pair.ID)
|
||||||
|
|
||||||
|
_, exists := r.rules[ruleKey]
|
||||||
|
if exists {
|
||||||
|
err := r.removeRoutingRule(format, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainName],
|
||||||
|
Exprs: expression,
|
||||||
|
UserData: []byte(ruleKey),
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) acceptForwardRule(sourceNetwork string) {
|
||||||
|
src := generateCIDRMatcherExpressions(true, sourceNetwork)
|
||||||
|
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
|
||||||
|
|
||||||
|
var exprs []expr.Any
|
||||||
|
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
})...)
|
||||||
|
|
||||||
|
rule := &nftables.Rule{
|
||||||
|
Table: r.filterTable,
|
||||||
|
Chain: &nftables.Chain{
|
||||||
|
Name: "FORWARD",
|
||||||
|
Table: r.filterTable,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
},
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(userDataAcceptForwardRuleSrc),
|
||||||
|
}
|
||||||
|
|
||||||
|
r.conn.AddRule(rule)
|
||||||
|
|
||||||
|
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
|
||||||
|
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
|
||||||
|
|
||||||
|
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
})...)
|
||||||
|
|
||||||
|
rule = &nftables.Rule{
|
||||||
|
Table: r.filterTable,
|
||||||
|
Chain: &nftables.Chain{
|
||||||
|
Name: "FORWARD",
|
||||||
|
Table: r.filterTable,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
},
|
||||||
|
Exprs: exprs,
|
||||||
|
UserData: []byte(userDataAcceptForwardRuleDst),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(rule)
|
||||||
|
r.isDefaultFwdRulesEnabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
|
||||||
|
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
|
||||||
|
err := r.refreshRulesMap()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.removeRoutingRule(manager.NatFormat, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.rules) == 0 {
|
||||||
|
err := r.cleanUpDefaultForwardRules()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.conn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||||
|
}
|
||||||
|
log.Debugf("nftables: removed rules for %s", pair.Destination)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
|
||||||
|
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
|
||||||
|
ruleKey := manager.GenKey(format, pair.ID)
|
||||||
|
|
||||||
|
rule, found := r.rules[ruleKey]
|
||||||
|
if found {
|
||||||
|
ruleType := "forwarding"
|
||||||
|
if rule.Chain.Type == nftables.ChainTypeNAT {
|
||||||
|
ruleType = "nat"
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.conn.DelRule(rule)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
||||||
|
// duplicates and to get missing attributes that we don't have when adding new rules
|
||||||
|
func (r *router) refreshRulesMap() error {
|
||||||
|
for _, chain := range r.chains {
|
||||||
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
||||||
|
}
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 {
|
||||||
|
r.rules[string(rule.UserData)] = rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) cleanUpDefaultForwardRules() error {
|
||||||
|
if r.filterTable == nil {
|
||||||
|
r.isDefaultFwdRulesEnabled = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var rules []*nftables.Rule
|
||||||
|
for _, chain := range chains {
|
||||||
|
if chain.Table.Name != r.filterTable.Name {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if chain.Name != "FORWARD" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err = r.conn.GetRules(r.filterTable, chain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range rules {
|
||||||
|
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
|
||||||
|
err := r.conn.DelRule(rule)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.isDefaultFwdRulesEnabled = false
|
||||||
|
return r.conn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||||
|
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
|
||||||
|
ip, network, _ := net.ParseCIDR(cidr)
|
||||||
|
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||||
|
add := ipToAdd.Unmap()
|
||||||
|
|
||||||
|
var offSet uint32
|
||||||
|
if source {
|
||||||
|
offSet = 12 // src offset
|
||||||
|
} else {
|
||||||
|
offSet = 16 // dst offset
|
||||||
|
}
|
||||||
|
|
||||||
|
return []expr.Any{
|
||||||
|
// fetch src add
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: offSet,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
// net mask
|
||||||
|
&expr.Bitwise{
|
||||||
|
DestRegister: 1,
|
||||||
|
SourceRegister: 1,
|
||||||
|
Len: 4,
|
||||||
|
Mask: network.Mask,
|
||||||
|
Xor: zeroXor,
|
||||||
|
},
|
||||||
|
// net address
|
||||||
|
&expr.Cmp{
|
||||||
|
Register: 1,
|
||||||
|
Data: add.AsSlice(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
280
client/firewall/nftables/router_linux_test.go
Normal file
280
client/firewall/nftables/router_linux_test.go
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/expr"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// UNKNOWN is the default value for the firewall type for unknown firewall type
|
||||||
|
UNKNOWN = iota
|
||||||
|
// IPTABLES is the value for the iptables firewall type
|
||||||
|
IPTABLES
|
||||||
|
// NFTABLES is the value for the nftables firewall type
|
||||||
|
NFTABLES
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this OS")
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := createWorkTable()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
for _, testCase := range test.InsertRuleTestCases {
|
||||||
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
|
manager, err := newRouter(context.TODO(), table)
|
||||||
|
require.NoError(t, err, "failed to create router")
|
||||||
|
|
||||||
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
|
defer manager.ResetForwardRules()
|
||||||
|
|
||||||
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
err = manager.InsertRoutingRules(testCase.InputPair)
|
||||||
|
defer func() {
|
||||||
|
_ = manager.RemoveRoutingRules(testCase.InputPair)
|
||||||
|
}()
|
||||||
|
require.NoError(t, err, "forwarding pair should be inserted")
|
||||||
|
|
||||||
|
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||||
|
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||||
|
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||||
|
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||||
|
|
||||||
|
found := 0
|
||||||
|
for _, chain := range manager.chains {
|
||||||
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
|
||||||
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
|
||||||
|
found = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||||
|
|
||||||
|
if testCase.InputPair.Masquerade {
|
||||||
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||||
|
found := 0
|
||||||
|
for _, chain := range manager.chains {
|
||||||
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
|
||||||
|
found = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
||||||
|
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
||||||
|
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
|
||||||
|
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||||
|
|
||||||
|
found = 0
|
||||||
|
for _, chain := range manager.chains {
|
||||||
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
|
||||||
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
|
||||||
|
found = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||||
|
|
||||||
|
if testCase.InputPair.Masquerade {
|
||||||
|
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||||
|
found := 0
|
||||||
|
for _, chain := range manager.chains {
|
||||||
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
|
||||||
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
|
||||||
|
found = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this OS")
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := createWorkTable()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
for _, testCase := range test.RemoveRuleTestCases {
|
||||||
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
|
manager, err := newRouter(context.TODO(), table)
|
||||||
|
require.NoError(t, err, "failed to create router")
|
||||||
|
|
||||||
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
|
defer manager.ResetForwardRules()
|
||||||
|
|
||||||
|
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||||
|
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||||
|
|
||||||
|
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
||||||
|
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
|
||||||
|
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||||
|
Table: manager.workTable,
|
||||||
|
Chain: manager.chains[chainNameRouteingFw],
|
||||||
|
Exprs: forwardExp,
|
||||||
|
UserData: []byte(forwardRuleKey),
|
||||||
|
})
|
||||||
|
|
||||||
|
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||||
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
|
||||||
|
|
||||||
|
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||||
|
Table: manager.workTable,
|
||||||
|
Chain: manager.chains[chainNameRoutingNat],
|
||||||
|
Exprs: natExp,
|
||||||
|
UserData: []byte(natRuleKey),
|
||||||
|
})
|
||||||
|
|
||||||
|
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
|
||||||
|
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
|
||||||
|
|
||||||
|
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
||||||
|
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
|
||||||
|
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||||
|
Table: manager.workTable,
|
||||||
|
Chain: manager.chains[chainNameRouteingFw],
|
||||||
|
Exprs: forwardExp,
|
||||||
|
UserData: []byte(inForwardRuleKey),
|
||||||
|
})
|
||||||
|
|
||||||
|
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||||
|
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
|
||||||
|
|
||||||
|
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||||
|
Table: manager.workTable,
|
||||||
|
Chain: manager.chains[chainNameRoutingNat],
|
||||||
|
Exprs: natExp,
|
||||||
|
UserData: []byte(inNatRuleKey),
|
||||||
|
})
|
||||||
|
|
||||||
|
err = nftablesTestingClient.Flush()
|
||||||
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
manager.ResetForwardRules()
|
||||||
|
|
||||||
|
err = manager.RemoveRoutingRules(testCase.InputPair)
|
||||||
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
|
for _, chain := range manager.chains {
|
||||||
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 {
|
||||||
|
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
|
||||||
|
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
||||||
|
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
|
||||||
|
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||||
|
func check() int {
|
||||||
|
nf := nftables.Conn{}
|
||||||
|
if _, err := nf.ListChains(); err == nil {
|
||||||
|
return NFTABLES
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
if err != nil {
|
||||||
|
return UNKNOWN
|
||||||
|
}
|
||||||
|
if isIptablesClientAvailable(ip) {
|
||||||
|
return IPTABLES
|
||||||
|
}
|
||||||
|
|
||||||
|
return UNKNOWN
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||||
|
_, err := client.ListChains("filter")
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createWorkTable() (*nftables.Table, error) {
|
||||||
|
sConn, err := nftables.New(nftables.AsLasting())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range tables {
|
||||||
|
if t.Name == tableName {
|
||||||
|
sConn.DelTable(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||||
|
err = sConn.Flush()
|
||||||
|
|
||||||
|
return table, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func deleteWorkTable() {
|
||||||
|
sConn, err := nftables.New(nftables.AsLasting())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range tables {
|
||||||
|
if t.Name == tableName {
|
||||||
|
sConn.DelTable(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,8 @@
|
|||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -8,9 +10,8 @@ import (
|
|||||||
type Rule struct {
|
type Rule struct {
|
||||||
nftRule *nftables.Rule
|
nftRule *nftables.Rule
|
||||||
nftSet *nftables.Set
|
nftSet *nftables.Set
|
||||||
|
ruleID string
|
||||||
ruleID string
|
ip net.IP
|
||||||
ip []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
|
@ -1,115 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/rs/xid"
|
|
||||||
)
|
|
||||||
|
|
||||||
// nftRuleset links native firewall rule and ipset to ACL generated rules
|
|
||||||
type nftRuleset struct {
|
|
||||||
nftRule *nftables.Rule
|
|
||||||
nftSet *nftables.Set
|
|
||||||
issuedRules map[string]*Rule
|
|
||||||
rulesetID string
|
|
||||||
}
|
|
||||||
|
|
||||||
type rulesetManager struct {
|
|
||||||
rulesets map[string]*nftRuleset
|
|
||||||
|
|
||||||
nftSetName2rulesetID map[string]string
|
|
||||||
issuedRuleID2rulesetID map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRuleManager() *rulesetManager {
|
|
||||||
return &rulesetManager{
|
|
||||||
rulesets: map[string]*nftRuleset{},
|
|
||||||
|
|
||||||
nftSetName2rulesetID: map[string]string{},
|
|
||||||
issuedRuleID2rulesetID: map[string]string{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
|
|
||||||
ruleset, ok := r.rulesets[rulesetID]
|
|
||||||
return ruleset, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rulesetManager) createRuleset(
|
|
||||||
rulesetID string,
|
|
||||||
nftRule *nftables.Rule,
|
|
||||||
nftSet *nftables.Set,
|
|
||||||
) *nftRuleset {
|
|
||||||
ruleset := nftRuleset{
|
|
||||||
rulesetID: rulesetID,
|
|
||||||
nftRule: nftRule,
|
|
||||||
nftSet: nftSet,
|
|
||||||
issuedRules: map[string]*Rule{},
|
|
||||||
}
|
|
||||||
r.rulesets[ruleset.rulesetID] = &ruleset
|
|
||||||
if nftSet != nil {
|
|
||||||
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
|
|
||||||
}
|
|
||||||
return &ruleset
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rulesetManager) addRule(
|
|
||||||
ruleset *nftRuleset,
|
|
||||||
ip []byte,
|
|
||||||
) (*Rule, error) {
|
|
||||||
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
|
|
||||||
return nil, fmt.Errorf("ruleset not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
rule := Rule{
|
|
||||||
nftRule: ruleset.nftRule,
|
|
||||||
nftSet: ruleset.nftSet,
|
|
||||||
ruleID: xid.New().String(),
|
|
||||||
ip: ip,
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleset.issuedRules[rule.ruleID] = &rule
|
|
||||||
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
|
|
||||||
|
|
||||||
return &rule, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteRule from ruleset and returns true if contains other rules
|
|
||||||
func (r *rulesetManager) deleteRule(rule *Rule) bool {
|
|
||||||
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleset := r.rulesets[rulesetID]
|
|
||||||
if ruleset.nftRule == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
delete(r.issuedRuleID2rulesetID, rule.ruleID)
|
|
||||||
delete(ruleset.issuedRules, rule.ruleID)
|
|
||||||
|
|
||||||
if len(ruleset.issuedRules) == 0 {
|
|
||||||
delete(r.rulesets, ruleset.rulesetID)
|
|
||||||
if rule.nftSet != nil {
|
|
||||||
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
|
|
||||||
//
|
|
||||||
// This is important to do, because after we add rule to the nftables we can't update it until
|
|
||||||
// we set correct handle value to it.
|
|
||||||
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
|
|
||||||
split := bytes.Split(nftRule.UserData, []byte(" "))
|
|
||||||
ruleset, ok := r.rulesets[string(split[0])]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("ruleset not found")
|
|
||||||
}
|
|
||||||
*ruleset.nftRule = *nftRule
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,122 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRulesetManager_createRuleset(t *testing.T) {
|
|
||||||
// Create a ruleset manager.
|
|
||||||
rulesetManager := newRuleManager()
|
|
||||||
|
|
||||||
// Create a ruleset.
|
|
||||||
rulesetID := "ruleset-1"
|
|
||||||
nftRule := nftables.Rule{
|
|
||||||
UserData: []byte(rulesetID),
|
|
||||||
}
|
|
||||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
|
||||||
require.NotNil(t, ruleset, "createRuleset() failed")
|
|
||||||
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
|
|
||||||
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRulesetManager_addRule(t *testing.T) {
|
|
||||||
// Create a ruleset manager.
|
|
||||||
rulesetManager := newRuleManager()
|
|
||||||
|
|
||||||
// Create a ruleset.
|
|
||||||
rulesetID := "ruleset-1"
|
|
||||||
nftRule := nftables.Rule{}
|
|
||||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
|
||||||
|
|
||||||
// Add a rule to the ruleset.
|
|
||||||
ip := []byte("192.168.1.1")
|
|
||||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
|
||||||
require.NoError(t, err, "addRule() failed")
|
|
||||||
require.NotNil(t, rule, "rule should not be nil")
|
|
||||||
require.NotEqual(t, rule.ruleID, "ruleID is empty")
|
|
||||||
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
|
|
||||||
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
|
|
||||||
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
|
|
||||||
|
|
||||||
ruleset2 := &nftRuleset{
|
|
||||||
rulesetID: "ruleset-2",
|
|
||||||
}
|
|
||||||
_, err = rulesetManager.addRule(ruleset2, ip)
|
|
||||||
require.Error(t, err, "addRule() should have failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRulesetManager_deleteRule(t *testing.T) {
|
|
||||||
// Create a ruleset manager.
|
|
||||||
rulesetManager := newRuleManager()
|
|
||||||
|
|
||||||
// Create a ruleset.
|
|
||||||
rulesetID := "ruleset-1"
|
|
||||||
nftRule := nftables.Rule{}
|
|
||||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
|
||||||
|
|
||||||
// Add a rule to the ruleset.
|
|
||||||
ip := []byte("192.168.1.1")
|
|
||||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
|
||||||
require.NoError(t, err, "addRule() failed")
|
|
||||||
require.NotNil(t, rule, "rule should not be nil")
|
|
||||||
|
|
||||||
ip2 := []byte("192.168.1.1")
|
|
||||||
rule2, err := rulesetManager.addRule(ruleset, ip2)
|
|
||||||
require.NoError(t, err, "addRule() failed")
|
|
||||||
require.NotNil(t, rule2, "rule should not be nil")
|
|
||||||
|
|
||||||
hasNext := rulesetManager.deleteRule(rule)
|
|
||||||
require.True(t, hasNext, "deleteRule() should have returned true")
|
|
||||||
|
|
||||||
// Check that the rule is no longer in the manager.
|
|
||||||
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
|
|
||||||
|
|
||||||
hasNext = rulesetManager.deleteRule(rule2)
|
|
||||||
require.False(t, hasNext, "deleteRule() should have returned false")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
|
|
||||||
// Create a ruleset manager.
|
|
||||||
rulesetManager := newRuleManager()
|
|
||||||
// Create a ruleset.
|
|
||||||
rulesetID := "ruleset-1"
|
|
||||||
nftRule := nftables.Rule{}
|
|
||||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
|
||||||
// Add a rule to the ruleset.
|
|
||||||
ip := []byte("192.168.0.1")
|
|
||||||
|
|
||||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
|
||||||
require.NoError(t, err, "addRule() failed")
|
|
||||||
require.NotNil(t, rule, "rule should not be nil")
|
|
||||||
|
|
||||||
nftRuleCopy := nftRule
|
|
||||||
nftRuleCopy.Handle = 2
|
|
||||||
nftRuleCopy.UserData = []byte(rulesetID)
|
|
||||||
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
|
|
||||||
require.NoError(t, err, "setNftRuleHandle() failed")
|
|
||||||
// check correct work with references
|
|
||||||
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRulesetManager_getRuleset(t *testing.T) {
|
|
||||||
// Create a ruleset manager.
|
|
||||||
rulesetManager := newRuleManager()
|
|
||||||
// Create a ruleset.
|
|
||||||
rulesetID := "ruleset-1"
|
|
||||||
nftRule := nftables.Rule{}
|
|
||||||
nftSet := nftables.Set{
|
|
||||||
ID: 2,
|
|
||||||
}
|
|
||||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
|
|
||||||
require.NotNil(t, ruleset, "createRuleset() failed")
|
|
||||||
|
|
||||||
find, ok := rulesetManager.getRuleset(rulesetID)
|
|
||||||
require.True(t, ok, "getRuleset() failed")
|
|
||||||
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
|
|
||||||
|
|
||||||
_, ok = rulesetManager.getRuleset("does-not-exist")
|
|
||||||
require.False(t, ok, "getRuleset() failed")
|
|
||||||
}
|
|
47
client/firewall/test/cases_linux.go
Normal file
47
client/firewall/test/cases_linux.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package test
|
||||||
|
|
||||||
|
import firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
|
||||||
|
var (
|
||||||
|
InsertRuleTestCases = []struct {
|
||||||
|
Name string
|
||||||
|
InputPair firewall.RouterPair
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "Insert Forwarding IPV4 Rule",
|
||||||
|
InputPair: firewall.RouterPair{
|
||||||
|
ID: "zxa",
|
||||||
|
Source: "100.100.100.1/32",
|
||||||
|
Destination: "100.100.200.0/24",
|
||||||
|
Masquerade: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Insert Forwarding And Nat IPV4 Rules",
|
||||||
|
InputPair: firewall.RouterPair{
|
||||||
|
ID: "zxa",
|
||||||
|
Source: "100.100.100.1/32",
|
||||||
|
Destination: "100.100.200.0/24",
|
||||||
|
Masquerade: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RemoveRuleTestCases = []struct {
|
||||||
|
Name string
|
||||||
|
InputPair firewall.RouterPair
|
||||||
|
IpVersion string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "Remove Forwarding And Nat IPV4 Rules",
|
||||||
|
InputPair: firewall.RouterPair{
|
||||||
|
ID: "zxa",
|
||||||
|
Source: "100.100.100.1/32",
|
||||||
|
Destination: "100.100.200.0/24",
|
||||||
|
Masquerade: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
@ -1,4 +1,4 @@
|
|||||||
//go:build !windows && !linux
|
//go:build !windows
|
||||||
|
|
||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
@ -10,10 +10,16 @@ func (m *Manager) Reset() error {
|
|||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[string]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
|
|
||||||
|
if m.nativeFirewall != nil {
|
||||||
|
return m.nativeFirewall.Reset()
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// AllowNetbird allows netbird interface traffic
|
||||||
func (m *Manager) AllowNetbird() error {
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
if m.nativeFirewall != nil {
|
||||||
|
return m.nativeFirewall.AllowNetbird()
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1,21 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
|
||||||
func (m *Manager) AllowNetbird() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset firewall to the default state
|
|
||||||
func (m *Manager) Reset() error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
|
||||||
|
|
||||||
if m.resetHook != nil {
|
|
||||||
return m.resetHook()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -5,7 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Rule to handle management of rules
|
// Rule to handle management of rules
|
||||||
@ -15,7 +15,7 @@ type Rule struct {
|
|||||||
ipLayer gopacket.LayerType
|
ipLayer gopacket.LayerType
|
||||||
matchByIP bool
|
matchByIP bool
|
||||||
protoLayer gopacket.LayerType
|
protoLayer gopacket.LayerType
|
||||||
direction fw.RuleDirection
|
direction firewall.RuleDirection
|
||||||
sPort uint16
|
sPort uint16
|
||||||
dPort uint16
|
dPort uint16
|
||||||
drop bool
|
drop bool
|
||||||
|
@ -10,12 +10,16 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
const layerTypeAll = 0
|
const layerTypeAll = 0
|
||||||
|
|
||||||
|
var (
|
||||||
|
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
||||||
|
)
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
SetFilter(iface.PacketFilter) error
|
SetFilter(iface.PacketFilter) error
|
||||||
@ -27,12 +31,12 @@ type RuleSet map[string]Rule
|
|||||||
|
|
||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
outgoingRules map[string]RuleSet
|
outgoingRules map[string]RuleSet
|
||||||
incomingRules map[string]RuleSet
|
incomingRules map[string]RuleSet
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface IFaceMapper
|
wgIface IFaceMapper
|
||||||
resetHook func() error
|
nativeFirewall firewall.Manager
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
@ -52,6 +56,20 @@ type decoder struct {
|
|||||||
|
|
||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface IFaceMapper) (*Manager, error) {
|
func Create(iface IFaceMapper) (*Manager, error) {
|
||||||
|
return create(iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
||||||
|
mgr, err := create(iface)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.nativeFirewall = nativeFirewall
|
||||||
|
return mgr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func create(iface IFaceMapper) (*Manager, error) {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
decoders: sync.Pool{
|
decoders: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
@ -77,27 +95,50 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return errRouteNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.InsertRoutingRules(pair)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveRoutingRules removes a routing firewall rule
|
||||||
|
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return errRouteNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.RemoveRoutingRules(pair)
|
||||||
|
}
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
// AddFiltering rule to the firewall
|
||||||
//
|
//
|
||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *Manager) AddFiltering(
|
func (m *Manager) AddFiltering(
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto fw.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *fw.Port,
|
sPort *firewall.Port,
|
||||||
dPort *fw.Port,
|
dPort *firewall.Port,
|
||||||
direction fw.RuleDirection,
|
direction firewall.RuleDirection,
|
||||||
action fw.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
r := Rule{
|
r := Rule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
ip: ip,
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
matchByIP: true,
|
matchByIP: true,
|
||||||
direction: direction,
|
direction: direction,
|
||||||
drop: action == fw.ActionDrop,
|
drop: action == firewall.ActionDrop,
|
||||||
comment: comment,
|
comment: comment,
|
||||||
}
|
}
|
||||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
if ipNormalized := ip.To4(); ipNormalized != nil {
|
||||||
@ -118,21 +159,21 @@ func (m *Manager) AddFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch proto {
|
switch proto {
|
||||||
case fw.ProtocolTCP:
|
case firewall.ProtocolTCP:
|
||||||
r.protoLayer = layers.LayerTypeTCP
|
r.protoLayer = layers.LayerTypeTCP
|
||||||
case fw.ProtocolUDP:
|
case firewall.ProtocolUDP:
|
||||||
r.protoLayer = layers.LayerTypeUDP
|
r.protoLayer = layers.LayerTypeUDP
|
||||||
case fw.ProtocolICMP:
|
case firewall.ProtocolICMP:
|
||||||
r.protoLayer = layers.LayerTypeICMPv4
|
r.protoLayer = layers.LayerTypeICMPv4
|
||||||
if r.ipLayer == layers.LayerTypeIPv6 {
|
if r.ipLayer == layers.LayerTypeIPv6 {
|
||||||
r.protoLayer = layers.LayerTypeICMPv6
|
r.protoLayer = layers.LayerTypeICMPv6
|
||||||
}
|
}
|
||||||
case fw.ProtocolALL:
|
case firewall.ProtocolALL:
|
||||||
r.protoLayer = layerTypeAll
|
r.protoLayer = layerTypeAll
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if direction == fw.RuleDirectionIN {
|
if direction == firewall.RuleDirectionIN {
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||||
}
|
}
|
||||||
@ -144,12 +185,11 @@ func (m *Manager) AddFiltering(
|
|||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
m.outgoingRules[r.ip.String()][r.id] = r
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
return []firewall.Rule{&r}, nil
|
||||||
return &r, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeleteRule from the firewall by rule definition
|
||||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
func (m *Manager) DeleteRule(rule firewall.Rule) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@ -158,7 +198,7 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
|||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.direction == fw.RuleDirectionIN {
|
if r.direction == firewall.RuleDirectionIN {
|
||||||
_, ok := m.incomingRules[r.ip.String()][r.id]
|
_, ok := m.incomingRules[r.ip.String()][r.id]
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
@ -322,7 +362,7 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
protoLayer: layers.LayerTypeUDP,
|
protoLayer: layers.LayerTypeUDP,
|
||||||
dPort: dPort,
|
dPort: dPort,
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
direction: fw.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
||||||
udpHook: hook,
|
udpHook: hook,
|
||||||
}
|
}
|
||||||
@ -333,7 +373,7 @@ func (m *Manager) AddUDPPacketHook(
|
|||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if in {
|
if in {
|
||||||
r.direction = fw.RuleDirectionIN
|
r.direction = firewall.RuleDirectionIN
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||||
}
|
}
|
||||||
@ -370,8 +410,3 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
|||||||
}
|
}
|
||||||
return fmt.Errorf("hook with given id not found")
|
return fmt.Errorf("hook with given id not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetResetHook which will be executed in the end of Reset method
|
|
||||||
func (m *Manager) SetResetHook(hook func() error) {
|
|
||||||
m.resetHook = hook
|
|
||||||
}
|
|
||||||
|
@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -125,24 +125,32 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.DeleteRule(rule)
|
for _, r := range rule {
|
||||||
if err != nil {
|
err = m.DeleteRule(r)
|
||||||
t.Errorf("failed to delete rule: %v", err)
|
if err != nil {
|
||||||
return
|
t.Errorf("failed to delete rule: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
|
for _, r := range rule2 {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
||||||
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.DeleteRule(rule2)
|
for _, r := range rule2 {
|
||||||
if err != nil {
|
err = m.DeleteRule(r)
|
||||||
t.Errorf("failed to delete rule: %v", err)
|
if err != nil {
|
||||||
return
|
t.Errorf("failed to delete rule: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
|
for _, r := range rule2 {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
|
||||||
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,42 +11,27 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
|
||||||
type IFaceMapper interface {
|
|
||||||
Name() string
|
|
||||||
Address() iface.WGAddress
|
|
||||||
IsUserspaceBind() bool
|
|
||||||
SetFilter(iface.PacketFilter) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Manager is a ACL rules manager
|
// Manager is a ACL rules manager
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
||||||
Stop()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultManager uses firewall manager to handle
|
// DefaultManager uses firewall manager to handle
|
||||||
type DefaultManager struct {
|
type DefaultManager struct {
|
||||||
manager firewall.Manager
|
firewall firewall.Manager
|
||||||
ipsetCounter int
|
ipsetCounter int
|
||||||
rulesPairs map[string][]firewall.Rule
|
rulesPairs map[string][]firewall.Rule
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type ipsetInfo struct {
|
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||||
name string
|
|
||||||
ipCount int
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDefaultManager(fm firewall.Manager) *DefaultManager {
|
|
||||||
return &DefaultManager{
|
return &DefaultManager{
|
||||||
manager: fm,
|
firewall: fm,
|
||||||
rulesPairs: make(map[string][]firewall.Rule),
|
rulesPairs: make(map[string][]firewall.Rule),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -69,13 +54,13 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
time.Since(start), total)
|
time.Since(start), total)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if d.manager == nil {
|
if d.firewall == nil {
|
||||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := d.manager.Flush(); err != nil {
|
if err := d.firewall.Flush(); err != nil {
|
||||||
log.Error("failed to flush firewall rules: ", err)
|
log.Error("failed to flush firewall rules: ", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -125,57 +110,35 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
applyFailed := false
|
|
||||||
newRulePairs := make(map[string][]firewall.Rule)
|
newRulePairs := make(map[string][]firewall.Rule)
|
||||||
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
|
ipsetByRuleSelectors := make(map[string]string)
|
||||||
|
|
||||||
// calculate which IP's can be grouped in by which ipset
|
|
||||||
// to do that we use rule selector (which is just rule properties without IP's)
|
|
||||||
for _, r := range rules {
|
|
||||||
selector := d.getRuleGroupingSelector(r)
|
|
||||||
ipset, ok := ipsetByRuleSelectors[selector]
|
|
||||||
if !ok {
|
|
||||||
ipset = &ipsetInfo{}
|
|
||||||
}
|
|
||||||
|
|
||||||
ipset.ipCount++
|
|
||||||
ipsetByRuleSelectors[selector] = ipset
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||||
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
|
selector := d.getRuleGroupingSelector(r)
|
||||||
if ipset.name == "" {
|
ipsetName, ok := ipsetByRuleSelectors[selector]
|
||||||
|
if !ok {
|
||||||
d.ipsetCounter++
|
d.ipsetCounter++
|
||||||
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||||
|
ipsetByRuleSelectors[selector] = ipsetName
|
||||||
}
|
}
|
||||||
ipsetName := ipset.name
|
|
||||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||||
applyFailed = true
|
d.rollBack(newRulePairs)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
newRulePairs[pairID] = rulePair
|
if len(rules) > 0 {
|
||||||
}
|
d.rulesPairs[pairID] = rulePair
|
||||||
if applyFailed {
|
newRulePairs[pairID] = rulePair
|
||||||
log.Error("failed to apply firewall rules, rollback ACL to previous state")
|
|
||||||
for _, rules := range newRulePairs {
|
|
||||||
for _, rule := range rules {
|
|
||||||
if err := d.manager.DeleteRule(rule); err != nil {
|
|
||||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for pairID, rules := range d.rulesPairs {
|
for pairID, rules := range d.rulesPairs {
|
||||||
if _, ok := newRulePairs[pairID]; !ok {
|
if _, ok := newRulePairs[pairID]; !ok {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := d.manager.DeleteRule(rule); err != nil {
|
if err := d.firewall.DeleteRule(rule); err != nil {
|
||||||
log.Errorf("failed to delete firewall rule: %v", err)
|
log.Errorf("failed to delete firewall rule: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -186,16 +149,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
d.rulesPairs = newRulePairs
|
d.rulesPairs = newRulePairs
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop ACL controller and clear firewall state
|
|
||||||
func (d *DefaultManager) Stop() {
|
|
||||||
d.mutex.Lock()
|
|
||||||
defer d.mutex.Unlock()
|
|
||||||
|
|
||||||
if err := d.manager.Reset(); err != nil {
|
|
||||||
log.WithError(err).Error("reset firewall state")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||||
r *mgmProto.FirewallRule,
|
r *mgmProto.FirewallRule,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
@ -205,14 +158,14 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol := convertToFirewallProtocol(r.Protocol)
|
protocol, err := convertToFirewallProtocol(r.Protocol)
|
||||||
if protocol == firewall.ProtocolUnknown {
|
if err != nil {
|
||||||
return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
|
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
action := convertFirewallAction(r.Action)
|
action, err := convertFirewallAction(r.Action)
|
||||||
if action == firewall.ActionUnknown {
|
if err != nil {
|
||||||
return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
|
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var port *firewall.Port
|
var port *firewall.Port
|
||||||
@ -232,7 +185,6 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
var err error
|
|
||||||
switch r.Direction {
|
switch r.Direction {
|
||||||
case mgmProto.FirewallRule_IN:
|
case mgmProto.FirewallRule_IN:
|
||||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||||
@ -246,7 +198,6 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
d.rulesPairs[ruleID] = rules
|
|
||||||
return ruleID, rules, nil
|
return ruleID, rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -259,24 +210,24 @@ func (d *DefaultManager) addInRules(
|
|||||||
comment string,
|
comment string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
rule, err := d.manager.AddFiltering(
|
rule, err := d.firewall.AddFiltering(
|
||||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
rules = append(rules, rule)
|
rules = append(rules, rule...)
|
||||||
|
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
if shouldSkipInvertedRule(protocol, port) {
|
||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.manager.AddFiltering(
|
rule, err = d.firewall.AddFiltering(
|
||||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return append(rules, rule), nil
|
return append(rules, rule...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addOutRules(
|
func (d *DefaultManager) addOutRules(
|
||||||
@ -288,24 +239,24 @@ func (d *DefaultManager) addOutRules(
|
|||||||
comment string,
|
comment string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
rule, err := d.manager.AddFiltering(
|
rule, err := d.firewall.AddFiltering(
|
||||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
rules = append(rules, rule)
|
rules = append(rules, rule...)
|
||||||
|
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
if shouldSkipInvertedRule(protocol, port) {
|
||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.manager.AddFiltering(
|
rule, err = d.firewall.AddFiltering(
|
||||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return append(rules, rule), nil
|
return append(rules, rule...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRuleID() returns unique ID for the rule based on its parameters.
|
// getRuleID() returns unique ID for the rule based on its parameters.
|
||||||
@ -461,18 +412,29 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
|
|||||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
|
||||||
|
log.Debugf("rollback ACL to previous state")
|
||||||
|
for _, rules := range newRulePairs {
|
||||||
|
for _, rule := range rules {
|
||||||
|
if err := d.firewall.DeleteRule(rule); err != nil {
|
||||||
|
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case mgmProto.FirewallRule_TCP:
|
case mgmProto.FirewallRule_TCP:
|
||||||
return firewall.ProtocolTCP
|
return firewall.ProtocolTCP, nil
|
||||||
case mgmProto.FirewallRule_UDP:
|
case mgmProto.FirewallRule_UDP:
|
||||||
return firewall.ProtocolUDP
|
return firewall.ProtocolUDP, nil
|
||||||
case mgmProto.FirewallRule_ICMP:
|
case mgmProto.FirewallRule_ICMP:
|
||||||
return firewall.ProtocolICMP
|
return firewall.ProtocolICMP, nil
|
||||||
case mgmProto.FirewallRule_ALL:
|
case mgmProto.FirewallRule_ALL:
|
||||||
return firewall.ProtocolALL
|
return firewall.ProtocolALL, nil
|
||||||
default:
|
default:
|
||||||
return firewall.ProtocolUnknown
|
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -480,13 +442,13 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
|
|||||||
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
|
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertFirewallAction(action mgmProto.FirewallRuleAction) firewall.Action {
|
func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
|
||||||
switch action {
|
switch action {
|
||||||
case mgmProto.FirewallRule_ACCEPT:
|
case mgmProto.FirewallRule_ACCEPT:
|
||||||
return firewall.ActionAccept
|
return firewall.ActionAccept, nil
|
||||||
case mgmProto.FirewallRule_DROP:
|
case mgmProto.FirewallRule_DROP:
|
||||||
return firewall.ActionDrop
|
return firewall.ActionDrop, nil
|
||||||
default:
|
default:
|
||||||
return firewall.ActionUnknown
|
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,28 +0,0 @@
|
|||||||
//go:build !linux || android
|
|
||||||
|
|
||||||
package acl
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Create creates a firewall manager instance
|
|
||||||
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
|
||||||
if iface.IsUserspaceBind() {
|
|
||||||
// use userspace packet filtering firewall
|
|
||||||
fm, err := uspfilter.Create(iface)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := fm.AllowNetbird(); err != nil {
|
|
||||||
log.Warnf("failed to allow netbird interface traffic: %v", err)
|
|
||||||
}
|
|
||||||
return newDefaultManager(fm), nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
|
||||||
}
|
|
@ -1,77 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package acl
|
|
||||||
|
|
||||||
import (
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/iptables"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/nftables"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/checkfw"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Create creates a firewall manager instance for the Linux
|
|
||||||
func Create(iface IFaceMapper) (*DefaultManager, error) {
|
|
||||||
// on the linux system we try to user nftables or iptables
|
|
||||||
// in any case, because we need to allow netbird interface traffic
|
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
|
||||||
// for the userspace packet filtering firewall
|
|
||||||
var fm firewall.Manager
|
|
||||||
var err error
|
|
||||||
|
|
||||||
checkResult := checkfw.Check()
|
|
||||||
switch checkResult {
|
|
||||||
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
|
|
||||||
log.Debug("creating an iptables firewall manager for access control")
|
|
||||||
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
|
|
||||||
if fm, err = iptables.Create(iface, ipv6Supported); err != nil {
|
|
||||||
log.Infof("failed to create iptables manager for access control: %s", err)
|
|
||||||
}
|
|
||||||
case checkfw.NFTABLES:
|
|
||||||
log.Debug("creating an nftables firewall manager for access control")
|
|
||||||
if fm, err = nftables.Create(iface); err != nil {
|
|
||||||
log.Debugf("failed to create nftables manager for access control: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var resetHookForUserspace func() error
|
|
||||||
if fm != nil && err == nil {
|
|
||||||
// err shadowing is used here, to ignore this error
|
|
||||||
if err := fm.AllowNetbird(); err != nil {
|
|
||||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
|
||||||
}
|
|
||||||
resetHookForUserspace = fm.Reset
|
|
||||||
}
|
|
||||||
|
|
||||||
if iface.IsUserspaceBind() {
|
|
||||||
// use userspace packet filtering firewall
|
|
||||||
usfm, err := uspfilter.Create(iface)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create userspace filtering firewall: %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// set kernel space firewall Reset as hook for userspace firewall
|
|
||||||
// manager Reset method, to clean up
|
|
||||||
if resetHookForUserspace != nil {
|
|
||||||
usfm.SetResetHook(resetHookForUserspace)
|
|
||||||
}
|
|
||||||
|
|
||||||
// to be consistent for any future extensions.
|
|
||||||
// ignore this error
|
|
||||||
if err := usfm.AllowNetbird(); err != nil {
|
|
||||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
|
||||||
}
|
|
||||||
fm = usfm
|
|
||||||
}
|
|
||||||
|
|
||||||
if fm == nil || err != nil {
|
|
||||||
log.Errorf("failed to create firewall manager: %s", err)
|
|
||||||
// no firewall manager found or initialized correctly
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return newDefaultManager(fm), nil
|
|
||||||
}
|
|
@ -1,11 +1,14 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
@ -49,12 +52,15 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
acl, err := Create(ifaceMock)
|
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create ACL manager: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer acl.Stop()
|
defer func(fw manager.Manager) {
|
||||||
|
_ = fw.Reset()
|
||||||
|
}(fw)
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
t.Run("apply firewall rules", func(t *testing.T) {
|
t.Run("apply firewall rules", func(t *testing.T) {
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap)
|
||||||
@ -339,12 +345,15 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
acl, err := Create(ifaceMock)
|
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create ACL manager: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer acl.Stop()
|
defer func(fw manager.Manager) {
|
||||||
|
_ = fw.Reset()
|
||||||
|
}(fw)
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap)
|
||||||
|
|
||||||
|
@ -1,3 +0,0 @@
|
|||||||
//go:build !linux || android
|
|
||||||
|
|
||||||
package checkfw
|
|
@ -1,56 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package checkfw
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
"github.com/google/nftables"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// UNKNOWN is the default value for the firewall type for unknown firewall type
|
|
||||||
UNKNOWN FWType = iota
|
|
||||||
// IPTABLES is the value for the iptables firewall type
|
|
||||||
IPTABLES
|
|
||||||
// IPTABLESWITHV6 is the value for the iptables firewall type with ipv6
|
|
||||||
IPTABLESWITHV6
|
|
||||||
// NFTABLES is the value for the nftables firewall type
|
|
||||||
NFTABLES
|
|
||||||
)
|
|
||||||
|
|
||||||
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
|
|
||||||
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|
||||||
|
|
||||||
// FWType is the type for the firewall type
|
|
||||||
type FWType int
|
|
||||||
|
|
||||||
// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
|
||||||
func Check() FWType {
|
|
||||||
nf := nftables.Conn{}
|
|
||||||
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
|
||||||
return NFTABLES
|
|
||||||
}
|
|
||||||
|
|
||||||
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
if err == nil {
|
|
||||||
if isIptablesClientAvailable(ip) {
|
|
||||||
ipSupport := IPTABLES
|
|
||||||
ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
|
||||||
if ip6Err == nil {
|
|
||||||
if isIptablesClientAvailable(ipv6) {
|
|
||||||
ipSupport = IPTABLESWITHV6
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ipSupport
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return UNKNOWN
|
|
||||||
}
|
|
||||||
|
|
||||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
|
||||||
_, err := client.ListChains("filter")
|
|
||||||
return err == nil
|
|
||||||
}
|
|
@ -25,13 +25,30 @@ const (
|
|||||||
|
|
||||||
type osManagerType int
|
type osManagerType int
|
||||||
|
|
||||||
|
func (t osManagerType) String() string {
|
||||||
|
switch t {
|
||||||
|
case netbirdManager:
|
||||||
|
return "netbird"
|
||||||
|
case fileManager:
|
||||||
|
return "file"
|
||||||
|
case networkManager:
|
||||||
|
return "networkManager"
|
||||||
|
case systemdManager:
|
||||||
|
return "systemd"
|
||||||
|
case resolvConfManager:
|
||||||
|
return "resolvconf"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||||
osManager, err := getOSDNSManagerType()
|
osManager, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("discovered mode is: %d", osManager)
|
log.Debugf("discovered mode is: %s", osManager)
|
||||||
switch osManager {
|
switch osManager {
|
||||||
case networkManager:
|
case networkManager:
|
||||||
return newNetworkManagerDbusConfigurator(wgInterface)
|
return newNetworkManagerDbusConfigurator(wgInterface)
|
||||||
@ -65,7 +82,6 @@ func getOSDNSManagerType() (osManagerType, error) {
|
|||||||
return netbirdManager, nil
|
return netbirdManager, nil
|
||||||
}
|
}
|
||||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||||
log.Debugf("is nm running on supported v? %t", isNetworkManagerSupportedVersion())
|
|
||||||
return networkManager, nil
|
return networkManager, nil
|
||||||
}
|
}
|
||||||
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||||
|
@ -17,6 +17,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@ -115,6 +117,7 @@ type Engine struct {
|
|||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
|
firewall manager.Manager
|
||||||
routeManager routemanager.Manager
|
routeManager routemanager.Manager
|
||||||
acl acl.Manager
|
acl acl.Manager
|
||||||
|
|
||||||
@ -231,6 +234,19 @@ func (e *Engine) Start() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed creating firewall manager: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.firewall != nil && e.firewall.IsServerRouteSupported() {
|
||||||
|
err = e.routeManager.EnableServerRouter(e.firewall)
|
||||||
|
if err != nil {
|
||||||
|
e.close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
|
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIFaceName, err.Error())
|
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIFaceName, err.Error())
|
||||||
@ -258,10 +274,8 @@ func (e *Engine) Start() error {
|
|||||||
e.udpMux = mux
|
e.udpMux = mux
|
||||||
}
|
}
|
||||||
|
|
||||||
if acl, err := acl.Create(e.wgInterface); err != nil {
|
if e.firewall != nil {
|
||||||
log.Errorf("failed to create ACL manager, policy will not work: %s", err.Error())
|
e.acl = acl.NewDefaultManager(e.firewall)
|
||||||
} else {
|
|
||||||
e.acl = acl
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.dnsServer.Initialize()
|
err = e.dnsServer.Initialize()
|
||||||
@ -1044,8 +1058,11 @@ func (e *Engine) close() {
|
|||||||
e.dnsServer.Stop()
|
e.dnsServer.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.acl != nil {
|
if e.firewall != nil {
|
||||||
e.acl.Stop()
|
err := e.firewall.Reset()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to reset firewall: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,75 +0,0 @@
|
|||||||
package routemanager
|
|
||||||
|
|
||||||
var insertRuleTestCases = []struct {
|
|
||||||
name string
|
|
||||||
inputPair routerPair
|
|
||||||
ipVersion string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Insert Forwarding IPV4 Rule",
|
|
||||||
inputPair: routerPair{
|
|
||||||
ID: "zxa",
|
|
||||||
source: "100.100.100.1/32",
|
|
||||||
destination: "100.100.200.0/24",
|
|
||||||
masquerade: false,
|
|
||||||
},
|
|
||||||
ipVersion: ipv4,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Insert Forwarding And Nat IPV4 Rules",
|
|
||||||
inputPair: routerPair{
|
|
||||||
ID: "zxa",
|
|
||||||
source: "100.100.100.1/32",
|
|
||||||
destination: "100.100.200.0/24",
|
|
||||||
masquerade: true,
|
|
||||||
},
|
|
||||||
ipVersion: ipv4,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Insert Forwarding IPV6 Rule",
|
|
||||||
inputPair: routerPair{
|
|
||||||
ID: "zxa",
|
|
||||||
source: "fc00::1/128",
|
|
||||||
destination: "fc12::/64",
|
|
||||||
masquerade: false,
|
|
||||||
},
|
|
||||||
ipVersion: ipv6,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Insert Forwarding And Nat IPV6 Rules",
|
|
||||||
inputPair: routerPair{
|
|
||||||
ID: "zxa",
|
|
||||||
source: "fc00::1/128",
|
|
||||||
destination: "fc12::/64",
|
|
||||||
masquerade: true,
|
|
||||||
},
|
|
||||||
ipVersion: ipv6,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var removeRuleTestCases = []struct {
|
|
||||||
name string
|
|
||||||
inputPair routerPair
|
|
||||||
ipVersion string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Remove Forwarding And Nat IPV4 Rules",
|
|
||||||
inputPair: routerPair{
|
|
||||||
ID: "zxa",
|
|
||||||
source: "100.100.100.1/32",
|
|
||||||
destination: "100.100.200.0/24",
|
|
||||||
masquerade: true,
|
|
||||||
},
|
|
||||||
ipVersion: ipv4,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Remove Forwarding And Nat IPV6 Rules",
|
|
||||||
inputPair: routerPair{
|
|
||||||
ID: "zxa",
|
|
||||||
source: "fc00::1/128",
|
|
||||||
destination: "fc12::/64",
|
|
||||||
masquerade: true,
|
|
||||||
},
|
|
||||||
ipVersion: ipv6,
|
|
||||||
},
|
|
||||||
}
|
|
@ -1,12 +0,0 @@
|
|||||||
package routemanager
|
|
||||||
|
|
||||||
type firewallManager interface {
|
|
||||||
// RestoreOrCreateContainers restores or creates a firewall container set of rules, tables and default rules
|
|
||||||
RestoreOrCreateContainers() error
|
|
||||||
// InsertRoutingRules inserts a routing firewall rule
|
|
||||||
InsertRoutingRules(pair routerPair) error
|
|
||||||
// RemoveRoutingRules removes a routing firewall rule
|
|
||||||
RemoveRoutingRules(pair routerPair) error
|
|
||||||
// CleanRoutingRules cleans a firewall set of containers
|
|
||||||
CleanRoutingRules()
|
|
||||||
}
|
|
@ -1,55 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/checkfw"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
|
|
||||||
ipv4Forwarding = "netbird-rt-ipv4-forwarding"
|
|
||||||
ipv6Nat = "netbird-rt-ipv6-nat"
|
|
||||||
ipv4Nat = "netbird-rt-ipv4-nat"
|
|
||||||
natFormat = "netbird-nat-%s"
|
|
||||||
forwardingFormat = "netbird-fwd-%s"
|
|
||||||
inNatFormat = "netbird-nat-in-%s"
|
|
||||||
inForwardingFormat = "netbird-fwd-in-%s"
|
|
||||||
ipv6 = "ipv6"
|
|
||||||
ipv4 = "ipv4"
|
|
||||||
)
|
|
||||||
|
|
||||||
func genKey(format string, input string) string {
|
|
||||||
return fmt.Sprintf(format, input)
|
|
||||||
}
|
|
||||||
|
|
||||||
// newFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
|
|
||||||
func newFirewall(parentCTX context.Context) (firewallManager, error) {
|
|
||||||
checkResult := checkfw.Check()
|
|
||||||
switch checkResult {
|
|
||||||
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
|
|
||||||
log.Debug("creating an iptables firewall manager for route rules")
|
|
||||||
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
|
|
||||||
return newIptablesManager(parentCTX, ipv6Supported)
|
|
||||||
case checkfw.NFTABLES:
|
|
||||||
log.Info("creating an nftables firewall manager for route rules")
|
|
||||||
return newNFTablesManager(parentCTX), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("couldn't initialize nftables or iptables clients. Using a dummy firewall manager for route rules")
|
|
||||||
}
|
|
||||||
|
|
||||||
func getInPair(pair routerPair) routerPair {
|
|
||||||
return routerPair{
|
|
||||||
ID: pair.ID,
|
|
||||||
// invert source/destination
|
|
||||||
source: pair.destination,
|
|
||||||
destination: pair.source,
|
|
||||||
masquerade: pair.masquerade,
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,15 +0,0 @@
|
|||||||
//go:build !linux
|
|
||||||
// +build !linux
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
)
|
|
||||||
|
|
||||||
// newFirewall returns a nil manager
|
|
||||||
func newFirewall(context.Context) (firewallManager, error) {
|
|
||||||
return nil, fmt.Errorf("firewall not supported on %s", runtime.GOOS)
|
|
||||||
}
|
|
@ -1,487 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"os/exec"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
func isIptablesSupported() bool {
|
|
||||||
_, err4 := exec.LookPath("iptables")
|
|
||||||
_, err6 := exec.LookPath("ip6tables")
|
|
||||||
return err4 == nil && err6 == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// constants needed to manage and create iptable rules
|
|
||||||
const (
|
|
||||||
iptablesFilterTable = "filter"
|
|
||||||
iptablesNatTable = "nat"
|
|
||||||
iptablesForwardChain = "FORWARD"
|
|
||||||
iptablesPostRoutingChain = "POSTROUTING"
|
|
||||||
iptablesRoutingNatChain = "NETBIRD-RT-NAT"
|
|
||||||
iptablesRoutingForwardingChain = "NETBIRD-RT-FWD"
|
|
||||||
routingFinalForwardJump = "ACCEPT"
|
|
||||||
routingFinalNatJump = "MASQUERADE"
|
|
||||||
)
|
|
||||||
|
|
||||||
// some presets for building nftable rules
|
|
||||||
var (
|
|
||||||
iptablesDefaultForwardingRule = []string{"-j", iptablesRoutingForwardingChain, "-m", "comment", "--comment"}
|
|
||||||
iptablesDefaultNetbirdForwardingRule = []string{"-j", "RETURN"}
|
|
||||||
iptablesDefaultNatRule = []string{"-j", iptablesRoutingNatChain, "-m", "comment", "--comment"}
|
|
||||||
iptablesDefaultNetbirdNatRule = []string{"-j", "RETURN"}
|
|
||||||
)
|
|
||||||
|
|
||||||
type iptablesManager struct {
|
|
||||||
ctx context.Context
|
|
||||||
stop context.CancelFunc
|
|
||||||
ipv4Client *iptables.IPTables
|
|
||||||
ipv6Client *iptables.IPTables
|
|
||||||
rules map[string]map[string][]string
|
|
||||||
mux sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func newIptablesManager(parentCtx context.Context, ipv6Supported bool) (*iptablesManager, error) {
|
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to initialize iptables for ipv4: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
|
||||||
manager := &iptablesManager{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
ipv4Client: ipv4Client,
|
|
||||||
rules: make(map[string]map[string][]string),
|
|
||||||
}
|
|
||||||
|
|
||||||
if ipv6Supported {
|
|
||||||
manager.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to initialize iptables for ipv6: %s. Routes for this protocol won't be applied.", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return manager, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CleanRoutingRules cleans existing iptables resources that we created by the agent
|
|
||||||
func (i *iptablesManager) CleanRoutingRules() {
|
|
||||||
i.mux.Lock()
|
|
||||||
defer i.mux.Unlock()
|
|
||||||
|
|
||||||
err := i.cleanJumpRules()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("flushing tables")
|
|
||||||
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
|
|
||||||
if i.ipv4Client != nil {
|
|
||||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if i.ipv6Client != nil {
|
|
||||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("done cleaning up iptables rules")
|
|
||||||
}
|
|
||||||
|
|
||||||
// RestoreOrCreateContainers restores existing iptables containers (chains and rules)
|
|
||||||
// if they don't exist, we create them
|
|
||||||
func (i *iptablesManager) RestoreOrCreateContainers() error {
|
|
||||||
i.mux.Lock()
|
|
||||||
defer i.mux.Unlock()
|
|
||||||
|
|
||||||
if i.rules[ipv4][ipv4Forwarding] != nil && i.rules[ipv6][ipv6Forwarding] != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
|
|
||||||
|
|
||||||
if i.ipv4Client != nil {
|
|
||||||
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.restoreRules(i.ipv4Client)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if i.ipv6Client != nil {
|
|
||||||
err := createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.restoreRules(i.ipv6Client)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := i.addJumpRules()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while creating jump rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addJumpRules create jump rules to send packets to NetBird chains
|
|
||||||
func (i *iptablesManager) addJumpRules() error {
|
|
||||||
err := i.cleanJumpRules()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if i.ipv4Client != nil {
|
|
||||||
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding) //nolint:gocritic
|
|
||||||
|
|
||||||
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
i.rules[ipv4][ipv4Forwarding] = rule
|
|
||||||
|
|
||||||
rule = append(iptablesDefaultNatRule, ipv4Nat) //nolint:gocritic
|
|
||||||
err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
i.rules[ipv4][ipv4Nat] = rule
|
|
||||||
}
|
|
||||||
|
|
||||||
if i.ipv6Client != nil {
|
|
||||||
rule := append(iptablesDefaultForwardingRule, ipv6Forwarding) //nolint:gocritic
|
|
||||||
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
i.rules[ipv6][ipv6Forwarding] = rule
|
|
||||||
|
|
||||||
rule = append(iptablesDefaultNatRule, ipv6Nat) //nolint:gocritic
|
|
||||||
err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
i.rules[ipv6][ipv6Nat] = rule
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
|
|
||||||
func (i *iptablesManager) cleanJumpRules() error {
|
|
||||||
var err error
|
|
||||||
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
|
|
||||||
rule, found := i.rules[ipv4][ipv4Forwarding]
|
|
||||||
if i.ipv4Client != nil {
|
|
||||||
if found {
|
|
||||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
|
|
||||||
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rule, found = i.rules[ipv4][ipv4Nat]
|
|
||||||
if found {
|
|
||||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat)
|
|
||||||
err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if i.ipv6Client == nil {
|
|
||||||
rule, found = i.rules[ipv6][ipv6Forwarding]
|
|
||||||
if found {
|
|
||||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
|
|
||||||
err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rule, found = i.rules[ipv6][ipv6Nat]
|
|
||||||
if found {
|
|
||||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat)
|
|
||||||
err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func iptablesProtoToString(proto iptables.Protocol) string {
|
|
||||||
if proto == iptables.ProtocolIPv6 {
|
|
||||||
return ipv6
|
|
||||||
}
|
|
||||||
return ipv4
|
|
||||||
}
|
|
||||||
|
|
||||||
// restoreRules restores existing NetBird rules
|
|
||||||
func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error {
|
|
||||||
ipVersion := iptablesProtoToString(iptablesClient.Proto())
|
|
||||||
|
|
||||||
if i.rules[ipVersion] == nil {
|
|
||||||
i.rules[ipVersion] = make(map[string][]string)
|
|
||||||
}
|
|
||||||
table := iptablesFilterTable
|
|
||||||
for _, chain := range []string{iptablesForwardChain, iptablesRoutingForwardingChain} {
|
|
||||||
rules, err := iptablesClient.List(table, chain)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for _, ruleString := range rules {
|
|
||||||
rule := strings.Fields(ruleString)
|
|
||||||
id := getRuleRouteID(rule)
|
|
||||||
if id != "" {
|
|
||||||
i.rules[ipVersion][id] = rule[2:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
table = iptablesNatTable
|
|
||||||
for _, chain := range []string{iptablesPostRoutingChain, iptablesRoutingNatChain} {
|
|
||||||
rules, err := iptablesClient.List(table, chain)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for _, ruleString := range rules {
|
|
||||||
rule := strings.Fields(ruleString)
|
|
||||||
id := getRuleRouteID(rule)
|
|
||||||
if id != "" {
|
|
||||||
i.rules[ipVersion][id] = rule[2:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// createChain create NetBird chains
|
|
||||||
func createChain(iptables *iptables.IPTables, table, newChain string) error {
|
|
||||||
chains, err := iptables.ListChains(table)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't get %s %s table chains, error: %v", iptablesProtoToString(iptables.Proto()), table, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
shouldCreateChain := true
|
|
||||||
for _, chain := range chains {
|
|
||||||
if chain == newChain {
|
|
||||||
shouldCreateChain = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldCreateChain {
|
|
||||||
err = iptables.NewChain(table, newChain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't create %s chain %s in %s table, error: %v", iptablesProtoToString(iptables.Proto()), newChain, table, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if table == iptablesNatTable {
|
|
||||||
err = iptables.Append(table, newChain, iptablesDefaultNetbirdNatRule...)
|
|
||||||
} else {
|
|
||||||
err = iptables.Append(table, newChain, iptablesDefaultNetbirdForwardingRule...)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't create %s chain %s default rule, error: %v", iptablesProtoToString(iptables.Proto()), newChain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// genRuleSpec generates rule specification with comment identifier
|
|
||||||
func genRuleSpec(jump, id, source, destination string) []string {
|
|
||||||
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRuleRouteID returns the rule ID if matches our prefix
|
|
||||||
func getRuleRouteID(rule []string) string {
|
|
||||||
for i, flag := range rule {
|
|
||||||
if flag == "--comment" {
|
|
||||||
id := rule[i+1]
|
|
||||||
if strings.HasPrefix(id, "netbird-") {
|
|
||||||
return id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
|
|
||||||
func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
|
|
||||||
i.mux.Lock()
|
|
||||||
defer i.mux.Unlock()
|
|
||||||
|
|
||||||
err := i.insertRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.insertRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, getInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !pair.masquerade {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.insertRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.insertRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, getInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// insertRoutingRule inserts an iptable rule
|
|
||||||
func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string, pair routerPair) error {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
prefix := netip.MustParsePrefix(pair.source)
|
|
||||||
ipVersion := ipv4
|
|
||||||
iptablesClient := i.ipv4Client
|
|
||||||
if prefix.Addr().Unmap().Is6() {
|
|
||||||
iptablesClient = i.ipv6Client
|
|
||||||
ipVersion = ipv6
|
|
||||||
}
|
|
||||||
|
|
||||||
if iptablesClient == nil {
|
|
||||||
return fmt.Errorf("unable to insert iptables routing rules. Iptables client is not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleKey := genKey(keyFormat, pair.ID)
|
|
||||||
rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination)
|
|
||||||
existingRule, found := i.rules[ipVersion][ruleKey]
|
|
||||||
if found {
|
|
||||||
err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
|
|
||||||
}
|
|
||||||
delete(i.rules[ipVersion], ruleKey)
|
|
||||||
}
|
|
||||||
err = iptablesClient.Insert(table, chain, 1, rule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
i.rules[ipVersion][ruleKey] = rule
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
|
|
||||||
func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
|
|
||||||
i.mux.Lock()
|
|
||||||
defer i.mux.Unlock()
|
|
||||||
|
|
||||||
err := i.removeRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.removeRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, getInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !pair.masquerade {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.removeRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = i.removeRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, getInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeRoutingRule removes an iptables rule
|
|
||||||
func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair routerPair) error {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
prefix := netip.MustParsePrefix(pair.source)
|
|
||||||
ipVersion := ipv4
|
|
||||||
iptablesClient := i.ipv4Client
|
|
||||||
if prefix.Addr().Unmap().Is6() {
|
|
||||||
iptablesClient = i.ipv6Client
|
|
||||||
ipVersion = ipv6
|
|
||||||
}
|
|
||||||
|
|
||||||
if iptablesClient == nil {
|
|
||||||
return fmt.Errorf("unable to remove iptables routing rules. Iptables client is not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleKey := genKey(keyFormat, pair.ID)
|
|
||||||
existingRule, found := i.rules[ipVersion][ruleKey]
|
|
||||||
if found {
|
|
||||||
err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(i.rules[ipVersion], ruleKey)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getIptablesRuleType(table string) string {
|
|
||||||
ruleType := "forwarding"
|
|
||||||
if table == iptablesNatTable {
|
|
||||||
ruleType = "nat"
|
|
||||||
}
|
|
||||||
return ruleType
|
|
||||||
}
|
|
@ -1,294 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|
||||||
|
|
||||||
if !isIptablesSupported() {
|
|
||||||
t.SkipNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := newIptablesManager(context.TODO(), true)
|
|
||||||
require.NoError(t, err, "should return a valid iptables manager")
|
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
|
||||||
|
|
||||||
err = manager.RestoreOrCreateContainers()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6")
|
|
||||||
|
|
||||||
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
|
|
||||||
|
|
||||||
exists, err := manager.ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
|
|
||||||
require.True(t, exists, "forwarding rule should exist")
|
|
||||||
|
|
||||||
exists, err = manager.ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
|
|
||||||
require.True(t, exists, "postrouting rule should exist")
|
|
||||||
|
|
||||||
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
|
|
||||||
|
|
||||||
exists, err = manager.ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
|
|
||||||
require.True(t, exists, "forwarding rule should exist")
|
|
||||||
|
|
||||||
exists, err = manager.ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
|
|
||||||
require.True(t, exists, "postrouting rule should exist")
|
|
||||||
|
|
||||||
pair := routerPair{
|
|
||||||
ID: "abc",
|
|
||||||
source: "100.100.100.1/32",
|
|
||||||
destination: "100.100.100.0/24",
|
|
||||||
masquerade: true,
|
|
||||||
}
|
|
||||||
forward4RuleKey := genKey(forwardingFormat, pair.ID)
|
|
||||||
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
|
|
||||||
|
|
||||||
err = manager.ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
nat4RuleKey := genKey(natFormat, pair.ID)
|
|
||||||
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
|
|
||||||
|
|
||||||
err = manager.ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
pair = routerPair{
|
|
||||||
ID: "abc",
|
|
||||||
source: "fc00::1/128",
|
|
||||||
destination: "fc11::/64",
|
|
||||||
masquerade: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
forward6RuleKey := genKey(forwardingFormat, pair.ID)
|
|
||||||
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
|
|
||||||
|
|
||||||
err = manager.ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
nat6RuleKey := genKey(natFormat, pair.ID)
|
|
||||||
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
|
|
||||||
|
|
||||||
err = manager.ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
delete(manager.rules, ipv4)
|
|
||||||
delete(manager.rules, ipv6)
|
|
||||||
|
|
||||||
err = manager.RestoreOrCreateContainers()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
require.Len(t, manager.rules[ipv4], 4, "should have restored all rules for ipv4")
|
|
||||||
|
|
||||||
foundRule, found := manager.rules[ipv4][forward4RuleKey]
|
|
||||||
require.True(t, found, "forwarding rule should exist in the map")
|
|
||||||
require.Equal(t, forward4Rule[:4], foundRule[:4], "stored forwarding rule should match")
|
|
||||||
|
|
||||||
foundRule, found = manager.rules[ipv4][nat4RuleKey]
|
|
||||||
require.True(t, found, "nat rule should exist in the map")
|
|
||||||
require.Equal(t, nat4Rule[:4], foundRule[:4], "stored nat rule should match")
|
|
||||||
|
|
||||||
require.Len(t, manager.rules[ipv6], 4, "should have restored all rules for ipv6")
|
|
||||||
|
|
||||||
foundRule, found = manager.rules[ipv6][forward6RuleKey]
|
|
||||||
require.True(t, found, "forwarding rule should exist in the map")
|
|
||||||
require.Equal(t, forward6Rule[:4], foundRule[:4], "stored forward rule should match")
|
|
||||||
|
|
||||||
foundRule, found = manager.rules[ipv6][nat6RuleKey]
|
|
||||||
require.True(t, found, "nat rule should exist in the map")
|
|
||||||
require.Equal(t, nat6Rule[:4], foundRule[:4], "stored nat rule should match")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
|
||||||
|
|
||||||
if !isIptablesSupported() {
|
|
||||||
t.SkipNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range insertRuleTestCases {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
|
||||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
|
||||||
iptablesClient := ipv4Client
|
|
||||||
if testCase.ipVersion == ipv6 {
|
|
||||||
iptablesClient = ipv6Client
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := &iptablesManager{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
ipv4Client: ipv4Client,
|
|
||||||
ipv6Client: ipv6Client,
|
|
||||||
rules: make(map[string]map[string][]string),
|
|
||||||
}
|
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
|
||||||
|
|
||||||
err := manager.RestoreOrCreateContainers()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
err = manager.InsertRoutingRules(testCase.inputPair)
|
|
||||||
require.NoError(t, err, "forwarding pair should be inserted")
|
|
||||||
|
|
||||||
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
|
|
||||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
|
||||||
|
|
||||||
exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
|
|
||||||
require.True(t, exists, "forwarding rule should exist")
|
|
||||||
|
|
||||||
foundRule, found := manager.rules[testCase.ipVersion][forwardRuleKey]
|
|
||||||
require.True(t, found, "forwarding rule should exist in the manager map")
|
|
||||||
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
|
||||||
|
|
||||||
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
|
|
||||||
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
|
|
||||||
require.True(t, exists, "income forwarding rule should exist")
|
|
||||||
|
|
||||||
foundRule, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
|
|
||||||
require.True(t, found, "income forwarding rule should exist in the manager map")
|
|
||||||
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
|
|
||||||
|
|
||||||
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
|
||||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
|
|
||||||
if testCase.inputPair.masquerade {
|
|
||||||
require.True(t, exists, "nat rule should be created")
|
|
||||||
foundNatRule, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
|
|
||||||
require.True(t, foundNat, "nat rule should exist in the map")
|
|
||||||
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
|
|
||||||
} else {
|
|
||||||
require.False(t, exists, "nat rule should not be created")
|
|
||||||
_, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
|
|
||||||
require.False(t, foundNat, "nat rule should not exist in the map")
|
|
||||||
}
|
|
||||||
|
|
||||||
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
|
|
||||||
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
|
|
||||||
if testCase.inputPair.masquerade {
|
|
||||||
require.True(t, exists, "income nat rule should be created")
|
|
||||||
foundNatRule, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
|
|
||||||
require.True(t, foundNat, "income nat rule should exist in the map")
|
|
||||||
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
|
|
||||||
} else {
|
|
||||||
require.False(t, exists, "nat rule should not be created")
|
|
||||||
_, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
|
|
||||||
require.False(t, foundNat, "income nat rule should not exist in the map")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
|
||||||
|
|
||||||
if !isIptablesSupported() {
|
|
||||||
t.SkipNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range removeRuleTestCases {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
|
||||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
|
||||||
iptablesClient := ipv4Client
|
|
||||||
if testCase.ipVersion == ipv6 {
|
|
||||||
iptablesClient = ipv6Client
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := &iptablesManager{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
ipv4Client: ipv4Client,
|
|
||||||
ipv6Client: ipv6Client,
|
|
||||||
rules: make(map[string]map[string][]string),
|
|
||||||
}
|
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
|
||||||
|
|
||||||
err := manager.RestoreOrCreateContainers()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
|
|
||||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
|
||||||
|
|
||||||
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
|
|
||||||
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
|
|
||||||
|
|
||||||
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, inForwardRule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
|
||||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
|
||||||
|
|
||||||
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
|
|
||||||
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
|
|
||||||
|
|
||||||
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, inNatRule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
delete(manager.rules, ipv4)
|
|
||||||
delete(manager.rules, ipv6)
|
|
||||||
|
|
||||||
err = manager.RestoreOrCreateContainers()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
err = manager.RemoveRoutingRules(testCase.inputPair)
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
|
|
||||||
require.False(t, exists, "forwarding rule should not exist")
|
|
||||||
|
|
||||||
_, found := manager.rules[testCase.ipVersion][forwardRuleKey]
|
|
||||||
require.False(t, found, "forwarding rule should exist in the manager map")
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
|
|
||||||
require.False(t, exists, "income forwarding rule should not exist")
|
|
||||||
|
|
||||||
_, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
|
|
||||||
require.False(t, found, "income forwarding rule should exist in the manager map")
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
|
|
||||||
require.False(t, exists, "nat rule should not exist")
|
|
||||||
|
|
||||||
_, found = manager.rules[testCase.ipVersion][natRuleKey]
|
|
||||||
require.False(t, found, "nat rule should exist in the manager map")
|
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
|
|
||||||
require.False(t, exists, "income nat rule should not exist")
|
|
||||||
|
|
||||||
_, found = manager.rules[testCase.ipVersion][inNatRuleKey]
|
|
||||||
require.False(t, found, "income nat rule should exist in the manager map")
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
@ -19,6 +20,7 @@ type Manager interface {
|
|||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
|
EnableServerRouter(firewall firewall.Manager) error
|
||||||
Stop()
|
Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -35,19 +37,12 @@ type DefaultManager struct {
|
|||||||
notifier *notifier
|
notifier *notifier
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager returns a new route manager
|
|
||||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
|
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
|
||||||
srvRouter, err := newServerRouter(ctx, wgInterface)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("server router is not supported: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
mCTX, cancel := context.WithCancel(ctx)
|
mCTX, cancel := context.WithCancel(ctx)
|
||||||
dm := &DefaultManager{
|
dm := &DefaultManager{
|
||||||
ctx: mCTX,
|
ctx: mCTX,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
clientNetworks: make(map[string]*clientNetwork),
|
clientNetworks: make(map[string]*clientNetwork),
|
||||||
serverRouter: srvRouter,
|
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
pubKey: pubKey,
|
pubKey: pubKey,
|
||||||
@ -61,6 +56,15 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
|||||||
return dm
|
return dm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||||
|
var err error
|
||||||
|
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Stop stops the manager watchers and clean firewall rules
|
// Stop stops the manager watchers and clean firewall rules
|
||||||
func (m *DefaultManager) Stop() {
|
func (m *DefaultManager) Stop() {
|
||||||
m.stop()
|
m.stop()
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@ -37,6 +38,10 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
// Stop mock implementation of Stop from Manager interface
|
// Stop mock implementation of Stop from Manager interface
|
||||||
func (m *MockManager) Stop() {
|
func (m *MockManager) Stop() {
|
||||||
if m.StopFunc != nil {
|
if m.StopFunc != nil {
|
||||||
|
@ -1,571 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/binaryutil"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
nftablesTable = "netbird-rt"
|
|
||||||
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
|
||||||
nftablesRoutingNatChain = "netbird-rt-nat"
|
|
||||||
|
|
||||||
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
|
||||||
userDataAcceptForwardRuleDst = "frwacceptdst"
|
|
||||||
)
|
|
||||||
|
|
||||||
// constants needed to create nftable rules
|
|
||||||
const (
|
|
||||||
ipv4Len = 4
|
|
||||||
ipv4SrcOffset = 12
|
|
||||||
ipv4DestOffset = 16
|
|
||||||
ipv6Len = 16
|
|
||||||
ipv6SrcOffset = 8
|
|
||||||
ipv6DestOffset = 24
|
|
||||||
exprDirectionSource = "source"
|
|
||||||
exprDirectionDestination = "destination"
|
|
||||||
)
|
|
||||||
|
|
||||||
// some presets for building nftable rules
|
|
||||||
var (
|
|
||||||
zeroXor = binaryutil.NativeEndian.PutUint32(0)
|
|
||||||
|
|
||||||
zeroXor6 = append(binaryutil.NativeEndian.PutUint64(0), binaryutil.NativeEndian.PutUint64(0)...)
|
|
||||||
|
|
||||||
exprAllowRelatedEstablished = []expr.Any{
|
|
||||||
&expr.Ct{
|
|
||||||
Register: 1,
|
|
||||||
SourceRegister: false,
|
|
||||||
Key: 0,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
DestRegister: 1,
|
|
||||||
SourceRegister: 1,
|
|
||||||
Len: 4,
|
|
||||||
Mask: []uint8{0x6, 0x0, 0x0, 0x0},
|
|
||||||
Xor: zeroXor,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(0),
|
|
||||||
},
|
|
||||||
&expr.Counter{},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
exprCounterAccept = []expr.Any{
|
|
||||||
&expr.Counter{},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
type nftablesManager struct {
|
|
||||||
ctx context.Context
|
|
||||||
stop context.CancelFunc
|
|
||||||
conn *nftables.Conn
|
|
||||||
tableIPv4 *nftables.Table
|
|
||||||
tableIPv6 *nftables.Table
|
|
||||||
chains map[string]map[string]*nftables.Chain
|
|
||||||
rules map[string]*nftables.Rule
|
|
||||||
filterTable *nftables.Table
|
|
||||||
defaultForwardRules []*nftables.Rule
|
|
||||||
mux sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func newNFTablesManager(parentCtx context.Context) *nftablesManager {
|
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
|
||||||
|
|
||||||
return &nftablesManager{
|
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
conn: &nftables.Conn{},
|
|
||||||
chains: make(map[string]map[string]*nftables.Chain),
|
|
||||||
rules: make(map[string]*nftables.Rule),
|
|
||||||
defaultForwardRules: make([]*nftables.Rule, 2),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CleanRoutingRules cleans existing nftables rules from the system
|
|
||||||
func (n *nftablesManager) CleanRoutingRules() {
|
|
||||||
n.mux.Lock()
|
|
||||||
defer n.mux.Unlock()
|
|
||||||
log.Debug("flushing tables")
|
|
||||||
if n.tableIPv4 != nil && n.tableIPv6 != nil {
|
|
||||||
n.conn.FlushTable(n.tableIPv6)
|
|
||||||
n.conn.FlushTable(n.tableIPv4)
|
|
||||||
}
|
|
||||||
|
|
||||||
if n.defaultForwardRules[0] != nil {
|
|
||||||
err := n.eraseDefaultForwardRule()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete forward rule: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
|
||||||
}
|
|
||||||
|
|
||||||
// RestoreOrCreateContainers restores existing nftables containers (tables and chains)
|
|
||||||
// if they don't exist, we create them
|
|
||||||
func (n *nftablesManager) RestoreOrCreateContainers() error {
|
|
||||||
n.mux.Lock()
|
|
||||||
defer n.mux.Unlock()
|
|
||||||
|
|
||||||
if n.tableIPv6 != nil && n.tableIPv4 != nil {
|
|
||||||
log.Debugf("nftables: containers already restored, skipping")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
tables, err := n.conn.ListTables()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to list tables: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, table := range tables {
|
|
||||||
if table.Name == "filter" && table.Family == nftables.TableFamilyIPv4 {
|
|
||||||
log.Debugf("nftables: found filter table for ipv4")
|
|
||||||
n.filterTable = table
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if table.Name == nftablesTable {
|
|
||||||
if table.Family == nftables.TableFamilyIPv4 {
|
|
||||||
n.tableIPv4 = table
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
n.tableIPv6 = table
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if n.tableIPv4 == nil {
|
|
||||||
n.tableIPv4 = n.conn.AddTable(&nftables.Table{
|
|
||||||
Name: nftablesTable,
|
|
||||||
Family: nftables.TableFamilyIPv4,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if n.tableIPv6 == nil {
|
|
||||||
n.tableIPv6 = n.conn.AddTable(&nftables.Table{
|
|
||||||
Name: nftablesTable,
|
|
||||||
Family: nftables.TableFamilyIPv6,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
chains, err := n.conn.ListChains()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to list chains: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
n.chains[ipv4] = make(map[string]*nftables.Chain)
|
|
||||||
n.chains[ipv6] = make(map[string]*nftables.Chain)
|
|
||||||
|
|
||||||
for _, chain := range chains {
|
|
||||||
switch {
|
|
||||||
case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv4:
|
|
||||||
n.chains[ipv4][chain.Name] = chain
|
|
||||||
case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv6:
|
|
||||||
n.chains[ipv6][chain.Name] = chain
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, found := n.chains[ipv4][nftablesRoutingForwardingChain]; !found {
|
|
||||||
n.chains[ipv4][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: nftablesRoutingForwardingChain,
|
|
||||||
Table: n.tableIPv4,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityNATDest + 1,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, found := n.chains[ipv4][nftablesRoutingNatChain]; !found {
|
|
||||||
n.chains[ipv4][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: nftablesRoutingNatChain,
|
|
||||||
Table: n.tableIPv4,
|
|
||||||
Hooknum: nftables.ChainHookPostrouting,
|
|
||||||
Priority: nftables.ChainPriorityNATSource - 1,
|
|
||||||
Type: nftables.ChainTypeNAT,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, found := n.chains[ipv6][nftablesRoutingForwardingChain]; !found {
|
|
||||||
n.chains[ipv6][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: nftablesRoutingForwardingChain,
|
|
||||||
Table: n.tableIPv6,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityNATDest + 1,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, found := n.chains[ipv6][nftablesRoutingNatChain]; !found {
|
|
||||||
n.chains[ipv6][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: nftablesRoutingNatChain,
|
|
||||||
Table: n.tableIPv6,
|
|
||||||
Hooknum: nftables.ChainHookPostrouting,
|
|
||||||
Priority: nftables.ChainPriorityNATSource - 1,
|
|
||||||
Type: nftables.ChainTypeNAT,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
err = n.refreshRulesMap()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
n.checkOrCreateDefaultForwardingRules()
|
|
||||||
err = n.conn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
|
||||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
|
||||||
func (n *nftablesManager) refreshRulesMap() error {
|
|
||||||
for _, registeredChains := range n.chains {
|
|
||||||
for _, chain := range registeredChains {
|
|
||||||
rules, err := n.conn.GetRules(chain.Table, chain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
|
||||||
}
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 {
|
|
||||||
n.rules[string(rule.UserData)] = rule
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *nftablesManager) eraseDefaultForwardRule() error {
|
|
||||||
if n.defaultForwardRules[0] == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := n.refreshDefaultForwardRule()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, r := range n.defaultForwardRules {
|
|
||||||
err = n.conn.DelRule(r)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete forward rule (%d): %s", i, err)
|
|
||||||
}
|
|
||||||
n.defaultForwardRules[i] = nil
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *nftablesManager) refreshDefaultForwardRule() error {
|
|
||||||
rules, err := n.conn.GetRules(n.defaultForwardRules[0].Table, n.defaultForwardRules[0].Chain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to list rules in forward chain: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
found := false
|
|
||||||
for i, r := range n.defaultForwardRules {
|
|
||||||
for _, rule := range rules {
|
|
||||||
if string(rule.UserData) == string(r.UserData) {
|
|
||||||
n.defaultForwardRules[i] = rule
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
return fmt.Errorf("unable to find forward accept rule")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *nftablesManager) acceptForwardRule(sourceNetwork string) error {
|
|
||||||
src := generateCIDRMatcherExpressions("source", sourceNetwork)
|
|
||||||
dst := generateCIDRMatcherExpressions("destination", "0.0.0.0/0")
|
|
||||||
|
|
||||||
var exprs []expr.Any
|
|
||||||
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
})...)
|
|
||||||
|
|
||||||
r := &nftables.Rule{
|
|
||||||
Table: n.filterTable,
|
|
||||||
Chain: &nftables.Chain{
|
|
||||||
Name: "FORWARD",
|
|
||||||
Table: n.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: exprs,
|
|
||||||
UserData: []byte(userDataAcceptForwardRuleSrc),
|
|
||||||
}
|
|
||||||
|
|
||||||
n.defaultForwardRules[0] = n.conn.AddRule(r)
|
|
||||||
|
|
||||||
src = generateCIDRMatcherExpressions("source", "0.0.0.0/0")
|
|
||||||
dst = generateCIDRMatcherExpressions("destination", sourceNetwork)
|
|
||||||
|
|
||||||
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
})...)
|
|
||||||
|
|
||||||
r = &nftables.Rule{
|
|
||||||
Table: n.filterTable,
|
|
||||||
Chain: &nftables.Chain{
|
|
||||||
Name: "FORWARD",
|
|
||||||
Table: n.filterTable,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
},
|
|
||||||
Exprs: exprs,
|
|
||||||
UserData: []byte(userDataAcceptForwardRuleDst),
|
|
||||||
}
|
|
||||||
|
|
||||||
n.defaultForwardRules[1] = n.conn.AddRule(r)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
|
|
||||||
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
|
|
||||||
_, foundIPv4 := n.rules[ipv4Forwarding]
|
|
||||||
if !foundIPv4 {
|
|
||||||
n.rules[ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{
|
|
||||||
Table: n.tableIPv4,
|
|
||||||
Chain: n.chains[ipv4][nftablesRoutingForwardingChain],
|
|
||||||
Exprs: exprAllowRelatedEstablished,
|
|
||||||
UserData: []byte(ipv4Forwarding),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
_, foundIPv6 := n.rules[ipv6Forwarding]
|
|
||||||
if !foundIPv6 {
|
|
||||||
n.rules[ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{
|
|
||||||
Table: n.tableIPv6,
|
|
||||||
Chain: n.chains[ipv6][nftablesRoutingForwardingChain],
|
|
||||||
Exprs: exprAllowRelatedEstablished,
|
|
||||||
UserData: []byte(ipv6Forwarding),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
|
||||||
func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
|
|
||||||
n.mux.Lock()
|
|
||||||
defer n.mux.Unlock()
|
|
||||||
|
|
||||||
err := n.refreshRulesMap()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = n.insertRoutingRule(forwardingFormat, nftablesRoutingForwardingChain, pair, false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = n.insertRoutingRule(inForwardingFormat, nftablesRoutingForwardingChain, getInPair(pair), false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if pair.masquerade {
|
|
||||||
err = n.insertRoutingRule(natFormat, nftablesRoutingNatChain, pair, true)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = n.insertRoutingRule(inNatFormat, nftablesRoutingNatChain, getInPair(pair), true)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if n.defaultForwardRules[0] == nil && n.filterTable != nil {
|
|
||||||
err = n.acceptForwardRule(pair.source)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to create default forward rule: %s", err)
|
|
||||||
}
|
|
||||||
log.Debugf("default accept forward rule added")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = n.conn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// insertRoutingRule inserts a nftable rule to the conn client flush queue
|
|
||||||
func (n *nftablesManager) insertRoutingRule(format, chain string, pair routerPair, isNat bool) error {
|
|
||||||
|
|
||||||
prefix := netip.MustParsePrefix(pair.source)
|
|
||||||
|
|
||||||
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
|
|
||||||
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
|
|
||||||
|
|
||||||
var expression []expr.Any
|
|
||||||
if isNat {
|
|
||||||
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
|
||||||
} else {
|
|
||||||
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleKey := genKey(format, pair.ID)
|
|
||||||
|
|
||||||
_, exists := n.rules[ruleKey]
|
|
||||||
if exists {
|
|
||||||
err := n.removeRoutingRule(format, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefix.Addr().Unmap().Is4() {
|
|
||||||
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
|
|
||||||
Table: n.tableIPv4,
|
|
||||||
Chain: n.chains[ipv4][chain],
|
|
||||||
Exprs: expression,
|
|
||||||
UserData: []byte(ruleKey),
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
|
|
||||||
Table: n.tableIPv6,
|
|
||||||
Chain: n.chains[ipv6][chain],
|
|
||||||
Exprs: expression,
|
|
||||||
UserData: []byte(ruleKey),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
|
|
||||||
func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
|
|
||||||
n.mux.Lock()
|
|
||||||
defer n.mux.Unlock()
|
|
||||||
|
|
||||||
err := n.refreshRulesMap()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = n.removeRoutingRule(forwardingFormat, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = n.removeRoutingRule(inForwardingFormat, getInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = n.removeRoutingRule(natFormat, pair)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = n.removeRoutingRule(inNatFormat, getInPair(pair))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(n.rules) == 2 && n.defaultForwardRules[0] != nil {
|
|
||||||
err := n.eraseDefaultForwardRule()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete default fwd rule: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = n.conn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
|
||||||
}
|
|
||||||
log.Debugf("nftables: removed rules for %s", pair.destination)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
|
|
||||||
func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) error {
|
|
||||||
ruleKey := genKey(format, pair.ID)
|
|
||||||
|
|
||||||
rule, found := n.rules[ruleKey]
|
|
||||||
if found {
|
|
||||||
ruleType := "forwarding"
|
|
||||||
if rule.Chain.Type == nftables.ChainTypeNAT {
|
|
||||||
ruleType = "nat"
|
|
||||||
}
|
|
||||||
|
|
||||||
err := n.conn.DelRule(rule)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.destination)
|
|
||||||
|
|
||||||
delete(n.rules, ruleKey)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getPayloadDirectives get expression directives based on ip version and direction
|
|
||||||
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
|
||||||
switch {
|
|
||||||
case direction == exprDirectionSource && isIPv4:
|
|
||||||
return ipv4SrcOffset, ipv4Len, zeroXor
|
|
||||||
case direction == exprDirectionDestination && isIPv4:
|
|
||||||
return ipv4DestOffset, ipv4Len, zeroXor
|
|
||||||
case direction == exprDirectionSource && isIPv6:
|
|
||||||
return ipv6SrcOffset, ipv6Len, zeroXor6
|
|
||||||
case direction == exprDirectionDestination && isIPv6:
|
|
||||||
return ipv6DestOffset, ipv6Len, zeroXor6
|
|
||||||
default:
|
|
||||||
panic("no matched payload directive")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
|
||||||
func generateCIDRMatcherExpressions(direction string, cidr string) []expr.Any {
|
|
||||||
ip, network, _ := net.ParseCIDR(cidr)
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
|
||||||
add := ipToAdd.Unmap()
|
|
||||||
|
|
||||||
offSet, packetLen, zeroXor := getPayloadDirectives(direction, add.Is4(), add.Is6())
|
|
||||||
|
|
||||||
return []expr.Any{
|
|
||||||
// fetch src add
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: offSet,
|
|
||||||
Len: packetLen,
|
|
||||||
},
|
|
||||||
// net mask
|
|
||||||
&expr.Bitwise{
|
|
||||||
DestRegister: 1,
|
|
||||||
SourceRegister: 1,
|
|
||||||
Len: packetLen,
|
|
||||||
Mask: network.Mask,
|
|
||||||
Xor: zeroXor,
|
|
||||||
},
|
|
||||||
// net address
|
|
||||||
&expr.Cmp{
|
|
||||||
Register: 1,
|
|
||||||
Data: add.AsSlice(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,324 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/checkfw"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|
||||||
|
|
||||||
if checkfw.Check() != checkfw.NFTABLES {
|
|
||||||
t.Skip("nftables not supported on this OS")
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := newNFTablesManager(context.TODO())
|
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
|
||||||
|
|
||||||
err := manager.RestoreOrCreateContainers()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
|
||||||
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4")
|
|
||||||
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6")
|
|
||||||
require.Len(t, manager.rules, 2, "should have created rules for ipv4 and ipv6")
|
|
||||||
|
|
||||||
pair := routerPair{
|
|
||||||
ID: "abc",
|
|
||||||
source: "100.100.100.1/32",
|
|
||||||
destination: "100.100.100.0/24",
|
|
||||||
masquerade: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
|
|
||||||
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
|
|
||||||
|
|
||||||
forward4Exp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
|
||||||
forward4RuleKey := genKey(forwardingFormat, pair.ID)
|
|
||||||
inserted4Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: manager.tableIPv4,
|
|
||||||
Chain: manager.chains[ipv4][nftablesRoutingForwardingChain],
|
|
||||||
Exprs: forward4Exp,
|
|
||||||
UserData: []byte(forward4RuleKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
nat4Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
|
||||||
nat4RuleKey := genKey(natFormat, pair.ID)
|
|
||||||
|
|
||||||
inserted4Nat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: manager.tableIPv4,
|
|
||||||
Chain: manager.chains[ipv4][nftablesRoutingNatChain],
|
|
||||||
Exprs: nat4Exp,
|
|
||||||
UserData: []byte(nat4RuleKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
err = nftablesTestingClient.Flush()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
pair = routerPair{
|
|
||||||
ID: "xyz",
|
|
||||||
source: "fc00::1/128",
|
|
||||||
destination: "fc11::/64",
|
|
||||||
masquerade: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
sourceExp = generateCIDRMatcherExpressions("source", pair.source)
|
|
||||||
destExp = generateCIDRMatcherExpressions("destination", pair.destination)
|
|
||||||
|
|
||||||
forward6Exp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
|
||||||
forward6RuleKey := genKey(forwardingFormat, pair.ID)
|
|
||||||
inserted6Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: manager.tableIPv6,
|
|
||||||
Chain: manager.chains[ipv6][nftablesRoutingForwardingChain],
|
|
||||||
Exprs: forward6Exp,
|
|
||||||
UserData: []byte(forward6RuleKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
nat6Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
|
||||||
nat6RuleKey := genKey(natFormat, pair.ID)
|
|
||||||
|
|
||||||
inserted6Nat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: manager.tableIPv6,
|
|
||||||
Chain: manager.chains[ipv6][nftablesRoutingNatChain],
|
|
||||||
Exprs: nat6Exp,
|
|
||||||
UserData: []byte(nat6RuleKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
err = nftablesTestingClient.Flush()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
manager.tableIPv4 = nil
|
|
||||||
manager.tableIPv6 = nil
|
|
||||||
|
|
||||||
err = manager.RestoreOrCreateContainers()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
|
||||||
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4")
|
|
||||||
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6")
|
|
||||||
require.Len(t, manager.rules, 6, "should have restored all rules for ipv4 and ipv6")
|
|
||||||
|
|
||||||
foundRule, found := manager.rules[forward4RuleKey]
|
|
||||||
require.True(t, found, "forwarding rule should exist in the map")
|
|
||||||
assert.Equal(t, inserted4Forwarding.Exprs, foundRule.Exprs, "stored forwarding rule expressions should match")
|
|
||||||
|
|
||||||
foundRule, found = manager.rules[nat4RuleKey]
|
|
||||||
require.True(t, found, "nat rule should exist in the map")
|
|
||||||
// match len of output as nftables client doesn't return expressions with masquerade expression
|
|
||||||
assert.ElementsMatch(t, inserted4Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule expressions should match")
|
|
||||||
|
|
||||||
foundRule, found = manager.rules[forward6RuleKey]
|
|
||||||
require.True(t, found, "forwarding rule should exist in the map")
|
|
||||||
assert.Equal(t, inserted6Forwarding.Exprs, foundRule.Exprs, "stored forward rule should match")
|
|
||||||
|
|
||||||
foundRule, found = manager.rules[nat6RuleKey]
|
|
||||||
require.True(t, found, "nat rule should exist in the map")
|
|
||||||
// match len of output as nftables client doesn't return expressions with masquerade expression
|
|
||||||
assert.ElementsMatch(t, inserted6Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule should match")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
|
||||||
if checkfw.Check() != checkfw.NFTABLES {
|
|
||||||
t.Skip("nftables not supported on this OS")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range insertRuleTestCases {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
manager := newNFTablesManager(context.TODO())
|
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
|
||||||
|
|
||||||
err := manager.RestoreOrCreateContainers()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
err = manager.InsertRoutingRules(testCase.inputPair)
|
|
||||||
require.NoError(t, err, "forwarding pair should be inserted")
|
|
||||||
|
|
||||||
sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source)
|
|
||||||
destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination)
|
|
||||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
|
||||||
fwdRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
|
|
||||||
|
|
||||||
found := 0
|
|
||||||
for _, registeredChains := range manager.chains {
|
|
||||||
for _, chain := range registeredChains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
|
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
|
|
||||||
found = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
||||||
|
|
||||||
if testCase.inputPair.masquerade {
|
|
||||||
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
|
||||||
found := 0
|
|
||||||
for _, registeredChains := range manager.chains {
|
|
||||||
for _, chain := range registeredChains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
|
|
||||||
found = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
||||||
}
|
|
||||||
|
|
||||||
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
|
|
||||||
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
|
|
||||||
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
|
|
||||||
inFwdRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
|
|
||||||
|
|
||||||
found = 0
|
|
||||||
for _, registeredChains := range manager.chains {
|
|
||||||
for _, chain := range registeredChains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
|
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
|
|
||||||
found = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
||||||
|
|
||||||
if testCase.inputPair.masquerade {
|
|
||||||
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
|
|
||||||
found := 0
|
|
||||||
for _, registeredChains := range manager.chains {
|
|
||||||
for _, chain := range registeredChains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
|
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
|
|
||||||
found = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|
||||||
if checkfw.Check() != checkfw.NFTABLES {
|
|
||||||
t.Skip("nftables not supported on this OS")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range removeRuleTestCases {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
manager := newNFTablesManager(context.TODO())
|
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
|
||||||
|
|
||||||
err := manager.RestoreOrCreateContainers()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
table := manager.tableIPv4
|
|
||||||
if testCase.ipVersion == ipv6 {
|
|
||||||
table = manager.tableIPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source)
|
|
||||||
destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination)
|
|
||||||
|
|
||||||
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
|
||||||
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
|
|
||||||
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: table,
|
|
||||||
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
|
|
||||||
Exprs: forwardExp,
|
|
||||||
UserData: []byte(forwardRuleKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
|
||||||
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
|
||||||
|
|
||||||
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: table,
|
|
||||||
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
|
|
||||||
Exprs: natExp,
|
|
||||||
UserData: []byte(natRuleKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
|
|
||||||
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
|
|
||||||
|
|
||||||
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
|
|
||||||
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
|
|
||||||
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: table,
|
|
||||||
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
|
|
||||||
Exprs: forwardExp,
|
|
||||||
UserData: []byte(inForwardRuleKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
|
||||||
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
|
|
||||||
|
|
||||||
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: table,
|
|
||||||
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
|
|
||||||
Exprs: natExp,
|
|
||||||
UserData: []byte(inNatRuleKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
err = nftablesTestingClient.Flush()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
manager.tableIPv4 = nil
|
|
||||||
manager.tableIPv6 = nil
|
|
||||||
|
|
||||||
err = manager.RestoreOrCreateContainers()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
err = manager.RemoveRoutingRules(testCase.inputPair)
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
for _, registeredChains := range manager.chains {
|
|
||||||
for _, chain := range registeredChains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 {
|
|
||||||
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
|
|
||||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
|
||||||
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
|
|
||||||
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,24 +0,0 @@
|
|||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
)
|
|
||||||
|
|
||||||
type routerPair struct {
|
|
||||||
ID string
|
|
||||||
source string
|
|
||||||
destination string
|
|
||||||
masquerade bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func routeToRouterPair(source string, route *route.Route) routerPair {
|
|
||||||
parsed := netip.MustParsePrefix(source).Masked()
|
|
||||||
return routerPair{
|
|
||||||
ID: route.ID,
|
|
||||||
source: parsed.String(),
|
|
||||||
destination: route.Network.Masked().String(),
|
|
||||||
masquerade: route.Masquerade,
|
|
||||||
}
|
|
||||||
}
|
|
@ -4,9 +4,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newServerRouter(context.Context, *iface.WGIface) (serverRouter, error) {
|
func newServerRouter(context.Context, *iface.WGIface, firewall.Manager) (serverRouter, error) {
|
||||||
return nil, fmt.Errorf("server route not supported on this os")
|
return nil, fmt.Errorf("server route not supported on this os")
|
||||||
}
|
}
|
||||||
|
@ -4,11 +4,12 @@ package routemanager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@ -17,16 +18,11 @@ type defaultServerRouter struct {
|
|||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
routes map[string]*route.Route
|
routes map[string]*route.Route
|
||||||
firewall firewallManager
|
firewall firewall.Manager
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) {
|
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager) (serverRouter, error) {
|
||||||
firewall, err := newFirewall(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &defaultServerRouter{
|
return &defaultServerRouter{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
routes: make(map[string]*route.Route),
|
routes: make(map[string]*route.Route),
|
||||||
@ -38,13 +34,6 @@ func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRou
|
|||||||
func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
||||||
serverRoutesToRemove := make([]string, 0)
|
serverRoutesToRemove := make([]string, 0)
|
||||||
|
|
||||||
if len(routesMap) > 0 {
|
|
||||||
err := m.firewall.RestoreOrCreateContainers()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for routeID := range m.routes {
|
for routeID := range m.routes {
|
||||||
update, found := routesMap[routeID]
|
update, found := routesMap[routeID]
|
||||||
if !found || !update.IsEqual(m.routes[routeID]) {
|
if !found || !update.IsEqual(m.routes[routeID]) {
|
||||||
@ -121,5 +110,22 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultServerRouter) cleanUp() {
|
func (m *defaultServerRouter) cleanUp() {
|
||||||
m.firewall.CleanRoutingRules()
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
for _, r := range m.routes {
|
||||||
|
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r))
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to remove clean up route: %s", r.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func routeToRouterPair(source string, route *route.Route) firewall.RouterPair {
|
||||||
|
parsed := netip.MustParsePrefix(source).Masked()
|
||||||
|
return firewall.RouterPair{
|
||||||
|
ID: route.ID,
|
||||||
|
Source: parsed.String(),
|
||||||
|
Destination: route.Network.Masked().String(),
|
||||||
|
Masquerade: route.Masquerade,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
Loading…
Reference in New Issue
Block a user