[client] Set up firewall rules for dns routes dynamically based on dns response (#3702)

This commit is contained in:
Viktor Liu 2025-04-24 17:37:28 +02:00 committed by GitHub
parent 85f92f8321
commit 4a9049566a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 1399 additions and 591 deletions

View File

@ -113,17 +113,16 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination firewall.Network,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort, dPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if !destination.Addr().Is4() { if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
} }
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@ -243,6 +242,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.router.DeleteDNATRule(rule) return m.router.DeleteDNATRule(rule)
} }
// UpdateSet updates the set with the given prefixes
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes)
}
func getConntrackEstablished() []string { func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
} }

View File

@ -57,18 +57,18 @@ type ruleInfo struct {
} }
type routeFilteringRuleParams struct { type routeFilteringRuleParams struct {
Sources []netip.Prefix Source firewall.Network
Destination netip.Prefix Destination firewall.Network
Proto firewall.Protocol Proto firewall.Protocol
SPort *firewall.Port SPort *firewall.Port
DPort *firewall.Port DPort *firewall.Port
Direction firewall.RuleDirection Direction firewall.RuleDirection
Action firewall.Action Action firewall.Action
SetName string
} }
type routeRules map[string][]string type routeRules map[string][]string
// the ipset library currently does not support comments, so we use the name only (string)
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}] type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
type router struct { type router struct {
@ -129,7 +129,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination firewall.Network,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
@ -140,27 +140,28 @@ func (r *router) AddRouteFiltering(
return ruleKey, nil return ruleKey, nil
} }
var setName string var source firewall.Network
if len(sources) > 1 { if len(sources) > 1 {
setName = firewall.GenerateSetName(sources) source.Set = firewall.NewPrefixSet(sources)
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil { } else if len(sources) > 0 {
return nil, fmt.Errorf("create or get ipset: %w", err) source.Prefix = sources[0]
}
} }
params := routeFilteringRuleParams{ params := routeFilteringRuleParams{
Sources: sources, Source: source,
Destination: destination, Destination: destination,
Proto: proto, Proto: proto,
SPort: sPort, SPort: sPort,
DPort: dPort, DPort: dPort,
Action: action, Action: action,
SetName: setName,
} }
rule := genRouteFilteringRuleSpec(params) rule, err := r.genRouteRuleSpec(params, sources)
if err != nil {
return nil, fmt.Errorf("generate route rule spec: %w", err)
}
// Insert DROP rules at the beginning, append ACCEPT rules at the end // Insert DROP rules at the beginning, append ACCEPT rules at the end
var err error
if action == firewall.ActionDrop { if action == firewall.ActionDrop {
// after the established rule // after the established rule
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...) err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
@ -183,17 +184,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.ID() ruleKey := rule.ID()
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil { if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err) return fmt.Errorf("delete route rule: %v", err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
if setName != "" { if err := r.decrementSetCounter(rule); err != nil {
if _, err := r.ipsetCounter.Decrement(setName); err != nil { return fmt.Errorf("decrement ipset counter: %w", err)
return fmt.Errorf("failed to remove ipset: %w", err)
}
} }
} else { } else {
log.Debugf("route rule %s not found", ruleKey) log.Debugf("route rule %s not found", ruleKey)
@ -204,13 +201,26 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return nil return nil
} }
func (r *router) findSetNameInRule(rule []string) string { func (r *router) decrementSetCounter(rule []string) error {
for i, arg := range rule { sets := r.findSets(rule)
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet { var merr *multierror.Error
return rule[i+3] for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
} }
} }
return ""
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) findSets(rule []string) []string {
var sets []string
for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
sets = append(sets, rule[i+3])
}
}
return sets
} }
func (r *router) createIpSet(setName string, sources []netip.Prefix) error { func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
@ -231,6 +241,8 @@ func (r *router) deleteIpSet(setName string) error {
if err := ipset.Destroy(setName); err != nil { if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err) return fmt.Errorf("destroy set %s: %w", setName, err)
} }
log.Debugf("Deleted unused ipset %s", setName)
return nil return nil
} }
@ -270,12 +282,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
log.Errorf("%v", err) log.Errorf("%v", err)
} }
if err := r.removeNatRule(pair); err != nil { if pair.Masquerade {
return fmt.Errorf("remove nat rule: %w", err) if err := r.removeNatRule(pair); err != nil {
} return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err) return fmt.Errorf("remove inverse nat rule: %w", err)
}
} }
if err := r.removeLegacyRouteRule(pair); err != nil { if err := r.removeLegacyRouteRule(pair); err != nil {
@ -313,8 +327,10 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} else {
log.Debugf("legacy forwarding rule %s not found", ruleKey) if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} }
return nil return nil
@ -599,12 +615,24 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
rule = append(rule, rule = append(rule,
"-m", "conntrack", "-m", "conntrack",
"--ctstate", "NEW", "--ctstate", "NEW",
"-s", pair.Source.String(), )
"-d", pair.Destination.String(), sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
if err != nil {
return fmt.Errorf("apply network -s: %w", err)
}
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
if err != nil {
return fmt.Errorf("apply network -d: %w", err)
}
rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
rule = append(rule,
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue), "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
) )
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil { if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
// TODO: rollback ipset counter
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err) return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
} }
@ -622,6 +650,10 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err) return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} else { } else {
log.Debugf("marking rule %s not found", ruleKey) log.Debugf("marking rule %s not found", ruleKey)
} }
@ -787,17 +819,21 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) {
var rule []string var rule []string
if params.SetName != "" { sourceExp, err := r.applyNetwork("-s", params.Source, sources)
rule = append(rule, "-m", "set", matchSet, params.SetName, "src") if err != nil {
} else if len(params.Sources) > 0 { return nil, fmt.Errorf("apply network -s: %w", err)
source := params.Sources[0]
rule = append(rule, "-s", source.String()) }
destExp, err := r.applyNetwork("-d", params.Destination, nil)
if err != nil {
return nil, fmt.Errorf("apply network -d: %w", err)
} }
rule = append(rule, "-d", params.Destination.String()) rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
if params.Proto != firewall.ProtocolALL { if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto))) rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
@ -807,7 +843,47 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
rule = append(rule, "-j", actionToStr(params.Action)) rule = append(rule, "-j", actionToStr(params.Action))
return rule return rule, nil
}
func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
direction := "src"
if flag == "-d" {
direction = "dst"
}
if network.IsSet() {
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
}
if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil
}
// nolint:nilnil
return nil, nil
}
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
var merr *multierror.Error
for _, prefix := range prefixes {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
}
}
if merr == nil {
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
}
return nberrors.FormatErrorOrNil(merr)
} }
func applyPort(flag string, port *firewall.Port) []string { func applyPort(flag string, port *firewall.Port) []string {

View File

@ -60,8 +60,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
pair := firewall.RouterPair{ pair := firewall.RouterPair{
ID: "abc", ID: "abc",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: netip.MustParsePrefix("100.100.100.0/24"), Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")},
Masquerade: true, Masquerade: true,
} }
@ -332,7 +332,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map // Check if the rule is in the internal map
@ -347,23 +347,29 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
assert.NoError(t, err, "Failed to check rule existence") assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables") assert.True(t, exists, "Rule not found in iptables")
var source firewall.Network
if len(tt.sources) > 1 {
source.Set = firewall.NewPrefixSet(tt.sources)
} else if len(tt.sources) > 0 {
source.Prefix = tt.sources[0]
}
// Verify rule content // Verify rule content
params := routeFilteringRuleParams{ params := routeFilteringRuleParams{
Sources: tt.sources, Source: source,
Destination: tt.destination, Destination: firewall.Network{Prefix: tt.destination},
Proto: tt.proto, Proto: tt.proto,
SPort: tt.sPort, SPort: tt.sPort,
DPort: tt.dPort, DPort: tt.dPort,
Action: tt.action, Action: tt.action,
SetName: "",
} }
expectedRule := genRouteFilteringRuleSpec(params) expectedRule, err := r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec")
if tt.expectSet { if tt.expectSet {
setName := firewall.GenerateSetName(tt.sources) setName := firewall.NewPrefixSet(tt.sources).HashedName()
params.SetName = setName expectedRule, err = r.genRouteRuleSpec(params, nil)
expectedRule = genRouteFilteringRuleSpec(params) require.NoError(t, err, "Failed to generate expected rule spec with set")
// Check if the set was created // Check if the set was created
_, exists := r.ipsetCounter.Get(setName) _, exists := r.ipsetCounter.Get(setName)
@ -378,3 +384,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
}) })
} }
} }
func TestFindSetNameInRule(t *testing.T) {
r := &router{}
testCases := []struct {
name string
rule []string
expected []string
}{
{
name: "Basic rule with two sets",
rule: []string{
"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src",
"-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT",
},
expected: []string{"nb-2e5a2a05", "nb-349ae051"},
},
{
name: "No sets",
rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"},
expected: []string{},
},
{
name: "Multiple sets with different positions",
rule: []string{
"-m", "set", "--match-set", "set1", "src", "-p", "tcp",
"-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT",
},
expected: []string{"set1", "set-abc123"},
},
{
name: "Boundary case - sequence appears at end",
rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"},
expected: []string{"final-set"},
},
{
name: "Incomplete pattern - missing set name",
rule: []string{"-p", "tcp", "-m", "set", "--match-set"},
expected: []string{},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := r.findSets(tc.rule)
if len(result) != len(tc.expected) {
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
return
}
for i, set := range result {
if set != tc.expected[i] {
t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set)
}
}
})
}
}

View File

@ -1,13 +1,10 @@
package manager package manager
import ( import (
"crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"sort" "sort"
"strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -43,6 +40,18 @@ const (
// Action is the action to be taken on a rule // Action is the action to be taken on a rule
type Action int type Action int
// String returns the string representation of the action
func (a Action) String() string {
switch a {
case ActionAccept:
return "accept"
case ActionDrop:
return "drop"
default:
return "unknown"
}
}
const ( const (
// ActionAccept is the action to accept a packet // ActionAccept is the action to accept a packet
ActionAccept Action = iota ActionAccept Action = iota
@ -50,6 +59,33 @@ const (
ActionDrop ActionDrop
) )
// Network is a rule destination, either a set or a prefix
type Network struct {
Set Set
Prefix netip.Prefix
}
// String returns the string representation of the destination
func (d Network) String() string {
if d.Prefix.IsValid() {
return d.Prefix.String()
}
if d.IsSet() {
return d.Set.HashedName()
}
return "<invalid network>"
}
// IsSet returns true if the destination is a set
func (d Network) IsSet() bool {
return d.Set != Set{}
}
// IsPrefix returns true if the destination is a valid prefix
func (d Network) IsPrefix() bool {
return d.Prefix.IsValid()
}
// Manager is the high level abstraction of a firewall manager // Manager is the high level abstraction of a firewall manager
// //
// It declares methods which handle actions required by the // It declares methods which handle actions required by the
@ -83,10 +119,9 @@ type Manager interface {
AddRouteFiltering( AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination Network,
proto Protocol, proto Protocol,
sPort *Port, sPort, dPort *Port,
dPort *Port,
action Action, action Action,
) (Rule, error) ) (Rule, error)
@ -119,6 +154,9 @@ type Manager interface {
// DeleteDNATRule deletes a DNAT rule // DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error DeleteDNATRule(Rule) error
// UpdateSet updates the set with the given prefixes
UpdateSet(hash Set, prefixes []netip.Prefix) error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, pair RouterPair) string {
@ -153,22 +191,6 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
return nil return nil
} }
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix // MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 { if len(prefixes) == 0 {

View File

@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("192.168.1.0/24"),
} }
result1 := manager.GenerateSetName(prefixes1) result1 := manager.NewPrefixSet(prefixes1)
result2 := manager.GenerateSetName(prefixes2) result2 := manager.NewPrefixSet(prefixes2)
if result1 != result2 { if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2) t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("10.0.0.0/8"), netip.MustParsePrefix("10.0.0.0/8"),
} }
result := manager.GenerateSetName(prefixes) result := manager.NewPrefixSet(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result) matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName())
if err != nil { if err != nil {
t.Fatalf("Error matching regex: %v", err) t.Fatalf("Error matching regex: %v", err)
} }
@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) {
}) })
t.Run("Empty input produces consistent result", func(t *testing.T) { t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := manager.GenerateSetName([]netip.Prefix{}) result1 := manager.NewPrefixSet([]netip.Prefix{})
result2 := manager.GenerateSetName([]netip.Prefix{}) result2 := manager.NewPrefixSet([]netip.Prefix{})
if result1 != result2 { if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2) t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("192.168.1.0/24"),
} }
result1 := manager.GenerateSetName(prefixes1) result1 := manager.NewPrefixSet(prefixes1)
result2 := manager.GenerateSetName(prefixes2) result2 := manager.NewPrefixSet(prefixes2)
if result1 != result2 { if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2) t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)

