From 4a9049566a5176304d802ae8331573ccad51c312 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 24 Apr 2025 17:37:28 +0200 Subject: [PATCH] [client] Set up firewall rules for dns routes dynamically based on dns response (#3702) --- client/firewall/iptables/manager_linux.go | 17 +- client/firewall/iptables/router_linux.go | 158 ++++++--- client/firewall/iptables/router_linux_test.go | 85 ++++- client/firewall/manager/firewall.go | 66 ++-- client/firewall/manager/firewall_test.go | 16 +- client/firewall/manager/routerpair.go | 6 +- client/firewall/manager/set.go | 74 +++++ client/firewall/nftables/manager_linux.go | 19 +- .../firewall/nftables/manager_linux_test.go | 6 +- client/firewall/nftables/router_linux.go | 306 ++++++++++++------ client/firewall/nftables/router_linux_test.go | 10 +- client/firewall/test/cases_linux.go | 12 +- client/firewall/uspfilter/allow_netbird.go | 2 +- .../uspfilter/allow_netbird_windows.go | 6 +- client/firewall/uspfilter/rule.go | 17 +- client/firewall/uspfilter/tracer_test.go | 4 +- client/firewall/uspfilter/uspfilter.go | 163 +++++++--- .../uspfilter/uspfilter_filter_test.go | 158 ++++++--- client/firewall/uspfilter/uspfilter_test.go | 201 ++++++++++++ client/internal/acl/id/id.go | 2 +- client/internal/acl/manager.go | 54 +++- client/internal/acl/manager_test.go | 10 +- client/internal/debug/debug_linux.go | 32 ++ client/internal/dnsfwd/forwarder.go | 146 ++++++--- client/internal/dnsfwd/forwarder_test.go | 50 +-- client/internal/dnsfwd/manager.go | 35 +- client/internal/engine.go | 58 ++-- client/internal/peer/route.go | 20 +- client/internal/peer/status.go | 14 +- .../routemanager/dnsinterceptor/handler.go | 9 +- client/internal/routemanager/manager.go | 4 +- .../internal/routemanager/server_android.go | 2 +- .../routemanager/server_nonandroid.go | 88 ++--- .../routemanager/systemops/systemops_linux.go | 6 +- client/server/network.go | 2 +- client/status/status.go | 5 +- dns/dns.go | 3 +- go.mod | 21 +- go.sum | 43 ++- management/domain/domain.go | 12 +- management/domain/list.go | 5 +- management/domain/validate.go | 2 - management/server/types/account.go | 2 +- route/hauniqueid.go | 3 +- route/route.go | 36 ++- 45 files changed, 1399 insertions(+), 591 deletions(-) create mode 100644 client/firewall/manager/set.go diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 652ab1b3e..b229688fc 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -113,17 +113,16 @@ func (m *Manager) AddPeerFiltering( func (m *Manager) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, + sPort, dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - if !destination.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) + if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { + return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) @@ -243,6 +242,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { return m.router.DeleteDNATRule(rule) } +// UpdateSet updates the set with the given prefixes +func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.UpdateSet(set, prefixes) +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 869b0b359..b59c88580 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -57,18 +57,18 @@ type ruleInfo struct { } type routeFilteringRuleParams struct { - Sources []netip.Prefix - Destination netip.Prefix + Source firewall.Network + Destination firewall.Network Proto firewall.Protocol SPort *firewall.Port DPort *firewall.Port Direction firewall.RuleDirection Action firewall.Action - SetName string } type routeRules map[string][]string +// the ipset library currently does not support comments, so we use the name only (string) type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}] type router struct { @@ -129,7 +129,7 @@ func (r *router) init(stateManager *statemanager.Manager) error { func (r *router) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, @@ -140,27 +140,28 @@ func (r *router) AddRouteFiltering( return ruleKey, nil } - var setName string + var source firewall.Network if len(sources) > 1 { - setName = firewall.GenerateSetName(sources) - if _, err := r.ipsetCounter.Increment(setName, sources); err != nil { - return nil, fmt.Errorf("create or get ipset: %w", err) - } + source.Set = firewall.NewPrefixSet(sources) + } else if len(sources) > 0 { + source.Prefix = sources[0] } params := routeFilteringRuleParams{ - Sources: sources, + Source: source, Destination: destination, Proto: proto, SPort: sPort, DPort: dPort, Action: action, - SetName: setName, } - rule := genRouteFilteringRuleSpec(params) + rule, err := r.genRouteRuleSpec(params, sources) + if err != nil { + return nil, fmt.Errorf("generate route rule spec: %w", err) + } + // Insert DROP rules at the beginning, append ACCEPT rules at the end - var err error if action == firewall.ActionDrop { // after the established rule err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...) @@ -183,17 +184,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { ruleKey := rule.ID() if rule, exists := r.rules[ruleKey]; exists { - setName := r.findSetNameInRule(rule) - if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil { return fmt.Errorf("delete route rule: %v", err) } delete(r.rules, ruleKey) - if setName != "" { - if _, err := r.ipsetCounter.Decrement(setName); err != nil { - return fmt.Errorf("failed to remove ipset: %w", err) - } + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement ipset counter: %w", err) } } else { log.Debugf("route rule %s not found", ruleKey) @@ -204,13 +201,26 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { return nil } -func (r *router) findSetNameInRule(rule []string) string { - for i, arg := range rule { - if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet { - return rule[i+3] +func (r *router) decrementSetCounter(rule []string) error { + sets := r.findSets(rule) + var merr *multierror.Error + for _, setName := range sets { + if _, err := r.ipsetCounter.Decrement(setName); err != nil { + merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err)) } } - return "" + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) findSets(rule []string) []string { + var sets []string + for i, arg := range rule { + if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet { + sets = append(sets, rule[i+3]) + } + } + return sets } func (r *router) createIpSet(setName string, sources []netip.Prefix) error { @@ -231,6 +241,8 @@ func (r *router) deleteIpSet(setName string) error { if err := ipset.Destroy(setName); err != nil { return fmt.Errorf("destroy set %s: %w", setName, err) } + + log.Debugf("Deleted unused ipset %s", setName) return nil } @@ -270,12 +282,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { log.Errorf("%v", err) } - if err := r.removeNatRule(pair); err != nil { - return fmt.Errorf("remove nat rule: %w", err) - } + if pair.Masquerade { + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove nat rule: %w", err) + } - if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { - return fmt.Errorf("remove inverse nat rule: %w", err) + if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("remove inverse nat rule: %w", err) + } } if err := r.removeLegacyRouteRule(pair); err != nil { @@ -313,8 +327,10 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) } delete(r.rules, ruleKey) - } else { - log.Debugf("legacy forwarding rule %s not found", ruleKey) + + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement ipset counter: %w", err) + } } return nil @@ -599,12 +615,24 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { rule = append(rule, "-m", "conntrack", "--ctstate", "NEW", - "-s", pair.Source.String(), - "-d", pair.Destination.String(), + ) + sourceExp, err := r.applyNetwork("-s", pair.Source, nil) + if err != nil { + return fmt.Errorf("apply network -s: %w", err) + } + destExp, err := r.applyNetwork("-d", pair.Destination, nil) + if err != nil { + return fmt.Errorf("apply network -d: %w", err) + } + + rule = append(rule, sourceExp...) + rule = append(rule, destExp...) + rule = append(rule, "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue), ) if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil { + // TODO: rollback ipset counter return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err) } @@ -622,6 +650,10 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err) } delete(r.rules, ruleKey) + + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement ipset counter: %w", err) + } } else { log.Debugf("marking rule %s not found", ruleKey) } @@ -787,17 +819,21 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error { return nberrors.FormatErrorOrNil(merr) } -func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { +func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) { var rule []string - if params.SetName != "" { - rule = append(rule, "-m", "set", matchSet, params.SetName, "src") - } else if len(params.Sources) > 0 { - source := params.Sources[0] - rule = append(rule, "-s", source.String()) + sourceExp, err := r.applyNetwork("-s", params.Source, sources) + if err != nil { + return nil, fmt.Errorf("apply network -s: %w", err) + + } + destExp, err := r.applyNetwork("-d", params.Destination, nil) + if err != nil { + return nil, fmt.Errorf("apply network -d: %w", err) } - rule = append(rule, "-d", params.Destination.String()) + rule = append(rule, sourceExp...) + rule = append(rule, destExp...) if params.Proto != firewall.ProtocolALL { rule = append(rule, "-p", strings.ToLower(string(params.Proto))) @@ -807,7 +843,47 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { rule = append(rule, "-j", actionToStr(params.Action)) - return rule + return rule, nil +} + +func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) { + direction := "src" + if flag == "-d" { + direction = "dst" + } + + if network.IsSet() { + if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil { + return nil, fmt.Errorf("create or get ipset: %w", err) + } + + return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil + } + if network.IsPrefix() { + return []string{flag, network.Prefix.String()}, nil + } + + // nolint:nilnil + return nil, nil +} + +func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + var merr *multierror.Error + for _, prefix := range prefixes { + // TODO: Implement IPv6 support + if prefix.Addr().Is6() { + log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) + continue + } + if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err)) + } + } + if merr == nil { + log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes) + } + + return nberrors.FormatErrorOrNil(merr) } func applyPort(flag string, port *firewall.Port) []string { diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index dad77dee7..e9eeff863 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -60,8 +60,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { pair := firewall.RouterPair{ ID: "abc", - Source: netip.MustParsePrefix("100.100.100.1/32"), - Destination: netip.MustParsePrefix("100.100.100.0/24"), + Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, + Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")}, Masquerade: true, } @@ -332,7 +332,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) + ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action) require.NoError(t, err, "AddRouteFiltering failed") // Check if the rule is in the internal map @@ -347,23 +347,29 @@ func TestRouter_AddRouteFiltering(t *testing.T) { assert.NoError(t, err, "Failed to check rule existence") assert.True(t, exists, "Rule not found in iptables") + var source firewall.Network + if len(tt.sources) > 1 { + source.Set = firewall.NewPrefixSet(tt.sources) + } else if len(tt.sources) > 0 { + source.Prefix = tt.sources[0] + } // Verify rule content params := routeFilteringRuleParams{ - Sources: tt.sources, - Destination: tt.destination, + Source: source, + Destination: firewall.Network{Prefix: tt.destination}, Proto: tt.proto, SPort: tt.sPort, DPort: tt.dPort, Action: tt.action, - SetName: "", } - expectedRule := genRouteFilteringRuleSpec(params) + expectedRule, err := r.genRouteRuleSpec(params, nil) + require.NoError(t, err, "Failed to generate expected rule spec") if tt.expectSet { - setName := firewall.GenerateSetName(tt.sources) - params.SetName = setName - expectedRule = genRouteFilteringRuleSpec(params) + setName := firewall.NewPrefixSet(tt.sources).HashedName() + expectedRule, err = r.genRouteRuleSpec(params, nil) + require.NoError(t, err, "Failed to generate expected rule spec with set") // Check if the set was created _, exists := r.ipsetCounter.Get(setName) @@ -378,3 +384,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) { }) } } + +func TestFindSetNameInRule(t *testing.T) { + r := &router{} + + testCases := []struct { + name string + rule []string + expected []string + }{ + { + name: "Basic rule with two sets", + rule: []string{ + "-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src", + "-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT", + }, + expected: []string{"nb-2e5a2a05", "nb-349ae051"}, + }, + { + name: "No sets", + rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"}, + expected: []string{}, + }, + { + name: "Multiple sets with different positions", + rule: []string{ + "-m", "set", "--match-set", "set1", "src", "-p", "tcp", + "-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT", + }, + expected: []string{"set1", "set-abc123"}, + }, + { + name: "Boundary case - sequence appears at end", + rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"}, + expected: []string{"final-set"}, + }, + { + name: "Incomplete pattern - missing set name", + rule: []string{"-p", "tcp", "-m", "set", "--match-set"}, + expected: []string{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := r.findSets(tc.rule) + + if len(result) != len(tc.expected) { + t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result) + return + } + + for i, set := range result { + if set != tc.expected[i] { + t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set) + } + } + }) + } +} diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 1d71051ef..084d19423 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -1,13 +1,10 @@ package manager import ( - "crypto/sha256" - "encoding/hex" "fmt" "net" "net/netip" "sort" - "strings" log "github.com/sirupsen/logrus" @@ -43,6 +40,18 @@ const ( // Action is the action to be taken on a rule type Action int +// String returns the string representation of the action +func (a Action) String() string { + switch a { + case ActionAccept: + return "accept" + case ActionDrop: + return "drop" + default: + return "unknown" + } +} + const ( // ActionAccept is the action to accept a packet ActionAccept Action = iota @@ -50,6 +59,33 @@ const ( ActionDrop ) +// Network is a rule destination, either a set or a prefix +type Network struct { + Set Set + Prefix netip.Prefix +} + +// String returns the string representation of the destination +func (d Network) String() string { + if d.Prefix.IsValid() { + return d.Prefix.String() + } + if d.IsSet() { + return d.Set.HashedName() + } + return "" +} + +// IsSet returns true if the destination is a set +func (d Network) IsSet() bool { + return d.Set != Set{} +} + +// IsPrefix returns true if the destination is a valid prefix +func (d Network) IsPrefix() bool { + return d.Prefix.IsValid() +} + // Manager is the high level abstraction of a firewall manager // // It declares methods which handle actions required by the @@ -83,10 +119,9 @@ type Manager interface { AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination Network, proto Protocol, - sPort *Port, - dPort *Port, + sPort, dPort *Port, action Action, ) (Rule, error) @@ -119,6 +154,9 @@ type Manager interface { // DeleteDNATRule deletes a DNAT rule DeleteDNATRule(Rule) error + + // UpdateSet updates the set with the given prefixes + UpdateSet(hash Set, prefixes []netip.Prefix) error } func GenKey(format string, pair RouterPair) string { @@ -153,22 +191,6 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error { return nil } -// GenerateSetName generates a unique name for an ipset based on the given sources. -func GenerateSetName(sources []netip.Prefix) string { - // sort for consistent naming - SortPrefixes(sources) - - var sourcesStr strings.Builder - for _, src := range sources { - sourcesStr.WriteString(src.String()) - } - - hash := sha256.Sum256([]byte(sourcesStr.String())) - shortHash := hex.EncodeToString(hash[:])[:8] - - return fmt.Sprintf("nb-%s", shortHash) -} - // MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { if len(prefixes) == 0 { diff --git a/client/firewall/manager/firewall_test.go b/client/firewall/manager/firewall_test.go index 3f47d6679..180346906 100644 --- a/client/firewall/manager/firewall_test.go +++ b/client/firewall/manager/firewall_test.go @@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) { netip.MustParsePrefix("192.168.1.0/24"), } - result1 := manager.GenerateSetName(prefixes1) - result2 := manager.GenerateSetName(prefixes2) + result1 := manager.NewPrefixSet(prefixes1) + result2 := manager.NewPrefixSet(prefixes2) if result1 != result2 { t.Errorf("Different orders produced different hashes: %s != %s", result1, result2) @@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) { netip.MustParsePrefix("10.0.0.0/8"), } - result := manager.GenerateSetName(prefixes) + result := manager.NewPrefixSet(prefixes) - matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result) + matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName()) if err != nil { t.Fatalf("Error matching regex: %v", err) } @@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) { }) t.Run("Empty input produces consistent result", func(t *testing.T) { - result1 := manager.GenerateSetName([]netip.Prefix{}) - result2 := manager.GenerateSetName([]netip.Prefix{}) + result1 := manager.NewPrefixSet([]netip.Prefix{}) + result2 := manager.NewPrefixSet([]netip.Prefix{}) if result1 != result2 { t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2) @@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) { netip.MustParsePrefix("192.168.1.0/24"), } - result1 := manager.GenerateSetName(prefixes1) - result2 := manager.GenerateSetName(prefixes2) + result1 := manager.NewPrefixSet(prefixes1) + result2 := manager.NewPrefixSet(prefixes2) if result1 != result2 { t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2) diff --git a/client/firewall/manager/routerpair.go b/client/firewall/manager/routerpair.go index 8c94b7dd4..079c051d9 100644 --- a/client/firewall/manager/routerpair.go +++ b/client/firewall/manager/routerpair.go @@ -1,15 +1,13 @@ package manager import ( - "net/netip" - "github.com/netbirdio/netbird/route" ) type RouterPair struct { ID route.ID - Source netip.Prefix - Destination netip.Prefix + Source Network + Destination Network Masquerade bool Inverse bool } diff --git a/client/firewall/manager/set.go b/client/firewall/manager/set.go new file mode 100644 index 000000000..4c88f6eac --- /dev/null +++ b/client/firewall/manager/set.go @@ -0,0 +1,74 @@ +package manager + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "net/netip" + "slices" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/domain" +) + +type Set struct { + hash [4]byte + comment string +} + +// String returns the string representation of the set: hashed name and comment +func (h Set) String() string { + if h.comment == "" { + return h.HashedName() + } + return h.HashedName() + ": " + h.comment +} + +// HashedName returns the string representation of the hash +func (h Set) HashedName() string { + return fmt.Sprintf( + "nb-%s", + hex.EncodeToString(h.hash[:]), + ) +} + +// Comment returns the comment of the set +func (h Set) Comment() string { + return h.comment +} + +// NewPrefixSet generates a unique name for an ipset based on the given prefixes. +func NewPrefixSet(prefixes []netip.Prefix) Set { + // sort for consistent naming + SortPrefixes(prefixes) + + hash := sha256.New() + for _, src := range prefixes { + bytes, err := src.MarshalBinary() + if err != nil { + log.Warnf("failed to marshal prefix %s: %v", src, err) + } + hash.Write(bytes) + } + var set Set + copy(set.hash[:], hash.Sum(nil)[:4]) + + return set +} + +// NewDomainSet generates a unique name for an ipset based on the given domains. +func NewDomainSet(domains domain.List) Set { + slices.Sort(domains) + + hash := sha256.New() + for _, d := range domains { + hash.Write([]byte(d.PunycodeString())) + } + set := Set{ + comment: domains.SafeString(), + } + copy(set.hash[:], hash.Sum(nil)[:4]) + + return set +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index a5809471c..e6b3a031b 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -135,17 +135,16 @@ func (m *Manager) AddPeerFiltering( func (m *Manager) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, + sPort, dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - if !destination.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) + if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { + return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) @@ -242,7 +241,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { return firewall.SetLegacyManagement(m.router, isLegacy) } -// Reset firewall to the default state +// Close closes the firewall manager func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -359,6 +358,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { return m.router.DeleteDNATRule(rule) } +// UpdateSet updates the set with the given prefixes +func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.UpdateSet(set, prefixes) +} + func (m *Manager) createWorkTable() (*nftables.Table, error) { tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 373743a08..602a6b8dc 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -289,7 +289,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { _, err = manager.AddRouteFiltering( nil, []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, - netip.MustParsePrefix("10.1.0.0/24"), + fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{443}}, @@ -298,8 +298,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { require.NoError(t, err, "failed to add route filtering rule") pair := fw.RouterPair{ - Source: netip.MustParsePrefix("192.168.1.0/24"), - Destination: netip.MustParsePrefix("10.0.0.0/24"), + Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, + Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")}, Masquerade: true, } err = manager.AddNatRule(pair) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index aff86dd90..c2ba2a072 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -10,7 +10,6 @@ import ( "strings" "github.com/coreos/go-iptables/iptables" - "github.com/davecgh/go-spew/spew" "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" @@ -44,9 +43,14 @@ const ( const refreshRulesMapError = "refresh rules map: %w" var ( - errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") + errFilterTableNotFound = fmt.Errorf("'filter' table not found") ) +type setInput struct { + set firewall.Set + prefixes []netip.Prefix +} + type router struct { conn *nftables.Conn workTable *nftables.Table @@ -54,7 +58,7 @@ type router struct { chains map[string]*nftables.Chain // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules rules map[string]*nftables.Rule - ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] + ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set] wgIface iFaceMapper ipFwdState *ipfwdstate.IPForwardingState @@ -163,7 +167,7 @@ func (r *router) removeNatPreroutingRules() error { func (r *router) loadFilterTable() (*nftables.Table, error) { tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { - return nil, fmt.Errorf("nftables: unable to list tables: %v", err) + return nil, fmt.Errorf("unable to list tables: %v", err) } for _, table := range tables { @@ -316,7 +320,7 @@ func (r *router) setupDataPlaneMark() error { func (r *router) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, @@ -331,23 +335,29 @@ func (r *router) AddRouteFiltering( chain := r.chains[chainNameRoutingFw] var exprs []expr.Any + var source firewall.Network switch { case len(sources) == 1 && sources[0].Bits() == 0: // If it's 0.0.0.0/0, we don't need to add any source matching case len(sources) == 1: // If there's only one source, we can use it directly - exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...) + source.Prefix = sources[0] default: - // If there are multiple sources, create or get an ipset - var err error - exprs, err = r.getIpSetExprs(sources, exprs) - if err != nil { - return nil, fmt.Errorf("get ipset expressions: %w", err) - } + // If there are multiple sources, use a set + source.Set = firewall.NewPrefixSet(sources) } - // Handle destination - exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...) + sourceExp, err := r.applyNetwork(source, sources, true) + if err != nil { + return nil, fmt.Errorf("apply source: %w", err) + } + exprs = append(exprs, sourceExp...) + + destExp, err := r.applyNetwork(destination, nil, false) + if err != nil { + return nil, fmt.Errorf("apply destination: %w", err) + } + exprs = append(exprs, destExp...) // Handle protocol if proto != firewall.ProtocolALL { @@ -391,39 +401,27 @@ func (r *router) AddRouteFiltering( rule = r.conn.AddRule(rule) } - log.Tracef("Adding route rule %s", spew.Sdump(rule)) if err := r.conn.Flush(); err != nil { return nil, fmt.Errorf(flushError, err) } r.rules[string(ruleKey)] = rule - log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) + log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) return ruleKey, nil } -func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { - setName := firewall.GenerateSetName(sources) - ref, err := r.ipsetCounter.Increment(setName, sources) +func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) { + ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{ + set: set, + prefixes: prefixes, + }) if err != nil { - return nil, fmt.Errorf("create or get ipset for sources: %w", err) + return nil, fmt.Errorf("create or get ipset: %w", err) } - exprs = append(exprs, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Lookup{ - SourceRegister: 1, - SetName: ref.Out.Name, - SetID: ref.Out.ID, - }, - ) - return exprs, nil + return getIpSetExprs(ref, isSource) } func (r *router) DeleteRouteRule(rule firewall.Rule) error { @@ -442,42 +440,54 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { return fmt.Errorf("route rule %s has no handle", ruleKey) } - setName := r.findSetNameInRule(nftRule) - if err := r.deleteNftRule(nftRule, ruleKey); err != nil { return fmt.Errorf("delete: %w", err) } - if setName != "" { - if _, err := r.ipsetCounter.Decrement(setName); err != nil { - return fmt.Errorf("decrement ipset reference: %w", err) - } - } - if err := r.conn.Flush(); err != nil { return fmt.Errorf(flushError, err) } + if err := r.decrementSetCounter(nftRule); err != nil { + return fmt.Errorf("decrement set counter: %w", err) + } + return nil } -func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) { +func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) { // overlapping prefixes will result in an error, so we need to merge them - sources = firewall.MergeIPRanges(sources) + prefixes := firewall.MergeIPRanges(input.prefixes) - set := &nftables.Set{ - Name: setName, - Table: r.workTable, + nfset := &nftables.Set{ + Name: setName, + Comment: input.set.Comment(), + Table: r.workTable, // required for prefixes Interval: true, KeyType: nftables.TypeIPAddr, } + elements := convertPrefixesToSet(prefixes) + if err := r.conn.AddSet(nfset, elements); err != nil { + return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err) + } + + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf("flush error: %w", err) + } + + log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2) + + return nfset, nil +} + +func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { var elements []nftables.SetElement - for _, prefix := range sources { + for _, prefix := range prefixes { // TODO: Implement IPv6 support if prefix.Addr().Is6() { - log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) + log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) continue } @@ -493,18 +503,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables. nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, ) } - - if err := r.conn.AddSet(set, elements); err != nil { - return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err) - } - - if err := r.conn.Flush(); err != nil { - return nil, fmt.Errorf("flush error: %w", err) - } - - log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2) - - return set, nil + return elements } // calculateLastIP determines the last IP in a given prefix. @@ -528,8 +527,8 @@ func uint32ToBytes(ip uint32) [4]byte { return b } -func (r *router) deleteIpSet(setName string, set *nftables.Set) error { - r.conn.DelSet(set) +func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error { + r.conn.DelSet(nfset) if err := r.conn.Flush(); err != nil { return fmt.Errorf(flushError, err) } @@ -538,13 +537,27 @@ func (r *router) deleteIpSet(setName string, set *nftables.Set) error { return nil } -func (r *router) findSetNameInRule(rule *nftables.Rule) string { - for _, e := range rule.Exprs { - if lookup, ok := e.(*expr.Lookup); ok { - return lookup.SetName +func (r *router) decrementSetCounter(rule *nftables.Rule) error { + sets := r.findSets(rule) + + var merr *multierror.Error + for _, setName := range sets { + if _, err := r.ipsetCounter.Decrement(setName); err != nil { + merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err)) } } - return "" + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) findSets(rule *nftables.Rule) []string { + var sets []string + for _, e := range rule.Exprs { + if lookup, ok := e.(*expr.Lookup); ok { + sets = append(sets, lookup.SetName) + } + } + return sets } func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error { @@ -586,7 +599,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { } if err := r.conn.Flush(); err != nil { - return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err) + // TODO: rollback ipset counter + return fmt.Errorf("insert rules for %s: %v", pair.Destination, err) } return nil @@ -594,19 +608,22 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { // addNatRule inserts a nftables rule to the conn client flush queue func (r *router) addNatRule(pair firewall.RouterPair) error { - sourceExp := generateCIDRMatcherExpressions(true, pair.Source) - destExp := generateCIDRMatcherExpressions(false, pair.Destination) + sourceExp, err := r.applyNetwork(pair.Source, nil, true) + if err != nil { + return fmt.Errorf("apply source: %w", err) + } + + destExp, err := r.applyNetwork(pair.Destination, nil, false) + if err != nil { + return fmt.Errorf("apply destination: %w", err) + } op := expr.CmpOpEq if pair.Inverse { op = expr.CmpOpNeq } - // We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading. - // Masquerading will take care of the conntrack state, which means we won't need to mark established connections. - exprs := getCtNewExprs() - exprs = append(exprs, - // interface matching + exprs := []expr.Any{ &expr.Meta{ Key: expr.MetaKeyIIFNAME, Register: 1, @@ -616,7 +633,10 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { Register: 1, Data: ifname(r.wgIface.Name()), }, - ) + } + // We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading. + // Masquerading will take care of the conntrack state, which means we won't need to mark established connections. + exprs = append(exprs, getCtNewExprs()...) exprs = append(exprs, sourceExp...) exprs = append(exprs, destExp...) @@ -729,8 +749,15 @@ func (r *router) addPostroutingRules() error { // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { - sourceExp := generateCIDRMatcherExpressions(true, pair.Source) - destExp := generateCIDRMatcherExpressions(false, pair.Destination) + sourceExp, err := r.applyNetwork(pair.Source, nil, true) + if err != nil { + return fmt.Errorf("apply source: %w", err) + } + + destExp, err := r.applyNetwork(pair.Destination, nil, false) + if err != nil { + return fmt.Errorf("apply destination: %w", err) + } exprs := []expr.Any{ &expr.Counter{}, @@ -739,7 +766,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { }, } - expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic + exprs = append(exprs, sourceExp...) + exprs = append(exprs, destExp...) ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) @@ -752,7 +780,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ Table: r.workTable, Chain: r.chains[chainNameRoutingFw], - Exprs: expression, + Exprs: exprs, UserData: []byte(ruleKey), }) return nil @@ -767,11 +795,13 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) } - log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) + log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) delete(r.rules, ruleKey) - } else { - log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey) + + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement set counter: %w", err) + } } return nil @@ -982,12 +1012,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return fmt.Errorf(refreshRulesMapError, err) } - if err := r.removeNatRule(pair); err != nil { - return fmt.Errorf("remove prerouting rule: %w", err) - } + if pair.Masquerade { + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove prerouting rule: %w", err) + } - if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { - return fmt.Errorf("remove inverse prerouting rule: %w", err) + if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("remove inverse prerouting rule: %w", err) + } } if err := r.removeLegacyRouteRule(pair); err != nil { @@ -995,10 +1027,10 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { } if err := r.conn.Flush(); err != nil { - return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) + // TODO: rollback set counter + return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err) } - log.Debugf("nftables: removed nat rules for %s", pair.Destination) return nil } @@ -1006,16 +1038,19 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) if rule, exists := r.rules[ruleKey]; exists { - err := r.conn.DelRule(rule) - if err != nil { + if err := r.conn.DelRule(rule); err != nil { return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err) } - log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination) + log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination) delete(r.rules, ruleKey) + + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement set counter: %w", err) + } } else { - log.Debugf("nftables: prerouting rule %s not found", ruleKey) + log.Debugf("prerouting rule %s not found", ruleKey) } return nil @@ -1027,7 +1062,7 @@ func (r *router) refreshRulesMap() error { for _, chain := range r.chains { rules, err := r.conn.GetRules(chain.Table, chain) if err != nil { - return fmt.Errorf("nftables: unable to list rules: %v", err) + return fmt.Errorf(" unable to list rules: %v", err) } for _, rule := range rules { if len(rule.UserData) > 0 { @@ -1301,13 +1336,54 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error { return nberrors.FormatErrorOrNil(merr) } -// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR -func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any { - var offset uint32 - if source { - offset = 12 // src offset - } else { - offset = 16 // dst offset +func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName()) + if err != nil { + return fmt.Errorf("get set %s: %w", set.HashedName(), err) + } + + elements := convertPrefixesToSet(prefixes) + if err := r.conn.SetAddElements(nfset, elements); err != nil { + return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err) + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes) + + return nil +} + +// applyNetwork generates nftables expressions for networks (CIDR) or sets +func (r *router) applyNetwork( + network firewall.Network, + setPrefixes []netip.Prefix, + isSource bool, +) ([]expr.Any, error) { + if network.IsSet() { + exprs, err := r.getIpSet(network.Set, setPrefixes, isSource) + if err != nil { + return nil, fmt.Errorf("source: %w", err) + } + return exprs, nil + } + + if network.IsPrefix() { + return applyPrefix(network.Prefix, isSource), nil + } + + return nil, nil +} + +// applyPrefix generates nftables expressions for a CIDR prefix +func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { + // dst offset + offset := uint32(16) + if isSource { + // src offset + offset = 12 } ones := prefix.Bits() @@ -1415,3 +1491,27 @@ func getCtNewExprs() []expr.Any { }, } } + +func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { + + // dst offset + offset := uint32(16) + if isSource { + // src offset + offset = 12 + } + + return []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: offset, + Len: 4, + }, + &expr.Lookup{ + SourceRegister: 1, + SetName: ref.Out.Name, + SetID: ref.Out.ID, + }, + }, nil +} diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 28baef4dd..4fdbf3505 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) { } // Build CIDR matching expressions - sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) - destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) + sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true) + destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false) // Combine all expressions in the correct order // nolint:gocritic @@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) + ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action) require.NoError(t, err, "AddRouteFiltering failed") t.Cleanup(func() { @@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - setName := firewall.GenerateSetName(tt.sources) - set, err := r.createIpSet(setName, tt.sources) + setName := firewall.NewPrefixSet(tt.sources).HashedName() + set, err := r.createIpSet(setName, setInput{prefixes: tt.sources}) if err != nil { t.Logf("Failed to create IP set: %v", err) printNftSets() diff --git a/client/firewall/test/cases_linux.go b/client/firewall/test/cases_linux.go index 267e93efd..59a370a97 100644 --- a/client/firewall/test/cases_linux.go +++ b/client/firewall/test/cases_linux.go @@ -15,8 +15,8 @@ var ( Name: "Insert Forwarding IPV4 Rule", InputPair: firewall.RouterPair{ ID: "zxa", - Source: netip.MustParsePrefix("100.100.100.1/32"), - Destination: netip.MustParsePrefix("100.100.200.0/24"), + Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, + Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")}, Masquerade: false, }, }, @@ -24,8 +24,8 @@ var ( Name: "Insert Forwarding And Nat IPV4 Rules", InputPair: firewall.RouterPair{ ID: "zxa", - Source: netip.MustParsePrefix("100.100.100.1/32"), - Destination: netip.MustParsePrefix("100.100.200.0/24"), + Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, + Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")}, Masquerade: true, }, }, @@ -40,8 +40,8 @@ var ( Name: "Remove Forwarding And Nat IPV4 Rules", InputPair: firewall.RouterPair{ ID: "zxa", - Source: netip.MustParsePrefix("100.100.100.1/32"), - Destination: netip.MustParsePrefix("100.100.200.0/24"), + Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, + Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")}, Masquerade: true, }, }, diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 5fe698aa9..ce04c82c7 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -12,7 +12,7 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -// Reset firewall to the default state +// Close cleans up the firewall manager by removing all rules and closing trackers func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index f63792fec..f261c472f 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -10,7 +10,6 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -22,7 +21,7 @@ const ( firewallRuleName = "Netbird" ) -// Reset firewall to the default state +// Close cleans up the firewall manager by removing all rules and closing trackers func (m *Manager) Close(*statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -32,17 +31,14 @@ func (m *Manager) Close(*statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger) } if fwder := m.forwarder.Load(); fwder != nil { diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index a23d2011b..b765c72e9 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -29,14 +29,15 @@ func (r *PeerRule) ID() string { } type RouteRule struct { - id string - mgmtId []byte - sources []netip.Prefix - destination netip.Prefix - proto firewall.Protocol - srcPort *firewall.Port - dstPort *firewall.Port - action firewall.Action + id string + mgmtId []byte + sources []netip.Prefix + dstSet firewall.Set + destinations []netip.Prefix + proto firewall.Protocol + srcPort *firewall.Port + dstPort *firewall.Port + action firewall.Action } // ID returns the rule id diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index 48b0ec44d..53ee6c886 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -199,7 +199,7 @@ func TestTracePacket(t *testing.T) { src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) - _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept) + _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept) require.NoError(t, err) }, packetBuilder: func() *PacketBuilder { @@ -223,7 +223,7 @@ func TestTracePacket(t *testing.T) { src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) - _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop) + _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop) require.NoError(t, err) }, packetBuilder: func() *PacketBuilder { diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 466c6a18b..ccf0be225 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -49,10 +49,10 @@ var errNatNotSupported = errors.New("nat not supported with userspace firewall") // RuleSet is a set of rules grouped by a string key type RuleSet map[string]PeerRule -type RouteRules []RouteRule +type RouteRules []*RouteRule func (r RouteRules) Sort() { - slices.SortStableFunc(r, func(a, b RouteRule) int { + slices.SortStableFunc(r, func(a, b *RouteRule) int { // Deny rules come first if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop { return -1 @@ -99,6 +99,8 @@ type Manager struct { forwarder atomic.Pointer[forwarder.Forwarder] logger *nblog.Logger flowLogger nftypes.FlowLogger + + blockRule firewall.Rule } // decoder for packages @@ -201,41 +203,35 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe } } - if err := m.blockInvalidRouted(iface); err != nil { - log.Errorf("failed to block invalid routed traffic: %v", err) - } - if err := iface.SetFilter(m); err != nil { return nil, fmt.Errorf("set filter: %w", err) } return m, nil } -func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error { - if m.forwarder.Load() == nil { - return nil - } +func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) { wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) if err != nil { - return fmt.Errorf("parse wireguard network: %w", err) + return nil, fmt.Errorf("parse wireguard network: %w", err) } log.Debugf("blocking invalid routed traffic for %s", wgPrefix) - if _, err := m.AddRouteFiltering( + rule, err := m.addRouteFiltering( nil, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, - wgPrefix, + firewall.Network{Prefix: wgPrefix}, firewall.ProtocolALL, nil, nil, firewall.ActionDrop, - ); err != nil { - return fmt.Errorf("block wg nte : %w", err) + ) + if err != nil { + return nil, fmt.Errorf("block wg nte : %w", err) } // TODO: Block networks that we're a client of - return nil + return rule, nil } func (m *Manager) determineRouting() error { @@ -413,10 +409,23 @@ func (m *Manager) AddPeerFiltering( func (m *Manager) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, + sPort, dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action) +} + +func (m *Manager) addRouteFiltering( + id []byte, + sources []netip.Prefix, + destination firewall.Network, + proto firewall.Protocol, + sPort, dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { if m.nativeRouter.Load() && m.nativeFirewall != nil { @@ -426,34 +435,39 @@ func (m *Manager) AddRouteFiltering( ruleID := uuid.New().String() rule := RouteRule{ // TODO: consolidate these IDs - id: ruleID, - mgmtId: id, - sources: sources, - destination: destination, - proto: proto, - srcPort: sPort, - dstPort: dPort, - action: action, + id: ruleID, + mgmtId: id, + sources: sources, + dstSet: destination.Set, + proto: proto, + srcPort: sPort, + dstPort: dPort, + action: action, + } + if destination.IsPrefix() { + rule.destinations = []netip.Prefix{destination.Prefix} } - m.mutex.Lock() - m.routeRules = append(m.routeRules, rule) + m.routeRules = append(m.routeRules, &rule) m.routeRules.Sort() - m.mutex.Unlock() return &rule, nil } func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.deleteRouteRule(rule) +} + +func (m *Manager) deleteRouteRule(rule firewall.Rule) error { if m.nativeRouter.Load() && m.nativeFirewall != nil { return m.nativeFirewall.DeleteRouteRule(rule) } - m.mutex.Lock() - defer m.mutex.Unlock() - ruleID := rule.ID() - idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool { + idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool { return r.id == ruleID }) if idx < 0 { @@ -509,6 +523,52 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { return m.nativeFirewall.DeleteDNATRule(rule) } +// UpdateSet updates the rule destinations associated with the given set +// by merging the existing prefixes with the new ones, then deduplicating. +func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + if m.nativeRouter.Load() && m.nativeFirewall != nil { + return m.nativeFirewall.UpdateSet(set, prefixes) + } + + m.mutex.Lock() + defer m.mutex.Unlock() + + var matches []*RouteRule + for _, rule := range m.routeRules { + if rule.dstSet == set { + matches = append(matches, rule) + } + } + + if len(matches) == 0 { + return fmt.Errorf("no route rule found for set: %s", set) + } + + destinations := matches[0].destinations + for _, prefix := range prefixes { + if prefix.Addr().Is4() { + destinations = append(destinations, prefix) + } + } + + slices.SortFunc(destinations, func(a, b netip.Prefix) int { + cmp := a.Addr().Compare(b.Addr()) + if cmp != 0 { + return cmp + } + return a.Bits() - b.Bits() + }) + + destinations = slices.Compact(destinations) + + for _, rule := range matches { + rule.destinations = destinations + } + log.Debugf("updated set %s to prefixes %v", set.HashedName(), destinations) + + return nil +} + // DropOutgoing filter outgoing packets func (m *Manager) DropOutgoing(packetData []byte, size int) bool { return m.processOutgoingHooks(packetData, size) @@ -988,8 +1048,15 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol return nil, false } -func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { - if !rule.destination.Contains(dstAddr) { +func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { + destMatched := false + for _, dst := range rule.destinations { + if dst.Contains(dstAddr) { + destMatched = true + break + } + } + if !destMatched { return false } @@ -1091,7 +1158,22 @@ func (m *Manager) EnableRouting() error { m.mutex.Lock() defer m.mutex.Unlock() - return m.determineRouting() + if err := m.determineRouting(); err != nil { + return fmt.Errorf("determine routing: %w", err) + } + + if m.forwarder.Load() == nil { + return nil + } + + rule, err := m.blockInvalidRouted(m.wgIface) + if err != nil { + return fmt.Errorf("block invalid routed: %w", err) + } + + m.blockRule = rule + + return nil } func (m *Manager) DisableRouting() error { @@ -1116,5 +1198,12 @@ func (m *Manager) DisableRouting() error { log.Debug("forwarder stopped") + if m.blockRule != nil { + if err := m.deleteRouteRule(m.blockRule); err != nil { + return fmt.Errorf("delete block rule: %w", err) + } + m.blockRule = nil + } + return nil } diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go index 9c0a54e3f..04a398d1f 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/management/domain" ) func TestPeerACLFiltering(t *testing.T) { @@ -600,8 +601,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { } manager, err := Create(ifaceMock, false, flowLogger) - require.NoError(tb, manager.EnableRouting()) require.NoError(tb, err) + require.NoError(tb, manager.EnableRouting()) require.NotNil(tb, manager) require.True(tb, manager.routingEnabled.Load()) require.False(tb, manager.nativeRouter.Load()) @@ -618,7 +619,7 @@ func TestRouteACLFiltering(t *testing.T) { type rule struct { sources []netip.Prefix - dest netip.Prefix + dest fw.Network proto fw.Protocol srcPort *fw.Port dstPort *fw.Port @@ -644,7 +645,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionAccept, @@ -660,7 +661,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionAccept, @@ -676,7 +677,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, - dest: netip.MustParsePrefix("0.0.0.0/0"), + dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionAccept, @@ -692,7 +693,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 53, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolUDP, dstPort: &fw.Port{Values: []uint16{53}}, action: fw.ActionAccept, @@ -706,7 +707,7 @@ func TestRouteACLFiltering(t *testing.T) { proto: fw.ProtocolICMP, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("0.0.0.0/0"), + dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, proto: fw.ProtocolICMP, action: fw.ActionAccept, }, @@ -721,7 +722,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -737,7 +738,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8080, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -753,7 +754,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -769,7 +770,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -785,7 +786,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, srcPort: &fw.Port{Values: []uint16{12345}}, action: fw.ActionAccept, @@ -804,7 +805,7 @@ func TestRouteACLFiltering(t *testing.T) { netip.MustParsePrefix("100.10.0.0/16"), netip.MustParsePrefix("172.16.0.0/16"), }, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -818,7 +819,7 @@ func TestRouteACLFiltering(t *testing.T) { proto: fw.ProtocolICMP, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, action: fw.ActionAccept, }, @@ -833,7 +834,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -849,7 +850,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8080, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80, 8080, 443}}, action: fw.ActionAccept, @@ -865,7 +866,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}}, action: fw.ActionAccept, @@ -881,7 +882,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, srcPort: &fw.Port{Values: []uint16{12345}}, dstPort: &fw.Port{Values: []uint16{80}}, @@ -898,7 +899,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8080, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{ IsRange: true, @@ -917,7 +918,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 7999, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{ IsRange: true, @@ -936,7 +937,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, srcPort: &fw.Port{ IsRange: true, @@ -955,7 +956,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, srcPort: &fw.Port{ IsRange: true, @@ -977,7 +978,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8100, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{ IsRange: true, @@ -996,7 +997,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 5060, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolUDP, dstPort: &fw.Port{ IsRange: true, @@ -1015,7 +1016,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8080, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, dstPort: &fw.Port{ IsRange: true, @@ -1034,7 +1035,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionDrop, @@ -1050,7 +1051,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, action: fw.ActionDrop, }, @@ -1068,13 +1069,32 @@ func TestRouteACLFiltering(t *testing.T) { netip.MustParsePrefix("100.10.0.0/16"), netip.MustParsePrefix("172.16.0.0/16"), }, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionDrop, }, shouldPass: false, }, + + { + name: "Drop empty destination set", + srcIP: "172.16.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + }, + dest: fw.Network{Set: fw.Set{}}, + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, { name: "Accept TCP traffic outside drop port range", srcIP: "100.10.0.1", @@ -1084,7 +1104,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 7999, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, action: fw.ActionDrop, @@ -1100,7 +1120,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, action: fw.ActionAccept, }, @@ -1115,7 +1135,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 53, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolUDP, action: fw.ActionAccept, }, @@ -1130,7 +1150,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolUDP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -1146,7 +1166,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -1160,7 +1180,7 @@ func TestRouteACLFiltering(t *testing.T) { proto: fw.ProtocolICMP, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, action: fw.ActionAccept, }, @@ -1173,7 +1193,7 @@ func TestRouteACLFiltering(t *testing.T) { proto: fw.ProtocolICMP, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolUDP, action: fw.ActionAccept, }, @@ -1188,7 +1208,7 @@ func TestRouteACLFiltering(t *testing.T) { rule, err := manager.AddRouteFiltering( nil, []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, - netip.MustParsePrefix("0.0.0.0/0"), + fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, fw.ProtocolALL, nil, nil, @@ -1235,7 +1255,7 @@ func TestRouteACLOrder(t *testing.T) { name string rules []struct { sources []netip.Prefix - dest netip.Prefix + dest fw.Network proto fw.Protocol srcPort *fw.Port dstPort *fw.Port @@ -1256,7 +1276,7 @@ func TestRouteACLOrder(t *testing.T) { name: "Drop rules take precedence over accept", rules: []struct { sources []netip.Prefix - dest netip.Prefix + dest fw.Network proto fw.Protocol srcPort *fw.Port dstPort *fw.Port @@ -1265,7 +1285,7 @@ func TestRouteACLOrder(t *testing.T) { { // Accept rule added first sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80, 443}}, action: fw.ActionAccept, @@ -1273,7 +1293,7 @@ func TestRouteACLOrder(t *testing.T) { { // Drop rule added second but should be evaluated first sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionDrop, @@ -1311,7 +1331,7 @@ func TestRouteACLOrder(t *testing.T) { name: "Multiple drop rules take precedence", rules: []struct { sources []netip.Prefix - dest netip.Prefix + dest fw.Network proto fw.Protocol srcPort *fw.Port dstPort *fw.Port @@ -1320,14 +1340,14 @@ func TestRouteACLOrder(t *testing.T) { { // Accept all sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, - dest: netip.MustParsePrefix("0.0.0.0/0"), + dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, proto: fw.ProtocolALL, action: fw.ActionAccept, }, { // Drop specific port sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionDrop, @@ -1335,7 +1355,7 @@ func TestRouteACLOrder(t *testing.T) { { // Drop different port sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionDrop, @@ -1414,3 +1434,53 @@ func TestRouteACLOrder(t *testing.T) { }) } } + +func TestRouteACLSet(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: net.ParseIP("100.10.0.100"), + Network: &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + }, + } + }, + } + + manager, err := Create(ifaceMock, false, flowLogger) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, manager.Close(nil)) + }) + + set := fw.NewDomainSet(domain.List{"example.org"}) + + // Add rule that uses the set (initially empty) + rule, err := manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + fw.Network{Set: set}, + fw.ProtocolTCP, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + srcIP := netip.MustParseAddr("100.10.0.1") + dstIP := netip.MustParseAddr("192.168.1.100") + + // Check that traffic is dropped (empty set shouldn't match anything) + _, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) + require.False(t, isAllowed, "Empty set should not allow any traffic") + + err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}) + require.NoError(t, err) + + // Now the packet should be allowed + _, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) + require.True(t, isAllowed, "After set update, traffic to the added network should be allowed") +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index a48a483f8..24a6a2c40 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/netflow" + "github.com/netbirdio/netbird/management/domain" ) var logger = log.NewFromLogrus(logrus.StandardLogger()) @@ -711,3 +712,203 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { }) } } + +func TestUpdateSetMerge(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + } + + manager, err := Create(ifaceMock, false, flowLogger) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, manager.Close(nil)) + }) + + set := fw.NewDomainSet(domain.List{"example.org"}) + + initialPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + } + + rule, err := manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + fw.Network{Set: set}, + fw.ProtocolTCP, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + // Update the set with initial prefixes + err = manager.UpdateSet(set, initialPrefixes) + require.NoError(t, err) + + // Test initial prefixes work + srcIP := netip.MustParseAddr("100.10.0.1") + dstIP1 := netip.MustParseAddr("10.0.0.100") + dstIP2 := netip.MustParseAddr("192.168.1.100") + dstIP3 := netip.MustParseAddr("172.16.0.100") + + _, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) + _, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) + _, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80) + + require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed") + require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed") + require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied") + + newPrefixes := []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("10.1.0.0/24"), + } + + err = manager.UpdateSet(set, newPrefixes) + require.NoError(t, err) + + // Check that all original prefixes are still included + _, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) + _, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) + require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update") + require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update") + + // Check that new prefixes are included + dstIP4 := netip.MustParseAddr("172.16.1.100") + dstIP5 := netip.MustParseAddr("10.1.0.50") + + _, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80) + _, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80) + + require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed") + require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed") + + // Verify the rule has all prefixes + manager.mutex.RLock() + foundRule := false + for _, r := range manager.routeRules { + if r.id == rule.ID() { + foundRule = true + require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes), + "Rule should have all prefixes merged") + } + } + manager.mutex.RUnlock() + require.True(t, foundRule, "Rule should be found") +} + +func TestUpdateSetDeduplication(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + } + + manager, err := Create(ifaceMock, false, flowLogger) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, manager.Close(nil)) + }) + + set := fw.NewDomainSet(domain.List{"example.org"}) + + rule, err := manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + fw.Network{Set: set}, + fw.ProtocolTCP, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + initialPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/24"), // Duplicate + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), // Duplicate + } + + err = manager.UpdateSet(set, initialPrefixes) + require.NoError(t, err) + + // Check the internal state for deduplication + manager.mutex.RLock() + foundRule := false + for _, r := range manager.routeRules { + if r.id == rule.ID() { + foundRule = true + // Should have deduplicated to 2 prefixes + require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed") + + // Check the prefixes are correct + expectedPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + } + for i, prefix := range expectedPrefixes { + require.True(t, r.destinations[i] == prefix, + "Prefix should match expected value") + } + } + } + manager.mutex.RUnlock() + require.True(t, foundRule, "Rule should be found") + + // Test with overlapping prefixes of different sizes + overlappingPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/16"), // More general + netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists) + netip.MustParsePrefix("192.168.0.0/16"), // More general + netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists) + } + + err = manager.UpdateSet(set, overlappingPrefixes) + require.NoError(t, err) + + // Check that all prefixes are included (no deduplication of overlapping prefixes) + manager.mutex.RLock() + for _, r := range manager.routeRules { + if r.id == rule.ID() { + // Should have all 4 prefixes (2 original + 2 new more general ones) + require.Len(t, r.destinations, 4, + "Overlapping prefixes should not be deduplicated") + + // Verify they're sorted correctly (more specific prefixes should come first) + prefixes := make([]string, 0, len(r.destinations)) + for _, p := range r.destinations { + prefixes = append(prefixes, p.String()) + } + + // Check sorted order + require.Equal(t, []string{ + "10.0.0.0/16", + "10.0.0.0/24", + "192.168.0.0/16", + "192.168.1.0/24", + }, prefixes, "Prefixes should be sorted") + } + } + manager.mutex.RUnlock() + + // Test functionality with all prefixes + testCases := []struct { + dstIP netip.Addr + expected bool + desc string + }{ + {netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"}, + {netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"}, + {netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"}, + {netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"}, + {netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"}, + } + + srcIP := netip.MustParseAddr("100.10.0.1") + for _, tc := range testCases { + _, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80) + require.Equal(t, tc.expected, isAllowed, tc.desc) + } +} diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go index 93f16b429..23451453e 100644 --- a/client/internal/acl/id/id.go +++ b/client/internal/acl/id/id.go @@ -18,7 +18,7 @@ func (r RuleID) ID() string { func GenerateRouteRuleKey( sources []netip.Prefix, - destination netip.Prefix, + destination manager.Network, proto manager.Protocol, sPort *manager.Port, dPort *manager.Port, diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 61fbb10ca..6fa35d5c2 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -18,6 +18,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/management/domain" mgmProto "github.com/netbirdio/netbird/management/proto" ) @@ -25,7 +26,7 @@ var ErrSourceRangesEmpty = errors.New("sources range is empty") // Manager is a ACL rules manager type Manager interface { - ApplyFiltering(networkMap *mgmProto.NetworkMap) + ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) } type protoMatch struct { @@ -53,7 +54,7 @@ func NewDefaultManager(fm firewall.Manager) *DefaultManager { // ApplyFiltering firewall rules to the local firewall manager processed by ACL policy. // // If allowByDefault is true it appends allow ALL traffic rules to input and output chains. -func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { +func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) { d.mutex.Lock() defer d.mutex.Unlock() @@ -82,7 +83,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { log.Errorf("failed to set legacy management flag: %v", err) } - if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil { + if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil { log.Errorf("Failed to apply route ACLs: %v", err) } @@ -176,16 +177,16 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { d.peerRulesPairs = newRulePairs } -func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { +func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error { newRouteRules := make(map[id.RuleID]struct{}, len(rules)) var merr *multierror.Error // Apply new rules - firewall manager will return existing rule ID if already present for _, rule := range rules { - id, err := d.applyRouteACL(rule) + id, err := d.applyRouteACL(rule, dynamicResolver) if err != nil { if errors.Is(err, ErrSourceRangesEmpty) { - log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err) + log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err) } else { merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err)) } @@ -208,7 +209,7 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) err return nberrors.FormatErrorOrNil(merr) } -func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { +func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) { if len(rule.SourceRanges) == 0 { return "", ErrSourceRangesEmpty } @@ -222,15 +223,9 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul sources = append(sources, source) } - var destination netip.Prefix - if rule.IsDynamic { - destination = getDefault(sources[0]) - } else { - var err error - destination, err = netip.ParsePrefix(rule.Destination) - if err != nil { - return "", fmt.Errorf("parse destination: %w", err) - } + destination, err := determineDestination(rule, dynamicResolver, sources) + if err != nil { + return "", fmt.Errorf("determine destination: %w", err) } protocol, err := convertToFirewallProtocol(rule.Protocol) @@ -580,6 +575,33 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port { return nil } +func determineDestination(rule *mgmProto.RouteFirewallRule, dynamicResolver bool, sources []netip.Prefix) (firewall.Network, error) { + var destination firewall.Network + + if rule.IsDynamic { + if dynamicResolver { + if len(rule.Domains) > 0 { + destination.Set = firewall.NewDomainSet(domain.FromPunycodeList(rule.Domains)) + } else { + // isDynamic is set but no domains = outdated management server + log.Warn("connected to an older version of management server (no domains in rules), using default destination") + destination.Prefix = getDefault(sources[0]) + } + } else { + // client resolves DNS, we (router) don't know the destination + destination.Prefix = getDefault(sources[0]) + } + return destination, nil + } + + prefix, err := netip.ParsePrefix(rule.Destination) + if err != nil { + return destination, fmt.Errorf("parse destination: %w", err) + } + destination.Prefix = prefix + return destination, nil +} + func getDefault(prefix netip.Prefix) netip.Prefix { if prefix.Addr().Is6() { return netip.PrefixFrom(netip.IPv6Unspecified(), 0) diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 9488d33ab..3595ca600 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -66,7 +66,7 @@ func TestDefaultManager(t *testing.T) { acl := NewDefaultManager(fw) t.Run("apply firewall rules", func(t *testing.T) { - acl.ApplyFiltering(networkMap) + acl.ApplyFiltering(networkMap, false) if len(acl.peerRulesPairs) != 2 { t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) @@ -92,7 +92,7 @@ func TestDefaultManager(t *testing.T) { }, ) - acl.ApplyFiltering(networkMap) + acl.ApplyFiltering(networkMap, false) // we should have one old and one new rule in the existed rules if len(acl.peerRulesPairs) != 2 { @@ -116,13 +116,13 @@ func TestDefaultManager(t *testing.T) { networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRulesIsEmpty = true - if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 { + if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 { t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) return } networkMap.FirewallRulesIsEmpty = false - acl.ApplyFiltering(networkMap) + acl.ApplyFiltering(networkMap, false) if len(acl.peerRulesPairs) != 1 { t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) return @@ -359,7 +359,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { }(fw) acl := NewDefaultManager(fw) - acl.ApplyFiltering(networkMap) + acl.ApplyFiltering(networkMap, false) if len(acl.peerRulesPairs) != 3 { t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) diff --git a/client/internal/debug/debug_linux.go b/client/internal/debug/debug_linux.go index 291531fea..b4907beca 100644 --- a/client/internal/debug/debug_linux.go +++ b/client/internal/debug/debug_linux.go @@ -59,6 +59,16 @@ func collectIPTablesRules() (string, error) { builder.WriteString("\n") } + // Collect ipset information + ipsetOutput, err := collectIPSets() + if err != nil { + log.Warnf("Failed to collect ipset information: %v", err) + } else { + builder.WriteString("=== ipset list output ===\n") + builder.WriteString(ipsetOutput) + builder.WriteString("\n") + } + builder.WriteString("=== iptables -v -n -L output ===\n") tables := []string{"filter", "nat", "mangle", "raw", "security"} @@ -78,6 +88,28 @@ func collectIPTablesRules() (string, error) { return builder.String(), nil } +// collectIPSets collects information about ipsets +func collectIPSets() (string, error) { + cmd := exec.Command("ipset", "list") + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + if strings.Contains(err.Error(), "executable file not found") { + return "", fmt.Errorf("ipset command not found: %w", err) + } + return "", fmt.Errorf("execute ipset list: %w (stderr: %s)", err, stderr.String()) + } + + ipsets := stdout.String() + if strings.TrimSpace(ipsets) == "" { + return "No ipsets found", nil + } + + return ipsets, nil +} + // collectIPTablesSave uses iptables-save to get rule definitions func collectIPTablesSave() (string, error) { cmd := exec.Command("iptables-save") diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 2d69ce858..8f6a31f47 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -3,6 +3,7 @@ package dnsfwd import ( "context" "errors" + "fmt" "math" "net" "net/netip" @@ -10,11 +11,16 @@ import ( "sync" "time" + "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" ) const errResolveFailed = "failed to resolve query for domain=%s: %v" @@ -23,25 +29,27 @@ const upstreamTimeout = 15 * time.Second type DNSForwarder struct { listenAddress string ttl uint32 - domains []string statusRecorder *peer.Status dnsServer *dns.Server mux *dns.ServeMux - resId sync.Map + mutex sync.RWMutex + fwdEntries []*ForwarderEntry + firewall firewall.Manager } -func NewDNSForwarder(listenAddress string, ttl uint32, statusRecorder *peer.Status) *DNSForwarder { +func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder { log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) return &DNSForwarder{ listenAddress: listenAddress, ttl: ttl, + firewall: firewall, statusRecorder: statusRecorder, } } -func (f *DNSForwarder) Listen(domains []string, resIds map[string]string) error { +func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { log.Infof("listen DNS forwarder on address=%s", f.listenAddress) mux := dns.NewServeMux() @@ -53,31 +61,35 @@ func (f *DNSForwarder) Listen(domains []string, resIds map[string]string) error f.dnsServer = dnsServer f.mux = mux - f.UpdateDomains(domains, resIds) + f.UpdateDomains(entries) return dnsServer.ListenAndServe() } -func (f *DNSForwarder) UpdateDomains(domains []string, resIds map[string]string) { - log.Debugf("Updating domains from %v to %v", f.domains, domains) +func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { + f.mutex.Lock() + defer f.mutex.Unlock() - for _, d := range f.domains { - f.mux.HandleRemove(d) + if f.mux == nil { + log.Debug("DNS mux is nil, skipping domain update") + f.fwdEntries = entries + return } - f.resId.Clear() - newDomains := filterDomains(domains) + oldDomains := filterDomains(f.fwdEntries) + + for _, d := range oldDomains { + f.mux.HandleRemove(d.PunycodeString()) + } + + newDomains := filterDomains(entries) for _, d := range newDomains { - f.mux.HandleFunc(d, f.handleDNSQuery) + f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQuery) } - for domain, resId := range resIds { - if domain != "" { - f.resId.Store(domain, resId) - } - } + f.fwdEntries = entries - f.domains = newDomains + log.Debugf("Updated domains from %v to %v", oldDomains, newDomains) } func (f *DNSForwarder) Close(ctx context.Context) error { @@ -91,11 +103,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { if len(query.Question) == 0 { return } - log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", - query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass) - question := query.Question[0] - domain := question.Name + log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", + question.Name, question.Qtype, question.Qclass) + + domain := strings.ToLower(question.Name) resp := query.SetReply(query) var network string @@ -122,21 +134,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { return } - resId := f.getResIdForDomain(strings.TrimSuffix(domain, ".")) - if resId != "" { - for _, ip := range ips { - var ipWithSuffix string - if ip.Is4() { - ipWithSuffix = ip.String() + "/32" - log.Tracef("resolved domain=%s to IPv4=%s", domain, ipWithSuffix) - } else { - ipWithSuffix = ip.String() + "/128" - log.Tracef("resolved domain=%s to IPv6=%s", domain, ipWithSuffix) - } - f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId) - } - } - + f.updateInternalState(domain, ips) f.addIPsToResponse(resp, domain, ips) if err := w.WriteMsg(resp); err != nil { @@ -144,6 +142,42 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { } } +func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) { + var prefixes []netip.Prefix + mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, ".")) + if mostSpecificResId != "" { + for _, ip := range ips { + var prefix netip.Prefix + if ip.Is4() { + prefix = netip.PrefixFrom(ip, 32) + } else { + prefix = netip.PrefixFrom(ip, 128) + } + prefixes = append(prefixes, prefix) + f.statusRecorder.AddResolvedIPLookupEntry(prefix, mostSpecificResId) + } + } + + if f.firewall != nil { + f.updateFirewall(matchingEntries, prefixes) + } +} + +func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixes []netip.Prefix) { + var merr *multierror.Error + for _, entry := range matchingEntries { + if err := f.firewall.UpdateSet(entry.Set, prefixes); err != nil { + merr = multierror.Append(merr, fmt.Errorf("update set for domain=%s: %w", entry.Domain, err)) + } + } + if merr != nil { + log.Errorf("failed to update firewall sets (%d/%d): %v", + len(merr.Errors), + len(matchingEntries), + nberrors.FormatErrorOrNil(merr)) + } +} + // handleDNSError processes DNS lookup errors and sends an appropriate error response func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) { var dnsErr *net.DNSError @@ -204,45 +238,53 @@ func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []neti } } -func (f *DNSForwarder) getResIdForDomain(domain string) string { - var selectedResId string +// getMatchingEntries retrieves the resource IDs for a given domain. +// It returns the most specific match and all matching resource IDs. +func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*ForwarderEntry) { + var selectedResId route.ResID var bestScore int + var matches []*ForwarderEntry - f.resId.Range(func(key, value interface{}) bool { + f.mutex.RLock() + defer f.mutex.RUnlock() + + for _, entry := range f.fwdEntries { var score int - pattern := key.(string) + pattern := entry.Domain.PunycodeString() switch { case strings.HasPrefix(pattern, "*."): baseDomain := strings.TrimPrefix(pattern, "*.") - if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) { + + if strings.EqualFold(domain, baseDomain) || strings.HasSuffix(domain, "."+baseDomain) { score = len(baseDomain) + matches = append(matches, entry) } case domain == pattern: score = math.MaxInt + matches = append(matches, entry) default: - return true + continue } if score > bestScore { bestScore = score - selectedResId = value.(string) + selectedResId = entry.ResID } - return true - }) + } - return selectedResId + return selectedResId, matches } // filterDomains returns a list of normalized domains -func filterDomains(domains []string) []string { - newDomains := make([]string, 0, len(domains)) - for _, d := range domains { - if d == "" { +func filterDomains(entries []*ForwarderEntry) domain.List { + newDomains := make(domain.List, 0, len(entries)) + for _, d := range entries { + if d.Domain == "" { log.Warn("empty domain in DNS forwarder") continue } - newDomains = append(newDomains, nbdns.NormalizeZone(d)) + newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString()))) } return newDomains } diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index 88ffc2af3..f0829bbbd 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -1,56 +1,61 @@ package dnsfwd import ( - "sync" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" ) -func TestGetResIdForDomain(t *testing.T) { +func Test_getMatchingEntries(t *testing.T) { testCases := []struct { name string - storedMappings map[string]string // key: domain pattern, value: resId + storedMappings map[string]route.ResID // key: domain pattern, value: resId queryDomain string - expectedResId string + expectedResId route.ResID }{ { name: "Empty map returns empty string", - storedMappings: map[string]string{}, + storedMappings: map[string]route.ResID{}, queryDomain: "example.com", expectedResId: "", }, { name: "Exact match returns stored resId", - storedMappings: map[string]string{"example.com": "res1"}, + storedMappings: map[string]route.ResID{"example.com": "res1"}, queryDomain: "example.com", expectedResId: "res1", }, { name: "Wildcard pattern matches base domain", - storedMappings: map[string]string{"*.example.com": "res2"}, + storedMappings: map[string]route.ResID{"*.example.com": "res2"}, queryDomain: "example.com", expectedResId: "res2", }, { name: "Wildcard pattern matches subdomain", - storedMappings: map[string]string{"*.example.com": "res3"}, + storedMappings: map[string]route.ResID{"*.example.com": "res3"}, queryDomain: "foo.example.com", expectedResId: "res3", }, { name: "Wildcard pattern does not match different domain", - storedMappings: map[string]string{"*.example.com": "res4"}, + storedMappings: map[string]route.ResID{"*.example.com": "res4"}, queryDomain: "foo.notexample.com", expectedResId: "", }, { name: "Non-wildcard pattern does not match subdomain", - storedMappings: map[string]string{"example.com": "res5"}, + storedMappings: map[string]route.ResID{"example.com": "res5"}, queryDomain: "foo.example.com", expectedResId: "", }, { name: "Exact match over overlapping wildcard", - storedMappings: map[string]string{ + storedMappings: map[string]route.ResID{ "*.example.com": "resWildcard", "foo.example.com": "resExact", }, @@ -59,7 +64,7 @@ func TestGetResIdForDomain(t *testing.T) { }, { name: "Overlapping wildcards: Select more specific wildcard", - storedMappings: map[string]string{ + storedMappings: map[string]route.ResID{ "*.example.com": "resA", "*.sub.example.com": "resB", }, @@ -68,7 +73,7 @@ func TestGetResIdForDomain(t *testing.T) { }, { name: "Wildcard multi-level subdomain match", - storedMappings: map[string]string{ + storedMappings: map[string]route.ResID{ "*.example.com": "resMulti", }, queryDomain: "a.b.example.com", @@ -78,18 +83,21 @@ func TestGetResIdForDomain(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - fwd := &DNSForwarder{ - resId: sync.Map{}, - } + fwd := &DNSForwarder{} + var entries []*ForwarderEntry for domainPattern, resId := range tc.storedMappings { - fwd.resId.Store(domainPattern, resId) + d, err := domain.FromString(domainPattern) + require.NoError(t, err) + entries = append(entries, &ForwarderEntry{ + Domain: d, + ResID: resId, + }) } + fwd.UpdateDomains(entries) - got := fwd.getResIdForDomain(tc.queryDomain) - if got != tc.expectedResId { - t.Errorf("For query domain %q, expected resId %q, but got %q", tc.queryDomain, tc.expectedResId, got) - } + got, _ := fwd.getMatchingEntries(tc.queryDomain) + assert.Equal(t, got, tc.expectedResId) }) } } diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index a51ae7abb..e4a23450f 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -11,6 +11,8 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" ) const ( @@ -19,6 +21,13 @@ const ( dnsTTL = 60 //seconds ) +// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. +type ForwarderEntry struct { + Domain domain.Domain + ResID route.ResID + Set firewall.Set +} + type Manager struct { firewall firewall.Manager statusRecorder *peer.Status @@ -34,7 +43,7 @@ func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { } } -func (m *Manager) Start(domains []string, resIds map[string]string) error { +func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { log.Infof("starting DNS forwarder") if m.dnsForwarder != nil { return nil @@ -44,9 +53,9 @@ func (m *Manager) Start(domains []string, resIds map[string]string) error { return err } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.statusRecorder) + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder) go func() { - if err := m.dnsForwarder.Listen(domains, resIds); err != nil { + if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists log.Errorf("failed to start DNS forwarder, err: %v", err) } @@ -55,12 +64,12 @@ func (m *Manager) Start(domains []string, resIds map[string]string) error { return nil } -func (m *Manager) UpdateDomains(domains []string, resIds map[string]string) { +func (m *Manager) UpdateDomains(entries []*ForwarderEntry) { if m.dnsForwarder == nil { return } - m.dnsForwarder.UpdateDomains(domains, resIds) + m.dnsForwarder.UpdateDomains(entries) } func (m *Manager) Stop(ctx context.Context) error { @@ -81,34 +90,34 @@ func (m *Manager) Stop(ctx context.Context) error { return nberrors.FormatErrorOrNil(mErr) } -func (h *Manager) allowDNSFirewall() error { +func (m *Manager) allowDNSFirewall() error { dport := &firewall.Port{ IsRange: false, Values: []uint16{ListenPort}, } - if h.firewall == nil { + if m.firewall == nil { return nil } - dnsRules, err := h.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "") + dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "") if err != nil { log.Errorf("failed to add allow DNS router rules, err: %v", err) return err } - h.fwRules = dnsRules + m.fwRules = dnsRules return nil } -func (h *Manager) dropDNSFirewall() error { +func (m *Manager) dropDNSFirewall() error { var mErr *multierror.Error - for _, rule := range h.fwRules { - if err := h.firewall.DeletePeerRule(rule); err != nil { + for _, rule := range m.fwRules { + if err := m.firewall.DeletePeerRule(rule); err != nil { mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) } } - h.fwRules = nil + m.fwRules = nil return nberrors.FormatErrorOrNil(mErr) } diff --git a/client/internal/engine.go b/client/internal/engine.go index c377c12e1..b16232883 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -527,7 +527,7 @@ func (e *Engine) blockLanAccess() { if _, err := e.firewall.AddRouteFiltering( nil, []netip.Prefix{v4}, - network, + firewallManager.Network{Prefix: network}, firewallManager.ProtocolALL, nil, nil, @@ -960,21 +960,21 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } } - // DNS forwarder dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) - dnsRouteDomains, resourceIds := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes()) - e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains, resourceIds) + // apply routes first, route related actions might depend on routing being enabled routes := toRoutes(networkMap.GetRoutes()) if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { log.Errorf("failed to update clientRoutes, err: %v", err) } - // acls might need routing to be enabled, so we apply after routes if e.acl != nil { - e.acl.ApplyFiltering(networkMap) + e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag) } + fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) + e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) + // Ingress forward rules if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil { log.Errorf("failed to update forward rules, err: %v", err) @@ -1079,29 +1079,24 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { return routes } -func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) ([]string, map[string]string) { - if protoRoutes == nil { - protoRoutes = []*mgmProto.Route{} - } - - var dnsRoutes []string - resIds := make(map[string]string) - for _, protoRoute := range protoRoutes { - if len(protoRoute.Domains) == 0 { +func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderEntry { + var entries []*dnsfwd.ForwarderEntry + for _, route := range routes { + if len(route.Domains) == 0 { continue } - if protoRoute.Peer == myPubKey { - dnsRoutes = append(dnsRoutes, protoRoute.Domains...) - // resource ID is the first part of the ID - resId := strings.Split(protoRoute.ID, ":") - for _, domain := range protoRoute.Domains { - if len(resId) > 0 { - resIds[domain] = resId[0] - } + if route.Peer == myPubKey { + domainSet := firewallManager.NewDomainSet(route.Domains) + for _, d := range route.Domains { + entries = append(entries, &dnsfwd.ForwarderEntry{ + Domain: d, + Set: domainSet, + ResID: route.GetResourceID(), + }) } } } - return dnsRoutes, resIds + return entries } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config { @@ -1751,7 +1746,10 @@ func (e *Engine) GetWgAddr() net.IP { } // updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag -func (e *Engine) updateDNSForwarder(enabled bool, domains []string, resIds map[string]string) { +func (e *Engine) updateDNSForwarder( + enabled bool, + fwdEntries []*dnsfwd.ForwarderEntry, +) { if !enabled { if e.dnsForwardMgr == nil { return @@ -1762,18 +1760,18 @@ func (e *Engine) updateDNSForwarder(enabled bool, domains []string, resIds map[s return } - if len(domains) > 0 { - log.Infof("enable domain router service for domains: %v", domains) + if len(fwdEntries) > 0 { if e.dnsForwardMgr == nil { e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) - if err := e.dnsForwardMgr.Start(domains, resIds); err != nil { + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil } + + log.Infof("started domain router service with %d entries", len(fwdEntries)) } else { - log.Infof("update domain router service for domains: %v", domains) - e.dnsForwardMgr.UpdateDomains(domains, resIds) + e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { log.Infof("disable domain router service") diff --git a/client/internal/peer/route.go b/client/internal/peer/route.go index c3567dcc9..e5e315e3c 100644 --- a/client/internal/peer/route.go +++ b/client/internal/peer/route.go @@ -6,12 +6,14 @@ import ( "sync" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/route" ) // routeEntry holds the route prefix and the corresponding resource ID. type routeEntry struct { prefix netip.Prefix - resourceID string + resourceID route.ResID } type routeIDLookup struct { @@ -24,7 +26,7 @@ type routeIDLookup struct { resolvedIPs sync.Map } -func (r *routeIDLookup) AddLocalRouteID(resourceID string, route netip.Prefix) { +func (r *routeIDLookup) AddLocalRouteID(resourceID route.ResID, route netip.Prefix) { r.localLock.Lock() defer r.localLock.Unlock() @@ -56,7 +58,7 @@ func (r *routeIDLookup) RemoveLocalRouteID(route netip.Prefix) { } } -func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) { +func (r *routeIDLookup) AddRemoteRouteID(resourceID route.ResID, route netip.Prefix) { r.remoteLock.Lock() defer r.remoteLock.Unlock() @@ -87,7 +89,7 @@ func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) { } } -func (r *routeIDLookup) AddResolvedIP(resourceID string, route netip.Prefix) { +func (r *routeIDLookup) AddResolvedIP(resourceID route.ResID, route netip.Prefix) { r.resolvedIPs.Store(route.Addr(), resourceID) } @@ -97,19 +99,19 @@ func (r *routeIDLookup) RemoveResolvedIP(route netip.Prefix) { // Lookup returns the resource ID for the given IP address // and a bool indicating if the IP is an exit node. -func (r *routeIDLookup) Lookup(ip netip.Addr) (string, bool) { +func (r *routeIDLookup) Lookup(ip netip.Addr) (route.ResID, bool) { if res, ok := r.resolvedIPs.Load(ip); ok { - return res.(string), false + return res.(route.ResID), false } - var resourceID string + var resourceID route.ResID var isExitNode bool r.localLock.RLock() for _, entry := range r.localRoutes { if entry.prefix.Contains(ip) { resourceID = entry.resourceID - isExitNode = (entry.prefix.Bits() == 0) + isExitNode = entry.prefix.Bits() == 0 break } } @@ -120,7 +122,7 @@ func (r *routeIDLookup) Lookup(ip netip.Addr) (string, bool) { for _, entry := range r.remoteRoutes { if entry.prefix.Contains(ip) { resourceID = entry.resourceID - isExitNode = (entry.prefix.Bits() == 0) + isExitNode = entry.prefix.Bits() == 0 break } } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 9b3fc744d..3eca6a8c9 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/management/domain" relayClient "github.com/netbirdio/netbird/relay/client" + "github.com/netbirdio/netbird/route" ) const eventQueueSize = 10 @@ -313,7 +314,7 @@ func (d *Status) UpdatePeerState(receivedState State) error { return nil } -func (d *Status) AddPeerStateRoute(peer string, route string, resourceId string) error { +func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error { d.mux.Lock() defer d.mux.Unlock() @@ -581,7 +582,7 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) { } // AddLocalPeerStateRoute adds a route to the local peer state -func (d *Status) AddLocalPeerStateRoute(route, resourceId string) { +func (d *Status) AddLocalPeerStateRoute(route string, resourceId route.ResID) { d.mux.Lock() defer d.mux.Unlock() @@ -611,14 +612,11 @@ func (d *Status) RemoveLocalPeerStateRoute(route string) { } // AddResolvedIPLookupEntry adds a resolved IP lookup entry -func (d *Status) AddResolvedIPLookupEntry(route, resourceId string) { +func (d *Status) AddResolvedIPLookupEntry(prefix netip.Prefix, resourceId route.ResID) { d.mux.Lock() defer d.mux.Unlock() - pref, err := netip.ParsePrefix(route) - if err == nil { - d.routeIDLookup.AddResolvedIP(resourceId, pref) - } + d.routeIDLookup.AddResolvedIP(resourceId, prefix) } // RemoveResolvedIPLookupEntry removes a resolved IP lookup entry @@ -723,7 +721,7 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) { d.nsGroupStates = dnsStates } -func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId string) { +func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId route.ResID) { d.mux.Lock() defer d.mux.Unlock() diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 68d81d968..6d51c88c0 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -234,7 +234,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { origPattern = writer.GetOrigPattern() } - resolvedDomain := domain.Domain(r.Question[0].Name) + resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name)) // already punycode via RegisterHandler() originalDomain := domain.Domain(origPattern) @@ -328,6 +328,11 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom // Update domain prefixes using resolved domain as key if len(toAdd) > 0 || len(toRemove) > 0 { + if d.route.KeepRoute { + // replace stored prefixes with old + added + // nolint:gocritic + newPrefixes = append(oldPrefixes, toAdd...) + } d.interceptedDomains[resolvedDomain] = newPrefixes originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), ".")) d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID()) @@ -338,7 +343,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom originalDomain.SafeString(), toAdd) } - if len(toRemove) > 0 { + if len(toRemove) > 0 && !d.route.KeepRoute { log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", resolvedDomain.SafeString(), originalDomain.SafeString(), diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index ae0d1d220..078206ab9 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -259,8 +259,6 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } } - m.ctx = nil - m.mux.Lock() defer m.mux.Unlock() m.clientRoutes = nil @@ -292,7 +290,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro return nil } - if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil { + if err := m.serverRouter.updateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil { return fmt.Errorf("update routes: %w", err) } diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index 48bb0380d..953210e9e 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -18,7 +18,7 @@ type serverRouter struct { func (r serverRouter) cleanUp() { } -func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error { +func (r serverRouter) updateRoutes(map[route.ID]*route.Route, bool) error { return nil } diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 18713ee65..131d4c170 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -35,7 +35,10 @@ func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall fi }, nil } -func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { +func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error { + m.mux.Lock() + defer m.mux.Unlock() + serverRoutesToRemove := make([]route.ID, 0) for routeID := range m.routes { @@ -73,7 +76,7 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { continue } - err := m.addToServerNetwork(newRoute) + err := m.addToServerNetwork(newRoute, useNewDNSRoute) if err != nil { log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err) continue @@ -90,57 +93,30 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error { return m.ctx.Err() } - m.mux.Lock() - defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(route) - if err != nil { - return fmt.Errorf("parse prefix: %w", err) - } - - err = m.firewall.RemoveNatRule(routerPair) - if err != nil { + routerPair := routeToRouterPair(route, false) + if err := m.firewall.RemoveNatRule(routerPair); err != nil { return fmt.Errorf("remove routing rules: %w", err) } delete(m.routes, route.ID) - - routeStr := route.Network.String() - if route.IsDynamic() { - routeStr = route.Domains.SafeString() - } - m.statusRecorder.RemoveLocalPeerStateRoute(routeStr) + m.statusRecorder.RemoveLocalPeerStateRoute(route.NetString()) return nil } -func (m *serverRouter) addToServerNetwork(route *route.Route) error { +func (m *serverRouter) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error { if m.ctx.Err() != nil { log.Infof("Not adding to server network because context is done") return m.ctx.Err() } - m.mux.Lock() - defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(route) - if err != nil { - return fmt.Errorf("parse prefix: %w", err) - } - - err = m.firewall.AddNatRule(routerPair) - if err != nil { + routerPair := routeToRouterPair(route, useNewDNSRoute) + if err := m.firewall.AddNatRule(routerPair); err != nil { return fmt.Errorf("insert routing rules: %w", err) } m.routes[route.ID] = route - - routeStr := route.Network.String() - if route.IsDynamic() { - routeStr = route.Domains.SafeString() - } - - m.statusRecorder.AddLocalPeerStateRoute(routeStr, route.GetResourceID()) + m.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID()) return nil } @@ -148,31 +124,29 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error { func (m *serverRouter) cleanUp() { m.mux.Lock() defer m.mux.Unlock() - for _, r := range m.routes { - routerPair, err := routeToRouterPair(r) - if err != nil { - log.Errorf("Failed to convert route to router pair: %v", err) - continue - } - err = m.firewall.RemoveNatRule(routerPair) - if err != nil { + for _, r := range m.routes { + routerPair := routeToRouterPair(r, false) + if err := m.firewall.RemoveNatRule(routerPair); err != nil { log.Errorf("Failed to remove cleanup route: %v", err) } - } m.statusRecorder.CleanLocalPeerStateRoutes() } -func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { - // TODO: add ipv6 +func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterPair { source := getDefaultPrefix(route.Network) - - destination := route.Network.Masked() + destination := firewall.Network{} if route.IsDynamic() { - // TODO: add ipv6 additionally - destination = getDefaultPrefix(destination) + if useNewDNSRoute { + destination.Set = firewall.NewDomainSet(route.Domains) + } else { + // TODO: add ipv6 additionally + destination = getDefaultPrefix(destination.Prefix) + } + } else { + destination.Prefix = route.Network.Masked() } return firewall.RouterPair{ @@ -180,12 +154,16 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { Source: source, Destination: destination, Masquerade: route.Masquerade, - }, nil + } } -func getDefaultPrefix(prefix netip.Prefix) netip.Prefix { +func getDefaultPrefix(prefix netip.Prefix) firewall.Network { if prefix.Addr().Is6() { - return netip.PrefixFrom(netip.IPv6Unspecified(), 0) + return firewall.Network{ + Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + } + } + return firewall.Network{ + Prefix: netip.PrefixFrom(netip.IPv4Unspecified(), 0), } - return netip.PrefixFrom(netip.IPv4Unspecified(), 0) } diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index cf3c2f0aa..59b6346c6 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -45,7 +45,7 @@ var sysctlFailed bool type ruleParams struct { priority int - fwmark int + fwmark uint32 tableID int family int invert bool @@ -55,8 +55,8 @@ type ruleParams struct { func getSetupRules() []ruleParams { return []ruleParams{ - {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, - {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"}, + {100, 0, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, + {100, 0, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"}, {110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"}, {110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"}, } diff --git a/client/server/network.go b/client/server/network.go index e0b01f763..93b7caa46 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -100,7 +100,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro // Convert to proto format for domain, ips := range domainMap { - pbRoute.ResolvedIPs[domain.PunycodeString()] = &proto.IPList{ + pbRoute.ResolvedIPs[domain.SafeString()] = &proto.IPList{ Ips: ips, } } diff --git a/client/status/status.go b/client/status/status.go index 43acc9197..f37e5b0f0 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/version" ) @@ -414,7 +415,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, signalConnString, relaysString, dnsServersString, - overview.FQDN, + domain.Domain(overview.FQDN).SafeString(), interfaceIP, interfaceTypeString, rosenpassEnabledStatus, @@ -508,7 +509,7 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo " Quantum resistance: %s\n"+ " Networks: %s\n"+ " Latency: %s\n", - peerState.FQDN, + domain.Domain(peerState.FQDN).SafeString(), peerState.IP, peerState.PubKey, peerState.Status, diff --git a/dns/dns.go b/dns/dns.go index 8dfdf8526..3a1c76e56 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -111,6 +111,5 @@ func GetParsedDomainLabel(name string) (string, error) { // NormalizeZone returns a normalized domain name without the wildcard prefix func NormalizeZone(domain string) string { - d, _ := strings.CutPrefix(domain, "*.") - return d + return strings.TrimPrefix(domain, "*.") } diff --git a/go.mod b/go.mod index b1b01d446..095840f13 100644 --- a/go.mod +++ b/go.mod @@ -18,9 +18,9 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 - github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.36.0 - golang.org/x/sys v0.31.0 + github.com/vishvananda/netlink v1.3.0 + golang.org/x/crypto v0.37.0 + golang.org/x/sys v0.32.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -39,7 +39,6 @@ require ( github.com/coder/websocket v1.8.12 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 - github.com/davecgh/go-spew v1.1.1 github.com/eko/gocache/lib/v4 v4.2.0 github.com/eko/gocache/store/go_cache/v4 v4.2.2 github.com/eko/gocache/store/redis/v4 v4.2.2 @@ -49,7 +48,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.7.0 github.com/google/gopacket v1.1.19 - github.com/google/nftables v0.2.0 + github.com/google/nftables v0.3.0 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 @@ -100,10 +99,10 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/net v0.38.0 + golang.org/x/net v0.39.0 golang.org/x/oauth2 v0.24.0 - golang.org/x/sync v0.12.0 - golang.org/x/term v0.30.0 + golang.org/x/sync v0.13.0 + golang.org/x/term v0.31.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 @@ -145,6 +144,7 @@ require ( github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/docker v26.1.5+incompatible // indirect @@ -183,7 +183,6 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect - github.com/josharian/native v1.1.0 // indirect github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect @@ -192,7 +191,7 @@ require ( github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect - github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect github.com/mholt/acmez/v2 v2.0.1 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/patternmatcher v0.6.0 // indirect @@ -235,7 +234,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.23.0 // indirect + golang.org/x/text v0.24.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index fb351dd25..8c1c021f8 100644 --- a/go.sum +++ b/go.sum @@ -301,8 +301,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8= -github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4= +github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg= +github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -399,8 +399,6 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= -github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= -github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= @@ -447,8 +445,8 @@ github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= -github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= -github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= @@ -665,9 +663,8 @@ github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYg github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= -github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= -github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= +github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= @@ -752,8 +749,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -846,8 +843,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -876,8 +873,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -902,7 +899,6 @@ golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -911,7 +907,6 @@ golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -939,14 +934,16 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -954,8 +951,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= -golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= +golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= +golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -969,8 +966,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/management/domain/domain.go b/management/domain/domain.go index 2e089b01f..97acec688 100644 --- a/management/domain/domain.go +++ b/management/domain/domain.go @@ -1,12 +1,17 @@ package domain import ( + "strings" + "golang.org/x/net/idna" ) +// Domain represents a punycode-encoded domain string. +// This should only be converted from a string when the string already is in punycode, otherwise use FromString. type Domain string // String converts the Domain to a non-punycode string. +// For an infallible conversion, use SafeString. func (d Domain) String() (string, error) { unicode, err := idna.ToUnicode(string(d)) if err != nil { @@ -15,16 +20,17 @@ func (d Domain) String() (string, error) { return unicode, nil } -// SafeString converts the Domain to a non-punycode string, falling back to the original string if conversion fails. +// SafeString converts the Domain to a non-punycode string, falling back to the punycode string if conversion fails. func (d Domain) SafeString() string { str, err := d.String() if err != nil { - str = string(d) + return string(d) } return str } // PunycodeString returns the punycode representation of the Domain. +// This should only be used if a punycode domain is expected but only a string is supported. func (d Domain) PunycodeString() string { return string(d) } @@ -35,5 +41,5 @@ func FromString(s string) (Domain, error) { if err != nil { return "", err } - return Domain(ascii), nil + return Domain(strings.ToLower(ascii)), nil } diff --git a/management/domain/list.go b/management/domain/list.go index b6090c717..a988f4f70 100644 --- a/management/domain/list.go +++ b/management/domain/list.go @@ -5,6 +5,7 @@ import ( "strings" ) +// List is a slice of punycode-encoded domain strings. type List []Domain // ToStringList converts a List to a slice of string. @@ -53,7 +54,7 @@ func (d List) String() (string, error) { func (d List) SafeString() string { str, err := d.String() if err != nil { - return strings.Join(d.ToPunycodeList(), ", ") + return d.PunycodeString() } return str } @@ -101,7 +102,7 @@ func FromStringList(s []string) (List, error) { func FromPunycodeList(s []string) List { var dl List for _, domain := range s { - dl = append(dl, Domain(domain)) + dl = append(dl, Domain(strings.ToLower(domain))) } return dl } diff --git a/management/domain/validate.go b/management/domain/validate.go index bcbf26e05..a42aebe6f 100644 --- a/management/domain/validate.go +++ b/management/domain/validate.go @@ -22,8 +22,6 @@ func ValidateDomains(domains []string) (List, error) { var domainList List for _, d := range domains { - d := strings.ToLower(d) - // handles length and idna conversion punycode, err := FromString(d) if err != nil { diff --git a/management/server/types/account.go b/management/server/types/account.go index e9fa37085..8315f5796 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1289,7 +1289,7 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer if route.Peer != peer.Key { continue } - resourceAppliedPolicies := resourcePolicies[route.GetResourceID()] + resourceAppliedPolicies := resourcePolicies[string(route.GetResourceID())] distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) diff --git a/route/hauniqueid.go b/route/hauniqueid.go index 4d952beba..064608171 100644 --- a/route/hauniqueid.go +++ b/route/hauniqueid.go @@ -4,13 +4,14 @@ import "strings" const haSeparator = "|" +// HAUniqueID is a unique identifier that is used to group high availability routes. type HAUniqueID string func (id HAUniqueID) String() string { return string(id) } -// NetID returns the Network ID from the HAUniqueID +// NetID returns the NetID from the HAUniqueID func (id HAUniqueID) NetID() NetID { if i := strings.LastIndex(string(id), haSeparator); i != -1 { return NetID(id[:i]) diff --git a/route/route.go b/route/route.go index f7bf3ea87..722dacc2d 100644 --- a/route/route.go +++ b/route/route.go @@ -6,8 +6,6 @@ import ( "slices" "strings" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/status" ) @@ -46,10 +44,16 @@ const ( DomainNetwork ) +// ID is the unique route ID. type ID string +// ResID is the resourceID part of a route.ID (first part before the colon). +type ResID string + +// NetID is the route network identifier, a human-readable string. type NetID string +// HAMap is a map of HAUniqueID to a list of routes. type HAMap map[HAUniqueID][]*Route // NetworkType route network type @@ -162,21 +166,25 @@ func (r *Route) IsDynamic() bool { return r.NetworkType == DomainNetwork } +// GetHAUniqueID returns the HAUniqueID for the route, it can be used for grouping. func (r *Route) GetHAUniqueID() HAUniqueID { - if r.IsDynamic() { - domains, err := r.Domains.String() - if err != nil { - log.Errorf("Failed to convert domains to string: %v", err) - domains = r.Domains.PunycodeString() - } - return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, domains)) - } - return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, r.Network.String())) + return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, r.NetString())) } -// GetResourceID returns the Networks Resource ID from a route ID -func (r *Route) GetResourceID() string { - return strings.Split(string(r.ID), ":")[0] +// GetResourceID returns the Networks ResID from the route ID. +// It's the part before the first colon in the ID string. +func (r *Route) GetResourceID() ResID { + return ResID(strings.Split(string(r.ID), ":")[0]) +} + +// NetString returns the network string. +// If the route is dynamic, it returns the domains as comma-separated punycode-encoded string. +// If the route is not dynamic, it returns the network (prefix) string. +func (r *Route) NetString() string { + if r.IsDynamic() { + return r.Domains.SafeString() + } + return r.Network.String() } // ParseNetwork Parses a network prefix string and returns a netip.Prefix object and if is invalid, IPv4 or IPv6