[management merge only unique entries on network map merge (#3277)

This commit is contained in:
Pascal Fischer 2025-02-05 16:50:45 +01:00 committed by GitHub
parent b2a5b29fb2
commit 035c5d9f23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 216 additions and 19 deletions

View File

@ -4,7 +4,7 @@ import (
"context"
"fmt"
"reflect"
runtime "runtime"
"runtime"
"time"
"github.com/hashicorp/go-multierror"
@ -290,7 +290,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
// If the chosen route is the same as the current route, do nothing
if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
c.currentChosen.IsEqual(c.routes[newChosenID]) {
c.currentChosen.Equal(c.routes[newChosenID]) {
return nil
}

View File

@ -40,7 +40,7 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
for routeID := range m.routes {
update, found := routesMap[routeID]
if !found || !update.IsEqual(m.routes[routeID]) {
if !found || !update.Equal(m.routes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
}
}

View File

@ -1,6 +1,9 @@
package domain
import "strings"
import (
"sort"
"strings"
)
type List []Domain
@ -60,6 +63,27 @@ func (d List) PunycodeString() string {
return strings.Join(d.ToPunycodeList(), ", ")
}
func (d List) Equal(domains List) bool {
if len(d) != len(domains) {
return false
}
sort.Slice(d, func(i, j int) bool {
return d[i] < d[j]
})
sort.Slice(domains, func(i, j int) bool {
return domains[i] < domains[j]
})
for i, domain := range d {
if domain != domains[i] {
return false
}
}
return true
}
// FromStringList creates a DomainList from a slice of string.
func FromStringList(s []string) (List, error) {
var dl List

View File

@ -0,0 +1,49 @@
package domain
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_EqualReturnsTrueForIdenticalLists(t *testing.T) {
list1 := List{"domain1", "domain2", "domain3"}
list2 := List{"domain1", "domain2", "domain3"}
assert.True(t, list1.Equal(list2))
}
func Test_EqualReturnsFalseForDifferentLengths(t *testing.T) {
list1 := List{"domain1", "domain2"}
list2 := List{"domain1", "domain2", "domain3"}
assert.False(t, list1.Equal(list2))
}
func Test_EqualReturnsFalseForDifferentElements(t *testing.T) {
list1 := List{"domain1", "domain2", "domain3"}
list2 := List{"domain1", "domain4", "domain3"}
assert.False(t, list1.Equal(list2))
}
func Test_EqualReturnsTrueForUnsortedIdenticalLists(t *testing.T) {
list1 := List{"domain3", "domain1", "domain2"}
list2 := List{"domain1", "domain2", "domain3"}
assert.True(t, list1.Equal(list2))
}
func Test_EqualReturnsFalseForEmptyAndNonEmptyList(t *testing.T) {
list1 := List{}
list2 := List{"domain1"}
assert.False(t, list1.Equal(list2))
}
func Test_EqualReturnsTrueForBothEmptyLists(t *testing.T) {
list1 := List{}
list2 := List{}
assert.True(t, list1.Equal(list2))
}

View File

@ -284,7 +284,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
for _, rule := range firewallRules {
contains := false
for _, expectedRule := range epectedFirewallRules {
if rule.IsEqual(expectedRule) {
if rule.Equal(expectedRule) {
contains = true
break
}

View File

@ -459,7 +459,7 @@ func TestCreateRoute(t *testing.T) {
// assign generated ID
testCase.expectedRoute.ID = outRoute.ID
if !testCase.expectedRoute.IsEqual(outRoute) {
if !testCase.expectedRoute.Equal(outRoute) {
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", outRoute, testCase.expectedRoute)
}
})
@ -1000,7 +1000,7 @@ func TestSaveRoute(t *testing.T) {
savedRoute, saved := account.Routes[testCase.expectedRoute.ID]
require.True(t, saved)
if !testCase.expectedRoute.IsEqual(savedRoute) {
if !testCase.expectedRoute.Equal(savedRoute) {
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", savedRoute, testCase.expectedRoute)
}
})
@ -1194,7 +1194,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID)
require.NoError(t, err)
require.Len(t, peer1Routes.Routes, 1, "we should receive one route for peer1")
require.True(t, expectedRoute.IsEqual(peer1Routes.Routes[0]), "received route should be equal")
require.True(t, expectedRoute.Equal(peer1Routes.Routes[0]), "received route should be equal")
peer2Routes, err := am.GetNetworkMap(context.Background(), peer2ID)
require.NoError(t, err)
@ -1206,7 +1206,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID)
require.NoError(t, err)
require.Len(t, peer2Routes.Routes, 1, "we should receive one route")
require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group")
require.True(t, peer1Routes.Routes[0].Equal(peer2Routes.Routes[0]), "routes should be the same for peers in the same group")
newGroup := &types.Group{
ID: xid.New().String(),

View File

@ -38,8 +38,8 @@ type FirewallRule struct {
PortRange RulePortRange
}
// IsEqual checks if two firewall rules are equal.
func (r *FirewallRule) IsEqual(other *FirewallRule) bool {
// Equal checks if two firewall rules are equal.
func (r *FirewallRule) Equal(other *FirewallRule) bool {
return r.PeerIP == other.PeerIP &&
r.Direction == other.Direction &&
r.Action == other.Action &&

View File

@ -8,11 +8,13 @@ import (
"github.com/c-robinson/iplib"
"github.com/rs/xid"
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
)
@ -38,12 +40,26 @@ type NetworkMap struct {
}
func (nm *NetworkMap) Merge(other *NetworkMap) {
nm.Peers = append(nm.Peers, other.Peers...)
nm.Routes = append(nm.Routes, other.Routes...)
nm.OfflinePeers = append(nm.OfflinePeers, other.OfflinePeers...)
nm.FirewallRules = append(nm.FirewallRules, other.FirewallRules...)
nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, other.RoutesFirewallRules...)
nm.ForwardingRules = append(nm.ForwardingRules, other.ForwardingRules...)
nm.Peers = mergeUniquePeersByID(nm.Peers, other.Peers)
nm.Routes = util.MergeUnique(nm.Routes, other.Routes)
nm.OfflinePeers = mergeUniquePeersByID(nm.OfflinePeers, other.OfflinePeers)
nm.FirewallRules = util.MergeUnique(nm.FirewallRules, other.FirewallRules)
nm.RoutesFirewallRules = util.MergeUnique(nm.RoutesFirewallRules, other.RoutesFirewallRules)
nm.ForwardingRules = util.MergeUnique(nm.ForwardingRules, other.ForwardingRules)
}
func mergeUniquePeersByID(peers1, peers2 []*nbpeer.Peer) []*nbpeer.Peer {
result := make(map[string]*nbpeer.Peer)
for _, peer := range peers1 {
result[peer.ID] = peer
}
for _, peer := range peers2 {
if _, ok := result[peer.ID]; !ok {
result[peer.ID] = peer
}
}
return maps.Values(result)
}
type ForwardingRule struct {
@ -75,6 +91,13 @@ func (f *ForwardingRule) ToProto() *proto.ForwardingRule {
}
}
func (f *ForwardingRule) Equal(other *ForwardingRule) bool {
return f.RuleProtocol == other.RuleProtocol &&
f.DestinationPorts.Equal(&other.DestinationPorts) &&
f.TranslatedAddress.Equal(other.TranslatedAddress) &&
f.TranslatedPorts.Equal(&other.TranslatedPorts)
}
func ipToBytes(ip net.IP) []byte {
if ip4 := ip.To4(); ip4 != nil {
return ip4

View File

@ -33,6 +33,10 @@ func (r *RulePortRange) ToProto() *proto.PortInfo {
}
}
func (r *RulePortRange) Equal(other *RulePortRange) bool {
return r.Start == other.Start && r.End == other.End
}
// PolicyRule is the metadata of the policy
type PolicyRule struct {
// ID of the policy rule

View File

@ -30,3 +30,28 @@ type RouteFirewallRule struct {
// isDynamic indicates whether the rule is for DNS routing
IsDynamic bool
}
func (r *RouteFirewallRule) Equal(other *RouteFirewallRule) bool {
if r.Action != other.Action {
return false
}
if r.Destination != other.Destination {
return false
}
if r.Protocol != other.Protocol {
return false
}
if r.Port != other.Port {
return false
}
if !r.PortRange.Equal(&other.PortRange) {
return false
}
if !r.Domains.Equal(other.Domains) {
return false
}
if r.IsDynamic != other.IsDynamic {
return false
}
return true
}

View File

@ -19,3 +19,34 @@ func Difference(a, b []string) []string {
func ToPtr[T any](value T) *T {
return &value
}
type comparableObject[T any] interface {
Equal(other T) bool
}
func MergeUnique[T comparableObject[T]](arr1, arr2 []T) []T {
var result []T
for _, item := range arr1 {
if !contains(result, item) {
result = append(result, item)
}
}
for _, item := range arr2 {
if !contains(result, item) {
result = append(result, item)
}
}
return result
}
func contains[T comparableObject[T]](slice []T, element T) bool {
for _, item := range slice {
if item.Equal(element) {
return true
}
}
return false
}

View File

@ -0,0 +1,41 @@
package util
import (
"testing"
"github.com/stretchr/testify/assert"
)
type testObject struct {
value int
}
func (t testObject) Equal(other testObject) bool {
return t.value == other.value
}
func Test_MergeUniqueArraysWithoutDuplicates(t *testing.T) {
arr1 := []testObject{{value: 1}, {value: 2}}
arr2 := []testObject{{value: 2}, {value: 3}}
result := MergeUnique(arr1, arr2)
assert.Len(t, result, 3)
assert.Contains(t, result, testObject{value: 1})
assert.Contains(t, result, testObject{value: 2})
assert.Contains(t, result, testObject{value: 3})
}
func Test_MergeUniqueHandlesEmptyArrays(t *testing.T) {
arr1 := []testObject{}
arr2 := []testObject{}
result := MergeUnique(arr1, arr2)
assert.Empty(t, result)
}
func Test_MergeUniqueHandlesOneEmptyArray(t *testing.T) {
arr1 := []testObject{{value: 1}, {value: 2}}
arr2 := []testObject{}
result := MergeUnique(arr1, arr2)
assert.Len(t, result, 2)
assert.Contains(t, result, testObject{value: 1})
assert.Contains(t, result, testObject{value: 2})
}

View File

@ -132,8 +132,8 @@ func (r *Route) Copy() *Route {
return route
}
// IsEqual compares one route with the other
func (r *Route) IsEqual(other *Route) bool {
// Equal compares one route with the other
func (r *Route) Equal(other *Route) bool {
if r == nil && other == nil {
return true
} else if r == nil || other == nil {