View File

@ -1,15 +1,13 @@
package manager package manager
import ( import (
"net/netip"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type RouterPair struct { type RouterPair struct {
ID route.ID ID route.ID
Source netip.Prefix Source Network
Destination netip.Prefix Destination Network
Masquerade bool Masquerade bool
Inverse bool Inverse bool
} }

View File

@ -0,0 +1,74 @@
package manager
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/netip"
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/domain"
)
type Set struct {
hash [4]byte
comment string
}
// String returns the string representation of the set: hashed name and comment
func (h Set) String() string {
if h.comment == "" {
return h.HashedName()
}
return h.HashedName() + ": " + h.comment
}
// HashedName returns the string representation of the hash
func (h Set) HashedName() string {
return fmt.Sprintf(
"nb-%s",
hex.EncodeToString(h.hash[:]),
)
}
// Comment returns the comment of the set
func (h Set) Comment() string {
return h.comment
}
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
func NewPrefixSet(prefixes []netip.Prefix) Set {
// sort for consistent naming
SortPrefixes(prefixes)
hash := sha256.New()
for _, src := range prefixes {
bytes, err := src.MarshalBinary()
if err != nil {
log.Warnf("failed to marshal prefix %s: %v", src, err)
}
hash.Write(bytes)
}
var set Set
copy(set.hash[:], hash.Sum(nil)[:4])
return set
}
// NewDomainSet generates a unique name for an ipset based on the given domains.
func NewDomainSet(domains domain.List) Set {
slices.Sort(domains)
hash := sha256.New()
for _, d := range domains {
hash.Write([]byte(d.PunycodeString()))
}
set := Set{
comment: domains.SafeString(),
}
copy(set.hash[:], hash.Sum(nil)[:4])
return set
}

View File

@ -135,17 +135,16 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination firewall.Network,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort, dPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if !destination.Addr().Is4() { if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
} }
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@ -242,7 +241,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy) return firewall.SetLegacyManagement(m.router, isLegacy)
} }
// Reset firewall to the default state // Close closes the firewall manager
func (m *Manager) Close(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@ -359,6 +358,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.router.DeleteDNATRule(rule) return m.router.DeleteDNATRule(rule)
} }
// UpdateSet updates the set with the given prefixes
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {

View File

@ -289,7 +289,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(
nil, nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"), fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
fw.ProtocolTCP, fw.ProtocolTCP,
nil, nil,
&fw.Port{Values: []uint16{443}}, &fw.Port{Values: []uint16{443}},
@ -298,8 +298,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
require.NoError(t, err, "failed to add route filtering rule") require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{ pair := fw.RouterPair{
Source: netip.MustParsePrefix("192.168.1.0/24"), Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
Destination: netip.MustParsePrefix("10.0.0.0/24"), Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
Masquerade: true, Masquerade: true,
} }
err = manager.AddNatRule(pair) err = manager.AddNatRule(pair)

View File

@ -10,7 +10,6 @@ import (
"strings" "strings"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
@ -44,9 +43,14 @@ const (
const refreshRulesMapError = "refresh rules map: %w" const refreshRulesMapError = "refresh rules map: %w"
var ( var (
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") errFilterTableNotFound = fmt.Errorf("'filter' table not found")
) )
type setInput struct {
set firewall.Set
prefixes []netip.Prefix
}
type router struct { type router struct {
conn *nftables.Conn conn *nftables.Conn
workTable *nftables.Table workTable *nftables.Table
@ -54,7 +58,7 @@ type router struct {
chains map[string]*nftables.Chain chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
wgIface iFaceMapper wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState ipFwdState *ipfwdstate.IPForwardingState
@ -163,7 +167,7 @@ func (r *router) removeNatPreroutingRules() error {
func (r *router) loadFilterTable() (*nftables.Table, error) { func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err) return nil, fmt.Errorf("unable to list tables: %v", err)
} }
for _, table := range tables { for _, table := range tables {
@ -316,7 +320,7 @@ func (r *router) setupDataPlaneMark() error {
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination firewall.Network,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
@ -331,23 +335,29 @@ func (r *router) AddRouteFiltering(
chain := r.chains[chainNameRoutingFw] chain := r.chains[chainNameRoutingFw]
var exprs []expr.Any var exprs []expr.Any
var source firewall.Network
switch { switch {
case len(sources) == 1 && sources[0].Bits() == 0: case len(sources) == 1 && sources[0].Bits() == 0:
// If it's 0.0.0.0/0, we don't need to add any source matching // If it's 0.0.0.0/0, we don't need to add any source matching
case len(sources) == 1: case len(sources) == 1:
// If there's only one source, we can use it directly // If there's only one source, we can use it directly
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...) source.Prefix = sources[0]
default: default:
// If there are multiple sources, create or get an ipset // If there are multiple sources, use a set
var err error source.Set = firewall.NewPrefixSet(sources)
exprs, err = r.getIpSetExprs(sources, exprs)
if err != nil {
return nil, fmt.Errorf("get ipset expressions: %w", err)
}
} }
// Handle destination sourceExp, err := r.applyNetwork(source, sources, true)
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...) if err != nil {
return nil, fmt.Errorf("apply source: %w", err)
}
exprs = append(exprs, sourceExp...)
destExp, err := r.applyNetwork(destination, nil, false)
if err != nil {
return nil, fmt.Errorf("apply destination: %w", err)
}
exprs = append(exprs, destExp...)
// Handle protocol // Handle protocol
if proto != firewall.ProtocolALL { if proto != firewall.ProtocolALL {
@ -391,39 +401,27 @@ func (r *router) AddRouteFiltering(
rule = r.conn.AddRule(rule) rule = r.conn.AddRule(rule)
} }
log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err) return nil, fmt.Errorf(flushError, err)
} }
r.rules[string(ruleKey)] = rule r.rules[string(ruleKey)] = rule
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
return ruleKey, nil return ruleKey, nil
} }
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
setName := firewall.GenerateSetName(sources) ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
ref, err := r.ipsetCounter.Increment(setName, sources) set: set,
prefixes: prefixes,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("create or get ipset for sources: %w", err) return nil, fmt.Errorf("create or get ipset: %w", err)
} }
exprs = append(exprs, return getIpSetExprs(ref, isSource)
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
)
return exprs, nil
} }
func (r *router) DeleteRouteRule(rule firewall.Rule) error { func (r *router) DeleteRouteRule(rule firewall.Rule) error {
@ -442,42 +440,54 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return fmt.Errorf("route rule %s has no handle", ruleKey) return fmt.Errorf("route rule %s has no handle", ruleKey)
} }
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil { if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
return fmt.Errorf("delete: %w", err) return fmt.Errorf("delete: %w", err)
} }
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("decrement ipset reference: %w", err)
}
}
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err) return fmt.Errorf(flushError, err)
} }
if err := r.decrementSetCounter(nftRule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil return nil
} }
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) { func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them // overlapping prefixes will result in an error, so we need to merge them
sources = firewall.MergeIPRanges(sources) prefixes := firewall.MergeIPRanges(input.prefixes)
set := &nftables.Set{ nfset := &nftables.Set{
Name: setName, Name: setName,
Table: r.workTable, Comment: input.set.Comment(),
Table: r.workTable,
// required for prefixes // required for prefixes
Interval: true, Interval: true,
KeyType: nftables.TypeIPAddr, KeyType: nftables.TypeIPAddr,
} }
elements := convertPrefixesToSet(prefixes)
if err := r.conn.AddSet(nfset, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return nfset, nil
}
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement var elements []nftables.SetElement
for _, prefix := range sources { for _, prefix := range prefixes {
// TODO: Implement IPv6 support // TODO: Implement IPv6 support
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue continue
} }
@ -493,18 +503,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
) )
} }
return elements
if err := r.conn.AddSet(set, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return set, nil
} }
// calculateLastIP determines the last IP in a given prefix. // calculateLastIP determines the last IP in a given prefix.
@ -528,8 +527,8 @@ func uint32ToBytes(ip uint32) [4]byte {
return b return b
} }
func (r *router) deleteIpSet(setName string, set *nftables.Set) error { func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error {
r.conn.DelSet(set) r.conn.DelSet(nfset)
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err) return fmt.Errorf(flushError, err)
} }
@ -538,13 +537,27 @@ func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
return nil return nil
} }
func (r *router) findSetNameInRule(rule *nftables.Rule) string { func (r *router) decrementSetCounter(rule *nftables.Rule) error {
for _, e := range rule.Exprs { sets := r.findSets(rule)
if lookup, ok := e.(*expr.Lookup); ok {
return lookup.SetName var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
} }
} }
return ""
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) findSets(rule *nftables.Rule) []string {
var sets []string
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
sets = append(sets, lookup.SetName)
}
}
return sets
} }
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error { func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
@ -586,7 +599,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
} }
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err) // TODO: rollback ipset counter
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
} }
return nil return nil
@ -594,19 +608,22 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// addNatRule inserts a nftables rule to the conn client flush queue // addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error { func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source) sourceExp, err := r.applyNetwork(pair.Source, nil, true)
destExp := generateCIDRMatcherExpressions(false, pair.Destination) if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
op := expr.CmpOpEq op := expr.CmpOpEq
if pair.Inverse { if pair.Inverse {
op = expr.CmpOpNeq op = expr.CmpOpNeq
} }
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading. exprs := []expr.Any{
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
exprs := getCtNewExprs()
exprs = append(exprs,
// interface matching
&expr.Meta{ &expr.Meta{
Key: expr.MetaKeyIIFNAME, Key: expr.MetaKeyIIFNAME,
Register: 1, Register: 1,
@ -616,7 +633,10 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
Register: 1, Register: 1,
Data: ifname(r.wgIface.Name()), Data: ifname(r.wgIface.Name()),
}, },
) }
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
exprs = append(exprs, getCtNewExprs()...)
exprs = append(exprs, sourceExp...) exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...) exprs = append(exprs, destExp...)
@ -729,8 +749,15 @@ func (r *router) addPostroutingRules() error {
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source) sourceExp, err := r.applyNetwork(pair.Source, nil, true)
destExp := generateCIDRMatcherExpressions(false, pair.Destination) if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
exprs := []expr.Any{ exprs := []expr.Any{
&expr.Counter{}, &expr.Counter{},
@ -739,7 +766,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
}, },
} }
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
@ -752,7 +780,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable, Table: r.workTable,
Chain: r.chains[chainNameRoutingFw], Chain: r.chains[chainNameRoutingFw],
Exprs: expression, Exprs: exprs,
UserData: []byte(ruleKey), UserData: []byte(ruleKey),
}) })
return nil return nil
@ -767,11 +795,13 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey) if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} }
return nil return nil
@ -982,12 +1012,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
if err := r.removeNatRule(pair); err != nil { if pair.Masquerade {
return fmt.Errorf("remove prerouting rule: %w", err) if err := r.removeNatRule(pair); err != nil {
} return fmt.Errorf("remove prerouting rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err) return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
} }
if err := r.removeLegacyRouteRule(pair); err != nil { if err := r.removeLegacyRouteRule(pair); err != nil {
@ -995,10 +1027,10 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
} }
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) // TODO: rollback set counter
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
} }
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
return nil return nil
} }
@ -1006,16 +1038,19 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule) if err := r.conn.DelRule(rule); err != nil {
if err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination) log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else { } else {
log.Debugf("nftables: prerouting rule %s not found", ruleKey) log.Debugf("prerouting rule %s not found", ruleKey)
} }
return nil return nil
@ -1027,7 +1062,7 @@ func (r *router) refreshRulesMap() error {
for _, chain := range r.chains { for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain) rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil { if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err) return fmt.Errorf(" unable to list rules: %v", err)
} }
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 { if len(rule.UserData) > 0 {
@ -1301,13 +1336,54 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any { nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
var offset uint32 if err != nil {
if source { return fmt.Errorf("get set %s: %w", set.HashedName(), err)
offset = 12 // src offset }
} else {
offset = 16 // dst offset elements := convertPrefixesToSet(prefixes)
if err := r.conn.SetAddElements(nfset, elements); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,
setPrefixes []netip.Prefix,
isSource bool,
) ([]expr.Any, error) {
if network.IsSet() {
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
if err != nil {
return nil, fmt.Errorf("source: %w", err)
}
return exprs, nil
}
if network.IsPrefix() {
return applyPrefix(network.Prefix, isSource), nil
}
return nil, nil
}
// applyPrefix generates nftables expressions for a CIDR prefix
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = 12
} }
ones := prefix.Bits() ones := prefix.Bits()
@ -1415,3 +1491,27 @@ func getCtNewExprs() []expr.Any {
}, },
} }
} }
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = 12
}
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
}, nil
}

View File

@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
} }
// Build CIDR matching expressions // Build CIDR matching expressions
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
// Combine all expressions in the correct order // Combine all expressions in the correct order
// nolint:gocritic // nolint:gocritic
@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() { t.Cleanup(func() {
@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
setName := firewall.GenerateSetName(tt.sources) setName := firewall.NewPrefixSet(tt.sources).HashedName()
set, err := r.createIpSet(setName, tt.sources) set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
if err != nil { if err != nil {
t.Logf("Failed to create IP set: %v", err) t.Logf("Failed to create IP set: %v", err)
printNftSets() printNftSets()

View File

@ -15,8 +15,8 @@ var (
Name: "Insert Forwarding IPV4 Rule", Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: false, Masquerade: false,
}, },
}, },
@ -24,8 +24,8 @@ var (
Name: "Insert Forwarding And Nat IPV4 Rules", Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true, Masquerade: true,
}, },
}, },
@ -40,8 +40,8 @@ var (
Name: "Remove Forwarding And Nat IPV4 Rules", Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true, Masquerade: true,
}, },
}, },

View File

@ -12,7 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// Reset firewall to the default state // Close cleans up the firewall manager by removing all rules and closing trackers
func (m *Manager) Close(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()

View File

@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@ -22,7 +21,7 @@ const (
firewallRuleName = "Netbird" firewallRuleName = "Netbird"
) )
// Reset firewall to the default state // Close cleans up the firewall manager by removing all rules and closing trackers
func (m *Manager) Close(*statemanager.Manager) error { func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@ -32,17 +31,14 @@ func (m *Manager) Close(*statemanager.Manager) error {
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
} }
if fwder := m.forwarder.Load(); fwder != nil { if fwder := m.forwarder.Load(); fwder != nil {

View File

@ -29,14 +29,15 @@ func (r *PeerRule) ID() string {
} }
type RouteRule struct { type RouteRule struct {
id string id string
mgmtId []byte mgmtId []byte
sources []netip.Prefix sources []netip.Prefix
destination netip.Prefix dstSet firewall.Set
proto firewall.Protocol destinations []netip.Prefix
srcPort *firewall.Port proto firewall.Protocol
dstPort *firewall.Port srcPort *firewall.Port
action firewall.Action dstPort *firewall.Port
action firewall.Action
} }
// ID returns the rule id // ID returns the rule id

View File

@ -199,7 +199,7 @@ func TestTracePacket(t *testing.T) {
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept) _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err) require.NoError(t, err)
}, },
packetBuilder: func() *PacketBuilder { packetBuilder: func() *PacketBuilder {
@ -223,7 +223,7 @@ func TestTracePacket(t *testing.T) {
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop) _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err) require.NoError(t, err)
}, },
packetBuilder: func() *PacketBuilder { packetBuilder: func() *PacketBuilder {

View File

@ -49,10 +49,10 @@ var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// RuleSet is a set of rules grouped by a string key // RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule type RuleSet map[string]PeerRule
type RouteRules []RouteRule type RouteRules []*RouteRule
func (r RouteRules) Sort() { func (r RouteRules) Sort() {
slices.SortStableFunc(r, func(a, b RouteRule) int { slices.SortStableFunc(r, func(a, b *RouteRule) int {
// Deny rules come first // Deny rules come first
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop { if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
return -1 return -1
@ -99,6 +99,8 @@ type Manager struct {
forwarder atomic.Pointer[forwarder.Forwarder] forwarder atomic.Pointer[forwarder.Forwarder]
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
blockRule firewall.Rule
} }
// decoder for packages // decoder for packages
@ -201,41 +203,35 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
} }
} }
if err := m.blockInvalidRouted(iface); err != nil {
log.Errorf("failed to block invalid routed traffic: %v", err)
}
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
return nil, fmt.Errorf("set filter: %w", err) return nil, fmt.Errorf("set filter: %w", err)
} }
return m, nil return m, nil
} }
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error { func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
if m.forwarder.Load() == nil {
return nil
}
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
if err != nil { if err != nil {
return fmt.Errorf("parse wireguard network: %w", err) return nil, fmt.Errorf("parse wireguard network: %w", err)
} }
log.Debugf("blocking invalid routed traffic for %s", wgPrefix) log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
if _, err := m.AddRouteFiltering( rule, err := m.addRouteFiltering(
nil, nil,
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
wgPrefix, firewall.Network{Prefix: wgPrefix},
firewall.ProtocolALL, firewall.ProtocolALL,
nil, nil,
nil, nil,
firewall.ActionDrop, firewall.ActionDrop,
); err != nil { )
return fmt.Errorf("block wg nte : %w", err) if err != nil {
return nil, fmt.Errorf("block wg nte : %w", err)
} }
// TODO: Block networks that we're a client of // TODO: Block networks that we're a client of
return nil return rule, nil
} }
func (m *Manager) determineRouting() error { func (m *Manager) determineRouting() error {
@ -413,10 +409,23 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination firewall.Network,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort, dPort *firewall.Port,
dPort *firewall.Port, action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
func (m *Manager) addRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
if m.nativeRouter.Load() && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
@ -426,34 +435,39 @@ func (m *Manager) AddRouteFiltering(
ruleID := uuid.New().String() ruleID := uuid.New().String()
rule := RouteRule{ rule := RouteRule{
// TODO: consolidate these IDs // TODO: consolidate these IDs
id: ruleID, id: ruleID,
mgmtId: id, mgmtId: id,
sources: sources, sources: sources,
destination: destination, dstSet: destination.Set,
proto: proto, proto: proto,
srcPort: sPort, srcPort: sPort,
dstPort: dPort, dstPort: dPort,
action: action, action: action,
}
if destination.IsPrefix() {
rule.destinations = []netip.Prefix{destination.Prefix}
} }
m.mutex.Lock() m.routeRules = append(m.routeRules, &rule)
m.routeRules = append(m.routeRules, rule)
m.routeRules.Sort() m.routeRules.Sort()
m.mutex.Unlock()
return &rule, nil return &rule, nil
} }
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.deleteRouteRule(rule)
}
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule) return m.nativeFirewall.DeleteRouteRule(rule)
} }
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := rule.ID() ruleID := rule.ID()
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool { idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == ruleID return r.id == ruleID
}) })
if idx < 0 { if idx < 0 {
@ -509,6 +523,52 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteDNATRule(rule) return m.nativeFirewall.DeleteDNATRule(rule)
} }
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.UpdateSet(set, prefixes)
}
m.mutex.Lock()
defer m.mutex.Unlock()
var matches []*RouteRule
for _, rule := range m.routeRules {
if rule.dstSet == set {
matches = append(matches, rule)
}
}
if len(matches) == 0 {
return fmt.Errorf("no route rule found for set: %s", set)
}
destinations := matches[0].destinations
for _, prefix := range prefixes {
if prefix.Addr().Is4() {
destinations = append(destinations, prefix)
}
}
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
cmp := a.Addr().Compare(b.Addr())
if cmp != 0 {
return cmp
}
return a.Bits() - b.Bits()
})
destinations = slices.Compact(destinations)
for _, rule := range matches {
rule.destinations = destinations
}
log.Debugf("updated set %s to prefixes %v", set.HashedName(), destinations)
return nil
}
// DropOutgoing filter outgoing packets // DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte, size int) bool { func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData, size) return m.processOutgoingHooks(packetData, size)
@ -988,8 +1048,15 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol
return nil, false return nil, false
} }
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
if !rule.destination.Contains(dstAddr) { destMatched := false
for _, dst := range rule.destinations {
if dst.Contains(dstAddr) {
destMatched = true
break
}
}
if !destMatched {
return false return false
} }
@ -1091,7 +1158,22 @@ func (m *Manager) EnableRouting() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.determineRouting() if err := m.determineRouting(); err != nil {
return fmt.Errorf("determine routing: %w", err)
}
if m.forwarder.Load() == nil {
return nil
}
rule, err := m.blockInvalidRouted(m.wgIface)
if err != nil {
return fmt.Errorf("block invalid routed: %w", err)
}
m.blockRule = rule
return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
@ -1116,5 +1198,12 @@ func (m *Manager) DisableRouting() error {
log.Debug("forwarder stopped") log.Debug("forwarder stopped")
if m.blockRule != nil {
if err := m.deleteRouteRule(m.blockRule); err != nil {
return fmt.Errorf("delete block rule: %w", err)
}
m.blockRule = nil
}
return nil return nil
} }

View File

@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/management/domain"
) )
func TestPeerACLFiltering(t *testing.T) { func TestPeerACLFiltering(t *testing.T) {
@ -600,8 +601,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err) require.NoError(tb, err)
require.NoError(tb, manager.EnableRouting())
require.NotNil(tb, manager) require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled.Load()) require.True(tb, manager.routingEnabled.Load())
require.False(tb, manager.nativeRouter.Load()) require.False(tb, manager.nativeRouter.Load())
@ -618,7 +619,7 @@ func TestRouteACLFiltering(t *testing.T) {
type rule struct { type rule struct {
sources []netip.Prefix sources []netip.Prefix
dest netip.Prefix dest fw.Network
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@ -644,7 +645,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -660,7 +661,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -676,7 +677,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("0.0.0.0/0"), dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -692,7 +693,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 53, dstPort: 53,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP, proto: fw.ProtocolUDP,
dstPort: &fw.Port{Values: []uint16{53}}, dstPort: &fw.Port{Values: []uint16{53}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -706,7 +707,7 @@ func TestRouteACLFiltering(t *testing.T) {
proto: fw.ProtocolICMP, proto: fw.ProtocolICMP,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("0.0.0.0/0"), dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
proto: fw.ProtocolICMP, proto: fw.ProtocolICMP,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
@ -721,7 +722,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -737,7 +738,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -753,7 +754,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -769,7 +770,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -785,7 +786,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{Values: []uint16{12345}}, srcPort: &fw.Port{Values: []uint16{12345}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -804,7 +805,7 @@ func TestRouteACLFiltering(t *testing.T) {
netip.MustParsePrefix("100.10.0.0/16"), netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("172.16.0.0/16"),
}, },
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -818,7 +819,7 @@ func TestRouteACLFiltering(t *testing.T) {
proto: fw.ProtocolICMP, proto: fw.ProtocolICMP,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
@ -833,7 +834,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -849,7 +850,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}}, dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -865,7 +866,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}}, srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -881,7 +882,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
srcPort: &fw.Port{Values: []uint16{12345}}, srcPort: &fw.Port{Values: []uint16{12345}},
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
@ -898,7 +899,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@ -917,7 +918,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 7999, dstPort: 7999,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@ -936,7 +937,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{ srcPort: &fw.Port{
IsRange: true, IsRange: true,
@ -955,7 +956,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{ srcPort: &fw.Port{
IsRange: true, IsRange: true,
@ -977,7 +978,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8100, dstPort: 8100,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@ -996,7 +997,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 5060, dstPort: 5060,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP, proto: fw.ProtocolUDP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@ -1015,7 +1016,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@ -1034,7 +1035,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop, action: fw.ActionDrop,
@ -1050,7 +1051,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
action: fw.ActionDrop, action: fw.ActionDrop,
}, },
@ -1068,13 +1069,32 @@ func TestRouteACLFiltering(t *testing.T) {
netip.MustParsePrefix("100.10.0.0/16"), netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("172.16.0.0/16"),
}, },
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionDrop, action: fw.ActionDrop,
}, },
shouldPass: false, shouldPass: false,
}, },
{
name: "Drop empty destination set",
srcIP: "172.16.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: fw.Network{Set: fw.Set{}},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{ {
name: "Accept TCP traffic outside drop port range", name: "Accept TCP traffic outside drop port range",
srcIP: "100.10.0.1", srcIP: "100.10.0.1",
@ -1084,7 +1104,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 7999, dstPort: 7999,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
action: fw.ActionDrop, action: fw.ActionDrop,
@ -1100,7 +1120,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
@ -1115,7 +1135,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 53, dstPort: 53,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP, proto: fw.ProtocolUDP,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
@ -1130,7 +1150,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP, proto: fw.ProtocolUDP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -1146,7 +1166,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -1160,7 +1180,7 @@ func TestRouteACLFiltering(t *testing.T) {
proto: fw.ProtocolICMP, proto: fw.ProtocolICMP,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
@ -1173,7 +1193,7 @@ func TestRouteACLFiltering(t *testing.T) {
proto: fw.ProtocolICMP, proto: fw.ProtocolICMP,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP, proto: fw.ProtocolUDP,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
@ -1188,7 +1208,7 @@ func TestRouteACLFiltering(t *testing.T) {
rule, err := manager.AddRouteFiltering( rule, err := manager.AddRouteFiltering(
nil, nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
netip.MustParsePrefix("0.0.0.0/0"), fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
fw.ProtocolALL, fw.ProtocolALL,
nil, nil,
nil, nil,
@ -1235,7 +1255,7 @@ func TestRouteACLOrder(t *testing.T) {
name string name string
rules []struct { rules []struct {
sources []netip.Prefix sources []netip.Prefix
dest netip.Prefix dest fw.Network
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@ -1256,7 +1276,7 @@ func TestRouteACLOrder(t *testing.T) {
name: "Drop rules take precedence over accept", name: "Drop rules take precedence over accept",
rules: []struct { rules: []struct {
sources []netip.Prefix sources []netip.Prefix
dest netip.Prefix dest fw.Network
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@ -1265,7 +1285,7 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Accept rule added first // Accept rule added first
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80, 443}}, dstPort: &fw.Port{Values: []uint16{80, 443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@ -1273,7 +1293,7 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Drop rule added second but should be evaluated first // Drop rule added second but should be evaluated first
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop, action: fw.ActionDrop,
@ -1311,7 +1331,7 @@ func TestRouteACLOrder(t *testing.T) {
name: "Multiple drop rules take precedence", name: "Multiple drop rules take precedence",
rules: []struct { rules: []struct {
sources []netip.Prefix sources []netip.Prefix
dest netip.Prefix dest fw.Network
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@ -1320,14 +1340,14 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Accept all // Accept all
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("0.0.0.0/0"), dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
{ {
// Drop specific port // Drop specific port
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop, action: fw.ActionDrop,
@ -1335,7 +1355,7 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Drop different port // Drop different port
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionDrop, action: fw.ActionDrop,
@ -1414,3 +1434,53 @@ func TestRouteACLOrder(t *testing.T) {
}) })
} }
} }
func TestRouteACLSet(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
// Add rule that uses the set (initially empty)
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP := netip.MustParseAddr("192.168.1.100")
// Check that traffic is dropped (empty set shouldn't match anything)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.False(t, isAllowed, "Empty set should not allow any traffic")
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
require.NoError(t, err)
// Now the packet should be allowed
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
}

View File

@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/management/domain"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
@ -711,3 +712,203 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}) })
} }
} }
func TestUpdateSetMerge(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
initialPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.1.0/24"),
}
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
// Update the set with initial prefixes
err = manager.UpdateSet(set, initialPrefixes)
require.NoError(t, err)
// Test initial prefixes work
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP1 := netip.MustParseAddr("10.0.0.100")
dstIP2 := netip.MustParseAddr("192.168.1.100")
dstIP3 := netip.MustParseAddr("172.16.0.100")
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied")
newPrefixes := []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("10.1.0.0/24"),
}
err = manager.UpdateSet(set, newPrefixes)
require.NoError(t, err)
// Check that all original prefixes are still included
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
// Check that new prefixes are included
dstIP4 := netip.MustParseAddr("172.16.1.100")
dstIP5 := netip.MustParseAddr("10.1.0.50")
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
// Verify the rule has all prefixes
manager.mutex.RLock()
foundRule := false
for _, r := range manager.routeRules {
if r.id == rule.ID() {
foundRule = true
require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes),
"Rule should have all prefixes merged")
}
}
manager.mutex.RUnlock()
require.True(t, foundRule, "Rule should be found")
}
func TestUpdateSetDeduplication(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
initialPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.0/24"), // Duplicate
}
err = manager.UpdateSet(set, initialPrefixes)
require.NoError(t, err)
// Check the internal state for deduplication
manager.mutex.RLock()
foundRule := false
for _, r := range manager.routeRules {
if r.id == rule.ID() {
foundRule = true
// Should have deduplicated to 2 prefixes
require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed")
// Check the prefixes are correct
expectedPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.1.0/24"),
}
for i, prefix := range expectedPrefixes {
require.True(t, r.destinations[i] == prefix,
"Prefix should match expected value")
}
}
}
manager.mutex.RUnlock()
require.True(t, foundRule, "Rule should be found")
// Test with overlapping prefixes of different sizes
overlappingPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/16"), // More general
netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists)
netip.MustParsePrefix("192.168.0.0/16"), // More general
netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists)
}
err = manager.UpdateSet(set, overlappingPrefixes)
require.NoError(t, err)
// Check that all prefixes are included (no deduplication of overlapping prefixes)
manager.mutex.RLock()
for _, r := range manager.routeRules {
if r.id == rule.ID() {
// Should have all 4 prefixes (2 original + 2 new more general ones)
require.Len(t, r.destinations, 4,
"Overlapping prefixes should not be deduplicated")
// Verify they're sorted correctly (more specific prefixes should come first)
prefixes := make([]string, 0, len(r.destinations))
for _, p := range r.destinations {
prefixes = append(prefixes, p.String())
}
// Check sorted order
require.Equal(t, []string{
"10.0.0.0/16",
"10.0.0.0/24",
"192.168.0.0/16",
"192.168.1.0/24",
}, prefixes, "Prefixes should be sorted")
}
}
manager.mutex.RUnlock()
// Test functionality with all prefixes
testCases := []struct {
dstIP netip.Addr
expected bool
desc string
}{
{netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"},
{netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"},
{netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"},
{netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"},
{netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"},
}
srcIP := netip.MustParseAddr("100.10.0.1")
for _, tc := range testCases {
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
require.Equal(t, tc.expected, isAllowed, tc.desc)
}
}

View File

@ -18,7 +18,7 @@ func (r RuleID) ID() string {
func GenerateRouteRuleKey( func GenerateRouteRuleKey(
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination manager.Network,
proto manager.Protocol, proto manager.Protocol,
sPort *manager.Port, sPort *manager.Port,
dPort *manager.Port, dPort *manager.Port,

View File

@ -18,6 +18,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
@ -25,7 +26,7 @@ var ErrSourceRangesEmpty = errors.New("sources range is empty")
// Manager is a ACL rules manager // Manager is a ACL rules manager
type Manager interface { type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
} }
type protoMatch struct { type protoMatch struct {
@ -53,7 +54,7 @@ func NewDefaultManager(fm firewall.Manager) *DefaultManager {
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy. // ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
// //
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains. // If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) {
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
@ -82,7 +83,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
log.Errorf("failed to set legacy management flag: %v", err) log.Errorf("failed to set legacy management flag: %v", err)
} }
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil { if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err) log.Errorf("Failed to apply route ACLs: %v", err)
} }
@ -176,16 +177,16 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
d.peerRulesPairs = newRulePairs d.peerRulesPairs = newRulePairs
} }
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
newRouteRules := make(map[id.RuleID]struct{}, len(rules)) newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error var merr *multierror.Error
// Apply new rules - firewall manager will return existing rule ID if already present // Apply new rules - firewall manager will return existing rule ID if already present
for _, rule := range rules { for _, rule := range rules {
id, err := d.applyRouteACL(rule) id, err := d.applyRouteACL(rule, dynamicResolver)
if err != nil { if err != nil {
if errors.Is(err, ErrSourceRangesEmpty) { if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err) log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
} else { } else {
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err)) merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
} }
@ -208,7 +209,7 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) err
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 { if len(rule.SourceRanges) == 0 {
return "", ErrSourceRangesEmpty return "", ErrSourceRangesEmpty
} }
@ -222,15 +223,9 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
sources = append(sources, source) sources = append(sources, source)
} }
var destination netip.Prefix destination, err := determineDestination(rule, dynamicResolver, sources)
if rule.IsDynamic { if err != nil {
destination = getDefault(sources[0]) return "", fmt.Errorf("determine destination: %w", err)
} else {
var err error
destination, err = netip.ParsePrefix(rule.Destination)
if err != nil {
return "", fmt.Errorf("parse destination: %w", err)
}
} }
protocol, err := convertToFirewallProtocol(rule.Protocol) protocol, err := convertToFirewallProtocol(rule.Protocol)
@ -580,6 +575,33 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
return nil return nil
} }
func determineDestination(rule *mgmProto.RouteFirewallRule, dynamicResolver bool, sources []netip.Prefix) (firewall.Network, error) {
var destination firewall.Network
if rule.IsDynamic {
if dynamicResolver {
if len(rule.Domains) > 0 {
destination.Set = firewall.NewDomainSet(domain.FromPunycodeList(rule.Domains))
} else {
// isDynamic is set but no domains = outdated management server
log.Warn("connected to an older version of management server (no domains in rules), using default destination")
destination.Prefix = getDefault(sources[0])
}
} else {
// client resolves DNS, we (router) don't know the destination
destination.Prefix = getDefault(sources[0])
}
return destination, nil
}
prefix, err := netip.ParsePrefix(rule.Destination)
if err != nil {
return destination, fmt.Errorf("parse destination: %w", err)
}
destination.Prefix = prefix
return destination, nil
}
func getDefault(prefix netip.Prefix) netip.Prefix { func getDefault(prefix netip.Prefix) netip.Prefix {
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0) return netip.PrefixFrom(netip.IPv6Unspecified(), 0)

View File

@ -66,7 +66,7 @@ func TestDefaultManager(t *testing.T) {
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 2 { if len(acl.peerRulesPairs) != 2 {
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
@ -92,7 +92,7 @@ func TestDefaultManager(t *testing.T) {
}, },
) )
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap, false)
// we should have one old and one new rule in the existed rules // we should have one old and one new rule in the existed rules
if len(acl.peerRulesPairs) != 2 { if len(acl.peerRulesPairs) != 2 {
@ -116,13 +116,13 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true networkMap.FirewallRulesIsEmpty = true
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 { if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 {
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
return return
} }
networkMap.FirewallRulesIsEmpty = false networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 1 { if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
return return
@ -359,7 +359,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}(fw) }(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 3 { if len(acl.peerRulesPairs) != 3 {
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))

View File

@ -59,6 +59,16 @@ func collectIPTablesRules() (string, error) {
builder.WriteString("\n") builder.WriteString("\n")
} }
// Collect ipset information
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect ipset information: %v", err)
} else {
builder.WriteString("=== ipset list output ===\n")
builder.WriteString(ipsetOutput)
builder.WriteString("\n")
}
builder.WriteString("=== iptables -v -n -L output ===\n") builder.WriteString("=== iptables -v -n -L output ===\n")
tables := []string{"filter", "nat", "mangle", "raw", "security"} tables := []string{"filter", "nat", "mangle", "raw", "security"}
@ -78,6 +88,28 @@ func collectIPTablesRules() (string, error) {
return builder.String(), nil return builder.String(), nil
} }
// collectIPSets collects information about ipsets
func collectIPSets() (string, error) {
cmd := exec.Command("ipset", "list")
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if strings.Contains(err.Error(), "executable file not found") {
return "", fmt.Errorf("ipset command not found: %w", err)
}
return "", fmt.Errorf("execute ipset list: %w (stderr: %s)", err, stderr.String())
}
ipsets := stdout.String()
if strings.TrimSpace(ipsets) == "" {
return "No ipsets found", nil
}
return ipsets, nil
}
// collectIPTablesSave uses iptables-save to get rule definitions // collectIPTablesSave uses iptables-save to get rule definitions
func collectIPTablesSave() (string, error) { func collectIPTablesSave() (string, error) {
cmd := exec.Command("iptables-save") cmd := exec.Command("iptables-save")

View File

@ -3,6 +3,7 @@ package dnsfwd
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"math" "math"
"net" "net"
"net/netip" "net/netip"
@ -10,11 +11,16 @@ import (
"sync" "sync"
"time" "time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
) )
const errResolveFailed = "failed to resolve query for domain=%s: %v" const errResolveFailed = "failed to resolve query for domain=%s: %v"
@ -23,25 +29,27 @@ const upstreamTimeout = 15 * time.Second
type DNSForwarder struct { type DNSForwarder struct {
listenAddress string listenAddress string
ttl uint32 ttl uint32
domains []string
statusRecorder *peer.Status statusRecorder *peer.Status
dnsServer *dns.Server dnsServer *dns.Server
mux *dns.ServeMux mux *dns.ServeMux
resId sync.Map mutex sync.RWMutex
fwdEntries []*ForwarderEntry
firewall firewall.Manager
} }
func NewDNSForwarder(listenAddress string, ttl uint32, statusRecorder *peer.Status) *DNSForwarder { func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{ return &DNSForwarder{
listenAddress: listenAddress, listenAddress: listenAddress,
ttl: ttl, ttl: ttl,
firewall: firewall,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
} }
} }
func (f *DNSForwarder) Listen(domains []string, resIds map[string]string) error { func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
log.Infof("listen DNS forwarder on address=%s", f.listenAddress) log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
mux := dns.NewServeMux() mux := dns.NewServeMux()
@ -53,31 +61,35 @@ func (f *DNSForwarder) Listen(domains []string, resIds map[string]string) error
f.dnsServer = dnsServer f.dnsServer = dnsServer
f.mux = mux f.mux = mux
f.UpdateDomains(domains, resIds) f.UpdateDomains(entries)
return dnsServer.ListenAndServe() return dnsServer.ListenAndServe()
} }
func (f *DNSForwarder) UpdateDomains(domains []string, resIds map[string]string) { func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
log.Debugf("Updating domains from %v to %v", f.domains, domains) f.mutex.Lock()
defer f.mutex.Unlock()
for _, d := range f.domains { if f.mux == nil {
f.mux.HandleRemove(d) log.Debug("DNS mux is nil, skipping domain update")
f.fwdEntries = entries
return
} }
f.resId.Clear()
newDomains := filterDomains(domains) oldDomains := filterDomains(f.fwdEntries)
for _, d := range oldDomains {
f.mux.HandleRemove(d.PunycodeString())
}
newDomains := filterDomains(entries)
for _, d := range newDomains { for _, d := range newDomains {
f.mux.HandleFunc(d, f.handleDNSQuery) f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQuery)
} }
for domain, resId := range resIds { f.fwdEntries = entries
if domain != "" {
f.resId.Store(domain, resId)
}
}
f.domains = newDomains log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
} }
func (f *DNSForwarder) Close(ctx context.Context) error { func (f *DNSForwarder) Close(ctx context.Context) error {
@ -91,11 +103,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
if len(query.Question) == 0 { if len(query.Question) == 0 {
return return
} }
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass)
question := query.Question[0] question := query.Question[0]
domain := question.Name log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
question.Name, question.Qtype, question.Qclass)
domain := strings.ToLower(question.Name)
resp := query.SetReply(query) resp := query.SetReply(query)
var network string var network string
@ -122,21 +134,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
return return
} }
resId := f.getResIdForDomain(strings.TrimSuffix(domain, ".")) f.updateInternalState(domain, ips)
if resId != "" {
for _, ip := range ips {
var ipWithSuffix string
if ip.Is4() {
ipWithSuffix = ip.String() + "/32"
log.Tracef("resolved domain=%s to IPv4=%s", domain, ipWithSuffix)
} else {
ipWithSuffix = ip.String() + "/128"
log.Tracef("resolved domain=%s to IPv6=%s", domain, ipWithSuffix)
}
f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId)
}
}
f.addIPsToResponse(resp, domain, ips) f.addIPsToResponse(resp, domain, ips)
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
@ -144,6 +142,42 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
} }
} }
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) {
var prefixes []netip.Prefix
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
if mostSpecificResId != "" {
for _, ip := range ips {
var prefix netip.Prefix
if ip.Is4() {
prefix = netip.PrefixFrom(ip, 32)
} else {
prefix = netip.PrefixFrom(ip, 128)
}
prefixes = append(prefixes, prefix)
f.statusRecorder.AddResolvedIPLookupEntry(prefix, mostSpecificResId)
}
}
if f.firewall != nil {
f.updateFirewall(matchingEntries, prefixes)
}
}
func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixes []netip.Prefix) {
var merr *multierror.Error
for _, entry := range matchingEntries {
if err := f.firewall.UpdateSet(entry.Set, prefixes); err != nil {
merr = multierror.Append(merr, fmt.Errorf("update set for domain=%s: %w", entry.Domain, err))
}
}
if merr != nil {
log.Errorf("failed to update firewall sets (%d/%d): %v",
len(merr.Errors),
len(matchingEntries),
nberrors.FormatErrorOrNil(merr))
}
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response // handleDNSError processes DNS lookup errors and sends an appropriate error response
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) { func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) {
var dnsErr *net.DNSError var dnsErr *net.DNSError
@ -204,45 +238,53 @@ func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []neti
} }
} }
func (f *DNSForwarder) getResIdForDomain(domain string) string { // getMatchingEntries retrieves the resource IDs for a given domain.
var selectedResId string // It returns the most specific match and all matching resource IDs.
func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*ForwarderEntry) {
var selectedResId route.ResID
var bestScore int var bestScore int
var matches []*ForwarderEntry
f.resId.Range(func(key, value interface{}) bool { f.mutex.RLock()
defer f.mutex.RUnlock()
for _, entry := range f.fwdEntries {
var score int var score int
pattern := key.(string) pattern := entry.Domain.PunycodeString()
switch { switch {
case strings.HasPrefix(pattern, "*."): case strings.HasPrefix(pattern, "*."):
baseDomain := strings.TrimPrefix(pattern, "*.") baseDomain := strings.TrimPrefix(pattern, "*.")
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
if strings.EqualFold(domain, baseDomain) || strings.HasSuffix(domain, "."+baseDomain) {
score = len(baseDomain) score = len(baseDomain)
matches = append(matches, entry)
} }
case domain == pattern: case domain == pattern:
score = math.MaxInt score = math.MaxInt
matches = append(matches, entry)
default: default:
return true continue
} }
if score > bestScore { if score > bestScore {
bestScore = score bestScore = score
selectedResId = value.(string) selectedResId = entry.ResID
} }
return true }
})
return selectedResId return selectedResId, matches
} }
// filterDomains returns a list of normalized domains // filterDomains returns a list of normalized domains
func filterDomains(domains []string) []string { func filterDomains(entries []*ForwarderEntry) domain.List {
newDomains := make([]string, 0, len(domains)) newDomains := make(domain.List, 0, len(entries))
for _, d := range domains { for _, d := range entries {
if d == "" { if d.Domain == "" {
log.Warn("empty domain in DNS forwarder") log.Warn("empty domain in DNS forwarder")
continue continue
} }
newDomains = append(newDomains, nbdns.NormalizeZone(d)) newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
} }
return newDomains return newDomains
} }

View File

@ -1,56 +1,61 @@
package dnsfwd package dnsfwd
import ( import (
"sync"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
) )
func TestGetResIdForDomain(t *testing.T) { func Test_getMatchingEntries(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
storedMappings map[string]string // key: domain pattern, value: resId storedMappings map[string]route.ResID // key: domain pattern, value: resId
queryDomain string queryDomain string
expectedResId string expectedResId route.ResID
}{ }{
{ {
name: "Empty map returns empty string", name: "Empty map returns empty string",
storedMappings: map[string]string{}, storedMappings: map[string]route.ResID{},
queryDomain: "example.com", queryDomain: "example.com",
expectedResId: "", expectedResId: "",
}, },
{ {
name: "Exact match returns stored resId", name: "Exact match returns stored resId",
storedMappings: map[string]string{"example.com": "res1"}, storedMappings: map[string]route.ResID{"example.com": "res1"},
queryDomain: "example.com", queryDomain: "example.com",
expectedResId: "res1", expectedResId: "res1",
}, },
{ {
name: "Wildcard pattern matches base domain", name: "Wildcard pattern matches base domain",
storedMappings: map[string]string{"*.example.com": "res2"}, storedMappings: map[string]route.ResID{"*.example.com": "res2"},
queryDomain: "example.com", queryDomain: "example.com",
expectedResId: "res2", expectedResId: "res2",
}, },
{ {
name: "Wildcard pattern matches subdomain", name: "Wildcard pattern matches subdomain",
storedMappings: map[string]string{"*.example.com": "res3"}, storedMappings: map[string]route.ResID{"*.example.com": "res3"},
queryDomain: "foo.example.com", queryDomain: "foo.example.com",
expectedResId: "res3", expectedResId: "res3",
}, },
{ {
name: "Wildcard pattern does not match different domain", name: "Wildcard pattern does not match different domain",
storedMappings: map[string]string{"*.example.com": "res4"}, storedMappings: map[string]route.ResID{"*.example.com": "res4"},
queryDomain: "foo.notexample.com", queryDomain: "foo.notexample.com",
expectedResId: "", expectedResId: "",
}, },
{ {
name: "Non-wildcard pattern does not match subdomain", name: "Non-wildcard pattern does not match subdomain",
storedMappings: map[string]string{"example.com": "res5"}, storedMappings: map[string]route.ResID{"example.com": "res5"},
queryDomain: "foo.example.com", queryDomain: "foo.example.com",
expectedResId: "", expectedResId: "",
}, },
{ {
name: "Exact match over overlapping wildcard", name: "Exact match over overlapping wildcard",
storedMappings: map[string]string{ storedMappings: map[string]route.ResID{
"*.example.com": "resWildcard", "*.example.com": "resWildcard",
"foo.example.com": "resExact", "foo.example.com": "resExact",
}, },
@ -59,7 +64,7 @@ func TestGetResIdForDomain(t *testing.T) {
}, },
{ {
name: "Overlapping wildcards: Select more specific wildcard", name: "Overlapping wildcards: Select more specific wildcard",
storedMappings: map[string]string{ storedMappings: map[string]route.ResID{
"*.example.com": "resA", "*.example.com": "resA",
"*.sub.example.com": "resB", "*.sub.example.com": "resB",
}, },
@ -68,7 +73,7 @@ func TestGetResIdForDomain(t *testing.T) {
}, },
{ {
name: "Wildcard multi-level subdomain match", name: "Wildcard multi-level subdomain match",
storedMappings: map[string]string{ storedMappings: map[string]route.ResID{
"*.example.com": "resMulti", "*.example.com": "resMulti",
}, },
queryDomain: "a.b.example.com", queryDomain: "a.b.example.com",
@ -78,18 +83,21 @@ func TestGetResIdForDomain(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
fwd := &DNSForwarder{ fwd := &DNSForwarder{}
resId: sync.Map{},
}
var entries []*ForwarderEntry
for domainPattern, resId := range tc.storedMappings { for domainPattern, resId := range tc.storedMappings {
fwd.resId.Store(domainPattern, resId) d, err := domain.FromString(domainPattern)
require.NoError(t, err)
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: resId,
})
} }
fwd.UpdateDomains(entries)
got := fwd.getResIdForDomain(tc.queryDomain) got, _ := fwd.getMatchingEntries(tc.queryDomain)
if got != tc.expectedResId { assert.Equal(t, got, tc.expectedResId)
t.Errorf("For query domain %q, expected resId %q, but got %q", tc.queryDomain, tc.expectedResId, got)
}
}) })
} }
} }

View File

@ -11,6 +11,8 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
) )
const ( const (
@ -19,6 +21,13 @@ const (
dnsTTL = 60 //seconds dnsTTL = 60 //seconds
) )
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
type ForwarderEntry struct {
Domain domain.Domain
ResID route.ResID
Set firewall.Set
}
type Manager struct { type Manager struct {
firewall firewall.Manager firewall firewall.Manager
statusRecorder *peer.Status statusRecorder *peer.Status
@ -34,7 +43,7 @@ func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
} }
} }
func (m *Manager) Start(domains []string, resIds map[string]string) error { func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
log.Infof("starting DNS forwarder") log.Infof("starting DNS forwarder")
if m.dnsForwarder != nil { if m.dnsForwarder != nil {
return nil return nil
@ -44,9 +53,9 @@ func (m *Manager) Start(domains []string, resIds map[string]string) error {
return err return err
} }
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.statusRecorder) m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder)
go func() { go func() {
if err := m.dnsForwarder.Listen(domains, resIds); err != nil { if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
// todo handle close error if it is exists // todo handle close error if it is exists
log.Errorf("failed to start DNS forwarder, err: %v", err) log.Errorf("failed to start DNS forwarder, err: %v", err)
} }
@ -55,12 +64,12 @@ func (m *Manager) Start(domains []string, resIds map[string]string) error {
return nil return nil
} }
func (m *Manager) UpdateDomains(domains []string, resIds map[string]string) { func (m *Manager) UpdateDomains(entries []*ForwarderEntry) {
if m.dnsForwarder == nil { if m.dnsForwarder == nil {
return return
} }
m.dnsForwarder.UpdateDomains(domains, resIds) m.dnsForwarder.UpdateDomains(entries)
} }
func (m *Manager) Stop(ctx context.Context) error { func (m *Manager) Stop(ctx context.Context) error {
@ -81,34 +90,34 @@ func (m *Manager) Stop(ctx context.Context) error {
return nberrors.FormatErrorOrNil(mErr) return nberrors.FormatErrorOrNil(mErr)
} }
func (h *Manager) allowDNSFirewall() error { func (m *Manager) allowDNSFirewall() error {
dport := &firewall.Port{ dport := &firewall.Port{
IsRange: false, IsRange: false,
Values: []uint16{ListenPort}, Values: []uint16{ListenPort},
} }
if h.firewall == nil { if m.firewall == nil {
return nil return nil
} }
dnsRules, err := h.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "") dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
if err != nil { if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err) log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err return err
} }
h.fwRules = dnsRules m.fwRules = dnsRules
return nil return nil
} }
func (h *Manager) dropDNSFirewall() error { func (m *Manager) dropDNSFirewall() error {
var mErr *multierror.Error var mErr *multierror.Error
for _, rule := range h.fwRules { for _, rule := range m.fwRules {
if err := h.firewall.DeletePeerRule(rule); err != nil { if err := m.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
} }
} }
h.fwRules = nil m.fwRules = nil
return nberrors.FormatErrorOrNil(mErr) return nberrors.FormatErrorOrNil(mErr)
} }

View File

@ -527,7 +527,7 @@ func (e *Engine) blockLanAccess() {
if _, err := e.firewall.AddRouteFiltering( if _, err := e.firewall.AddRouteFiltering(
nil, nil,
[]netip.Prefix{v4}, []netip.Prefix{v4},
network, firewallManager.Network{Prefix: network},
firewallManager.ProtocolALL, firewallManager.ProtocolALL,
nil, nil,
nil, nil,
@ -960,21 +960,21 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
} }
// DNS forwarder
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
dnsRouteDomains, resourceIds := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains, resourceIds)
// apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes()) routes := toRoutes(networkMap.GetRoutes())
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err) log.Errorf("failed to update clientRoutes, err: %v", err)
} }
// acls might need routing to be enabled, so we apply after routes
if e.acl != nil { if e.acl != nil {
e.acl.ApplyFiltering(networkMap) e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
} }
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
// Ingress forward rules // Ingress forward rules
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil { if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
log.Errorf("failed to update forward rules, err: %v", err) log.Errorf("failed to update forward rules, err: %v", err)
@ -1079,29 +1079,24 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
return routes return routes
} }
func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) ([]string, map[string]string) { func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderEntry {
if protoRoutes == nil { var entries []*dnsfwd.ForwarderEntry
protoRoutes = []*mgmProto.Route{} for _, route := range routes {
} if len(route.Domains) == 0 {
var dnsRoutes []string
resIds := make(map[string]string)
for _, protoRoute := range protoRoutes {
if len(protoRoute.Domains) == 0 {
continue continue
} }
if protoRoute.Peer == myPubKey { if route.Peer == myPubKey {
dnsRoutes = append(dnsRoutes, protoRoute.Domains...) domainSet := firewallManager.NewDomainSet(route.Domains)
// resource ID is the first part of the ID for _, d := range route.Domains {
resId := strings.Split(protoRoute.ID, ":") entries = append(entries, &dnsfwd.ForwarderEntry{
for _, domain := range protoRoute.Domains { Domain: d,
if len(resId) > 0 { Set: domainSet,
resIds[domain] = resId[0] ResID: route.GetResourceID(),
} })
} }
} }
} }
return dnsRoutes, resIds return entries
} }
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config { func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
@ -1751,7 +1746,10 @@ func (e *Engine) GetWgAddr() net.IP {
} }
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag // updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
func (e *Engine) updateDNSForwarder(enabled bool, domains []string, resIds map[string]string) { func (e *Engine) updateDNSForwarder(
enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry,
) {
if !enabled { if !enabled {
if e.dnsForwardMgr == nil { if e.dnsForwardMgr == nil {
return return
@ -1762,18 +1760,18 @@ func (e *Engine) updateDNSForwarder(enabled bool, domains []string, resIds map[s
return return
} }
if len(domains) > 0 { if len(fwdEntries) > 0 {
log.Infof("enable domain router service for domains: %v", domains)
if e.dnsForwardMgr == nil { if e.dnsForwardMgr == nil {
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if err := e.dnsForwardMgr.Start(domains, resIds); err != nil { if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err) log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil e.dnsForwardMgr = nil
} }
log.Infof("started domain router service with %d entries", len(fwdEntries))
} else { } else {
log.Infof("update domain router service for domains: %v", domains) e.dnsForwardMgr.UpdateDomains(fwdEntries)
e.dnsForwardMgr.UpdateDomains(domains, resIds)
} }
} else if e.dnsForwardMgr != nil { } else if e.dnsForwardMgr != nil {
log.Infof("disable domain router service") log.Infof("disable domain router service")

View File

@ -6,12 +6,14 @@ import (
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/route"
) )
// routeEntry holds the route prefix and the corresponding resource ID. // routeEntry holds the route prefix and the corresponding resource ID.
type routeEntry struct { type routeEntry struct {
prefix netip.Prefix prefix netip.Prefix
resourceID string resourceID route.ResID
} }
type routeIDLookup struct { type routeIDLookup struct {
@ -24,7 +26,7 @@ type routeIDLookup struct {
resolvedIPs sync.Map resolvedIPs sync.Map
} }
func (r *routeIDLookup) AddLocalRouteID(resourceID string, route netip.Prefix) { func (r *routeIDLookup) AddLocalRouteID(resourceID route.ResID, route netip.Prefix) {
r.localLock.Lock() r.localLock.Lock()
defer r.localLock.Unlock() defer r.localLock.Unlock()
@ -56,7 +58,7 @@ func (r *routeIDLookup) RemoveLocalRouteID(route netip.Prefix) {
} }
} }
func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) { func (r *routeIDLookup) AddRemoteRouteID(resourceID route.ResID, route netip.Prefix) {
r.remoteLock.Lock() r.remoteLock.Lock()
defer r.remoteLock.Unlock() defer r.remoteLock.Unlock()
@ -87,7 +89,7 @@ func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) {
} }
} }
func (r *routeIDLookup) AddResolvedIP(resourceID string, route netip.Prefix) { func (r *routeIDLookup) AddResolvedIP(resourceID route.ResID, route netip.Prefix) {
r.resolvedIPs.Store(route.Addr(), resourceID) r.resolvedIPs.Store(route.Addr(), resourceID)
} }
@ -97,19 +99,19 @@ func (r *routeIDLookup) RemoveResolvedIP(route netip.Prefix) {
// Lookup returns the resource ID for the given IP address // Lookup returns the resource ID for the given IP address
// and a bool indicating if the IP is an exit node. // and a bool indicating if the IP is an exit node.
func (r *routeIDLookup) Lookup(ip netip.Addr) (string, bool) { func (r *routeIDLookup) Lookup(ip netip.Addr) (route.ResID, bool) {
if res, ok := r.resolvedIPs.Load(ip); ok { if res, ok := r.resolvedIPs.Load(ip); ok {
return res.(string), false return res.(route.ResID), false
} }
var resourceID string var resourceID route.ResID
var isExitNode bool var isExitNode bool
r.localLock.RLock() r.localLock.RLock()
for _, entry := range r.localRoutes { for _, entry := range r.localRoutes {
if entry.prefix.Contains(ip) { if entry.prefix.Contains(ip) {
resourceID = entry.resourceID resourceID = entry.resourceID
isExitNode = (entry.prefix.Bits() == 0) isExitNode = entry.prefix.Bits() == 0
break break
} }
} }
@ -120,7 +122,7 @@ func (r *routeIDLookup) Lookup(ip netip.Addr) (string, bool) {
for _, entry := range r.remoteRoutes { for _, entry := range r.remoteRoutes {
if entry.prefix.Contains(ip) { if entry.prefix.Contains(ip) {
resourceID = entry.resourceID resourceID = entry.resourceID
isExitNode = (entry.prefix.Bits() == 0) isExitNode = entry.prefix.Bits() == 0
break break
} }
} }

View File

@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
) )
const eventQueueSize = 10 const eventQueueSize = 10
@ -313,7 +314,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil return nil
} }
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId string) error { func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@ -581,7 +582,7 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
} }
// AddLocalPeerStateRoute adds a route to the local peer state // AddLocalPeerStateRoute adds a route to the local peer state
func (d *Status) AddLocalPeerStateRoute(route, resourceId string) { func (d *Status) AddLocalPeerStateRoute(route string, resourceId route.ResID) {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@ -611,14 +612,11 @@ func (d *Status) RemoveLocalPeerStateRoute(route string) {
} }
// AddResolvedIPLookupEntry adds a resolved IP lookup entry // AddResolvedIPLookupEntry adds a resolved IP lookup entry
func (d *Status) AddResolvedIPLookupEntry(route, resourceId string) { func (d *Status) AddResolvedIPLookupEntry(prefix netip.Prefix, resourceId route.ResID) {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
pref, err := netip.ParsePrefix(route) d.routeIDLookup.AddResolvedIP(resourceId, prefix)
if err == nil {
d.routeIDLookup.AddResolvedIP(resourceId, pref)
}
} }
// RemoveResolvedIPLookupEntry removes a resolved IP lookup entry // RemoveResolvedIPLookupEntry removes a resolved IP lookup entry
@ -723,7 +721,7 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.nsGroupStates = dnsStates d.nsGroupStates = dnsStates
} }
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId string) { func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId route.ResID) {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()

View File

@ -234,7 +234,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
origPattern = writer.GetOrigPattern() origPattern = writer.GetOrigPattern()
} }
resolvedDomain := domain.Domain(r.Question[0].Name) resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name))
// already punycode via RegisterHandler() // already punycode via RegisterHandler()
originalDomain := domain.Domain(origPattern) originalDomain := domain.Domain(origPattern)
@ -328,6 +328,11 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
// Update domain prefixes using resolved domain as key // Update domain prefixes using resolved domain as key
if len(toAdd) > 0 || len(toRemove) > 0 { if len(toAdd) > 0 || len(toRemove) > 0 {
if d.route.KeepRoute {
// replace stored prefixes with old + added
// nolint:gocritic
newPrefixes = append(oldPrefixes, toAdd...)
}
d.interceptedDomains[resolvedDomain] = newPrefixes d.interceptedDomains[resolvedDomain] = newPrefixes
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), ".")) originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID()) d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
@ -338,7 +343,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
originalDomain.SafeString(), originalDomain.SafeString(),
toAdd) toAdd)
} }
if len(toRemove) > 0 { if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(), resolvedDomain.SafeString(),
originalDomain.SafeString(), originalDomain.SafeString(),

View File

@ -259,8 +259,6 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
} }
} }
m.ctx = nil
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
m.clientRoutes = nil m.clientRoutes = nil
@ -292,7 +290,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
return nil return nil
} }
if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil { if err := m.serverRouter.updateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil {
return fmt.Errorf("update routes: %w", err) return fmt.Errorf("update routes: %w", err)
} }

View File

@ -18,7 +18,7 @@ type serverRouter struct {
func (r serverRouter) cleanUp() { func (r serverRouter) cleanUp() {
} }
func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error { func (r serverRouter) updateRoutes(map[route.ID]*route.Route, bool) error {
return nil return nil
} }

View File

@ -35,7 +35,10 @@ func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall fi
}, nil }, nil
} }
func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error {
m.mux.Lock()
defer m.mux.Unlock()
serverRoutesToRemove := make([]route.ID, 0) serverRoutesToRemove := make([]route.ID, 0)
for routeID := range m.routes { for routeID := range m.routes {
@ -73,7 +76,7 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
continue continue
} }
err := m.addToServerNetwork(newRoute) err := m.addToServerNetwork(newRoute, useNewDNSRoute)
if err != nil { if err != nil {
log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err) log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
continue continue
@ -90,57 +93,30 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
return m.ctx.Err() return m.ctx.Err()
} }
m.mux.Lock() routerPair := routeToRouterPair(route, false)
defer m.mux.Unlock() if err := m.firewall.RemoveNatRule(routerPair); err != nil {
routerPair, err := routeToRouterPair(route)
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
}
err = m.firewall.RemoveNatRule(routerPair)
if err != nil {
return fmt.Errorf("remove routing rules: %w", err) return fmt.Errorf("remove routing rules: %w", err)
} }
delete(m.routes, route.ID) delete(m.routes, route.ID)
m.statusRecorder.RemoveLocalPeerStateRoute(route.NetString())
routeStr := route.Network.String()
if route.IsDynamic() {
routeStr = route.Domains.SafeString()
}
m.statusRecorder.RemoveLocalPeerStateRoute(routeStr)
return nil return nil
} }
func (m *serverRouter) addToServerNetwork(route *route.Route) error { func (m *serverRouter) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error {
if m.ctx.Err() != nil { if m.ctx.Err() != nil {
log.Infof("Not adding to server network because context is done") log.Infof("Not adding to server network because context is done")
return m.ctx.Err() return m.ctx.Err()
} }
m.mux.Lock() routerPair := routeToRouterPair(route, useNewDNSRoute)
defer m.mux.Unlock() if err := m.firewall.AddNatRule(routerPair); err != nil {
routerPair, err := routeToRouterPair(route)
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
}
err = m.firewall.AddNatRule(routerPair)
if err != nil {
return fmt.Errorf("insert routing rules: %w", err) return fmt.Errorf("insert routing rules: %w", err)
} }
m.routes[route.ID] = route m.routes[route.ID] = route
m.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID())
routeStr := route.Network.String()
if route.IsDynamic() {
routeStr = route.Domains.SafeString()
}
m.statusRecorder.AddLocalPeerStateRoute(routeStr, route.GetResourceID())
return nil return nil
} }
@ -148,31 +124,29 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error {
func (m *serverRouter) cleanUp() { func (m *serverRouter) cleanUp() {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
for _, r := range m.routes {
routerPair, err := routeToRouterPair(r)
if err != nil {
log.Errorf("Failed to convert route to router pair: %v", err)
continue
}
err = m.firewall.RemoveNatRule(routerPair) for _, r := range m.routes {
if err != nil { routerPair := routeToRouterPair(r, false)
if err := m.firewall.RemoveNatRule(routerPair); err != nil {
log.Errorf("Failed to remove cleanup route: %v", err) log.Errorf("Failed to remove cleanup route: %v", err)
} }
} }
m.statusRecorder.CleanLocalPeerStateRoutes() m.statusRecorder.CleanLocalPeerStateRoutes()
} }
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterPair {
// TODO: add ipv6
source := getDefaultPrefix(route.Network) source := getDefaultPrefix(route.Network)
destination := firewall.Network{}
destination := route.Network.Masked()
if route.IsDynamic() { if route.IsDynamic() {
// TODO: add ipv6 additionally if useNewDNSRoute {
destination = getDefaultPrefix(destination) destination.Set = firewall.NewDomainSet(route.Domains)
} else {
// TODO: add ipv6 additionally
destination = getDefaultPrefix(destination.Prefix)
}
} else {
destination.Prefix = route.Network.Masked()
} }
return firewall.RouterPair{ return firewall.RouterPair{
@ -180,12 +154,16 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
Source: source, Source: source,
Destination: destination, Destination: destination,
Masquerade: route.Masquerade, Masquerade: route.Masquerade,
}, nil }
} }
func getDefaultPrefix(prefix netip.Prefix) netip.Prefix { func getDefaultPrefix(prefix netip.Prefix) firewall.Network {
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0) return firewall.Network{
Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
}
}
return firewall.Network{
Prefix: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
} }
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
} }

View File

@ -45,7 +45,7 @@ var sysctlFailed bool
type ruleParams struct { type ruleParams struct {
priority int priority int
fwmark int fwmark uint32
tableID int tableID int
family int family int
invert bool invert bool
@ -55,8 +55,8 @@ type ruleParams struct {
func getSetupRules() []ruleParams { func getSetupRules() []ruleParams {
return []ruleParams{ return []ruleParams{
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, {100, 0, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"}, {100, 0, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
{110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"}, {110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
{110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"}, {110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
} }

View File

@ -100,7 +100,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
// Convert to proto format // Convert to proto format
for domain, ips := range domainMap { for domain, ips := range domainMap {
pbRoute.ResolvedIPs[domain.PunycodeString()] = &proto.IPList{ pbRoute.ResolvedIPs[domain.SafeString()] = &proto.IPList{
Ips: ips, Ips: ips,
} }
} }

View File

@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@ -414,7 +415,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
signalConnString, signalConnString,
relaysString, relaysString,
dnsServersString, dnsServersString,
overview.FQDN, domain.Domain(overview.FQDN).SafeString(),
interfaceIP, interfaceIP,
interfaceTypeString, interfaceTypeString,
rosenpassEnabledStatus, rosenpassEnabledStatus,
@ -508,7 +509,7 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo
" Quantum resistance: %s\n"+ " Quantum resistance: %s\n"+
" Networks: %s\n"+ " Networks: %s\n"+
" Latency: %s\n", " Latency: %s\n",
peerState.FQDN, domain.Domain(peerState.FQDN).SafeString(),
peerState.IP, peerState.IP,
peerState.PubKey, peerState.PubKey,
peerState.Status, peerState.Status,

View File

@ -111,6 +111,5 @@ func GetParsedDomainLabel(name string) (string, error) {
// NormalizeZone returns a normalized domain name without the wildcard prefix // NormalizeZone returns a normalized domain name without the wildcard prefix
func NormalizeZone(domain string) string { func NormalizeZone(domain string) string {
d, _ := strings.CutPrefix(domain, "*.") return strings.TrimPrefix(domain, "*.")
return d
} }

21
go.mod
View File

@ -18,9 +18,9 @@ require (
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.2.1-beta.2 github.com/vishvananda/netlink v1.3.0
golang.org/x/crypto v0.36.0 golang.org/x/crypto v0.37.0
golang.org/x/sys v0.31.0 golang.org/x/sys v0.32.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
@ -39,7 +39,6 @@ require (
github.com/coder/websocket v1.8.12 github.com/coder/websocket v1.8.12
github.com/coreos/go-iptables v0.7.0 github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18 github.com/creack/pty v1.1.18
github.com/davecgh/go-spew v1.1.1
github.com/eko/gocache/lib/v4 v4.2.0 github.com/eko/gocache/lib/v4 v4.2.0
github.com/eko/gocache/store/go_cache/v4 v4.2.2 github.com/eko/gocache/store/go_cache/v4 v4.2.2
github.com/eko/gocache/store/redis/v4 v4.2.2 github.com/eko/gocache/store/redis/v4 v4.2.2
@ -49,7 +48,7 @@ require (
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.7.0 github.com/google/go-cmp v0.7.0
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/google/nftables v0.2.0 github.com/google/nftables v0.3.0
github.com/gopacket/gopacket v1.1.1 github.com/gopacket/gopacket v1.1.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-multierror v1.1.1
@ -100,10 +99,10 @@ require (
goauthentik.io/api/v3 v3.2023051.3 goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/net v0.38.0 golang.org/x/net v0.39.0
golang.org/x/oauth2 v0.24.0 golang.org/x/oauth2 v0.24.0
golang.org/x/sync v0.12.0 golang.org/x/sync v0.13.0
golang.org/x/term v0.30.0 golang.org/x/term v0.31.0
google.golang.org/api v0.177.0 google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7 gorm.io/driver/mysql v1.5.7
@ -145,6 +144,7 @@ require (
github.com/containerd/log v0.1.0 // indirect github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v26.1.5+incompatible // indirect github.com/docker/docker v26.1.5+incompatible // indirect
@ -183,7 +183,6 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/compress v1.18.0 // indirect
@ -192,7 +191,7 @@ require (
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
github.com/magiconair/properties v1.8.7 // indirect github.com/magiconair/properties v1.8.7 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
github.com/mholt/acmez/v2 v2.0.1 // indirect github.com/mholt/acmez/v2 v2.0.1 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect github.com/moby/patternmatcher v0.6.0 // indirect
@ -235,7 +234,7 @@ require (
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect golang.org/x/mod v0.17.0 // indirect
golang.org/x/text v0.23.0 // indirect golang.org/x/text v0.24.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect

43
go.sum
View File

@ -301,8 +301,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8= github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg=
github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4= github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
@ -399,8 +399,6 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
@ -447,8 +445,8 @@ github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
@ -665,9 +663,8 @@ github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYg
github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE= github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
@ -752,8 +749,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -846,8 +843,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -876,8 +873,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -902,7 +899,6 @@ golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -911,7 +907,6 @@ golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -939,14 +934,16 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@ -954,8 +951,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -969,8 +966,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View File

@ -1,12 +1,17 @@
package domain package domain
import ( import (
"strings"
"golang.org/x/net/idna" "golang.org/x/net/idna"
) )
// Domain represents a punycode-encoded domain string.
// This should only be converted from a string when the string already is in punycode, otherwise use FromString.
type Domain string type Domain string
// String converts the Domain to a non-punycode string. // String converts the Domain to a non-punycode string.
// For an infallible conversion, use SafeString.
func (d Domain) String() (string, error) { func (d Domain) String() (string, error) {
unicode, err := idna.ToUnicode(string(d)) unicode, err := idna.ToUnicode(string(d))
if err != nil { if err != nil {
@ -15,16 +20,17 @@ func (d Domain) String() (string, error) {
return unicode, nil return unicode, nil
} }
// SafeString converts the Domain to a non-punycode string, falling back to the original string if conversion fails. // SafeString converts the Domain to a non-punycode string, falling back to the punycode string if conversion fails.
func (d Domain) SafeString() string { func (d Domain) SafeString() string {
str, err := d.String() str, err := d.String()
if err != nil { if err != nil {
str = string(d) return string(d)
} }
return str return str
} }
// PunycodeString returns the punycode representation of the Domain. // PunycodeString returns the punycode representation of the Domain.
// This should only be used if a punycode domain is expected but only a string is supported.
func (d Domain) PunycodeString() string { func (d Domain) PunycodeString() string {
return string(d) return string(d)
} }
@ -35,5 +41,5 @@ func FromString(s string) (Domain, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return Domain(ascii), nil return Domain(strings.ToLower(ascii)), nil
} }

View File

@ -5,6 +5,7 @@ import (
"strings" "strings"
) )
// List is a slice of punycode-encoded domain strings.
type List []Domain type List []Domain
// ToStringList converts a List to a slice of string. // ToStringList converts a List to a slice of string.
@ -53,7 +54,7 @@ func (d List) String() (string, error) {
func (d List) SafeString() string { func (d List) SafeString() string {
str, err := d.String() str, err := d.String()
if err != nil { if err != nil {
return strings.Join(d.ToPunycodeList(), ", ") return d.PunycodeString()
} }
return str return str
} }
@ -101,7 +102,7 @@ func FromStringList(s []string) (List, error) {
func FromPunycodeList(s []string) List { func FromPunycodeList(s []string) List {
var dl List var dl List
for _, domain := range s { for _, domain := range s {
dl = append(dl, Domain(domain)) dl = append(dl, Domain(strings.ToLower(domain)))
} }
return dl return dl
} }

View File

@ -22,8 +22,6 @@ func ValidateDomains(domains []string) (List, error) {
var domainList List var domainList List
for _, d := range domains { for _, d := range domains {
d := strings.ToLower(d)
// handles length and idna conversion // handles length and idna conversion
punycode, err := FromString(d) punycode, err := FromString(d)
if err != nil { if err != nil {

View File

@ -1289,7 +1289,7 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer
if route.Peer != peer.Key { if route.Peer != peer.Key {
continue continue
} }
resourceAppliedPolicies := resourcePolicies[route.GetResourceID()] resourceAppliedPolicies := resourcePolicies[string(route.GetResourceID())]
distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups)
rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers)

View File

@ -4,13 +4,14 @@ import "strings"
const haSeparator = "|" const haSeparator = "|"
// HAUniqueID is a unique identifier that is used to group high availability routes.
type HAUniqueID string type HAUniqueID string
func (id HAUniqueID) String() string { func (id HAUniqueID) String() string {
return string(id) return string(id)
} }
// NetID returns the Network ID from the HAUniqueID // NetID returns the NetID from the HAUniqueID
func (id HAUniqueID) NetID() NetID { func (id HAUniqueID) NetID() NetID {
if i := strings.LastIndex(string(id), haSeparator); i != -1 { if i := strings.LastIndex(string(id), haSeparator); i != -1 {
return NetID(id[:i]) return NetID(id[:i])

View File

@ -6,8 +6,6 @@ import (
"slices" "slices"
"strings" "strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
@ -46,10 +44,16 @@ const (
DomainNetwork DomainNetwork
) )
// ID is the unique route ID.
type ID string type ID string
// ResID is the resourceID part of a route.ID (first part before the colon).
type ResID string
// NetID is the route network identifier, a human-readable string.
type NetID string type NetID string
// HAMap is a map of HAUniqueID to a list of routes.
type HAMap map[HAUniqueID][]*Route type HAMap map[HAUniqueID][]*Route
// NetworkType route network type // NetworkType route network type
@ -162,21 +166,25 @@ func (r *Route) IsDynamic() bool {
return r.NetworkType == DomainNetwork return r.NetworkType == DomainNetwork
} }
// GetHAUniqueID returns the HAUniqueID for the route, it can be used for grouping.
func (r *Route) GetHAUniqueID() HAUniqueID { func (r *Route) GetHAUniqueID() HAUniqueID {
if r.IsDynamic() { return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, r.NetString()))
domains, err := r.Domains.String()
if err != nil {
log.Errorf("Failed to convert domains to string: %v", err)
domains = r.Domains.PunycodeString()
}
return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, domains))
}
return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, r.Network.String()))
} }
// GetResourceID returns the Networks Resource ID from a route ID // GetResourceID returns the Networks ResID from the route ID.
func (r *Route) GetResourceID() string { // It's the part before the first colon in the ID string.
return strings.Split(string(r.ID), ":")[0] func (r *Route) GetResourceID() ResID {
return ResID(strings.Split(string(r.ID), ":")[0])
}
// NetString returns the network string.
// If the route is dynamic, it returns the domains as comma-separated punycode-encoded string.
// If the route is not dynamic, it returns the network (prefix) string.
func (r *Route) NetString() string {
if r.IsDynamic() {
return r.Domains.SafeString()
}
return r.Network.String()
} }
// ParseNetwork Parses a network prefix string and returns a netip.Prefix object and if is invalid, IPv4 or IPv6 // ParseNetwork Parses a network prefix string and returns a netip.Prefix object and if is invalid, IPv4 or IPv6