diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index faf0fadaa..c755194f0 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -4,7 +4,7 @@ import ( "context" "fmt" "reflect" - runtime "runtime" + "runtime" "time" "github.com/hashicorp/go-multierror" @@ -290,7 +290,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { // If the chosen route is the same as the current route, do nothing if c.currentChosen != nil && c.currentChosen.ID == newChosenID && - c.currentChosen.IsEqual(c.routes[newChosenID]) { + c.currentChosen.Equal(c.routes[newChosenID]) { return nil } diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 6ff80e52d..68218c0e2 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -40,7 +40,7 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { for routeID := range m.routes { update, found := routesMap[routeID] - if !found || !update.IsEqual(m.routes[routeID]) { + if !found || !update.Equal(m.routes[routeID]) { serverRoutesToRemove = append(serverRoutesToRemove, routeID) } } diff --git a/management/domain/list.go b/management/domain/list.go index 413a23442..b6090c717 100644 --- a/management/domain/list.go +++ b/management/domain/list.go @@ -1,6 +1,9 @@ package domain -import "strings" +import ( + "sort" + "strings" +) type List []Domain @@ -60,6 +63,27 @@ func (d List) PunycodeString() string { return strings.Join(d.ToPunycodeList(), ", ") } +func (d List) Equal(domains List) bool { + if len(d) != len(domains) { + return false + } + + sort.Slice(d, func(i, j int) bool { + return d[i] < d[j] + }) + + sort.Slice(domains, func(i, j int) bool { + return domains[i] < domains[j] + }) + + for i, domain := range d { + if domain != domains[i] { + return false + } + } + return true +} + // FromStringList creates a DomainList from a slice of string. func FromStringList(s []string) (List, error) { var dl List diff --git a/management/domain/list_test.go b/management/domain/list_test.go new file mode 100644 index 000000000..5000af01c --- /dev/null +++ b/management/domain/list_test.go @@ -0,0 +1,49 @@ +package domain + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_EqualReturnsTrueForIdenticalLists(t *testing.T) { + list1 := List{"domain1", "domain2", "domain3"} + list2 := List{"domain1", "domain2", "domain3"} + + assert.True(t, list1.Equal(list2)) +} + +func Test_EqualReturnsFalseForDifferentLengths(t *testing.T) { + list1 := List{"domain1", "domain2"} + list2 := List{"domain1", "domain2", "domain3"} + + assert.False(t, list1.Equal(list2)) +} + +func Test_EqualReturnsFalseForDifferentElements(t *testing.T) { + list1 := List{"domain1", "domain2", "domain3"} + list2 := List{"domain1", "domain4", "domain3"} + + assert.False(t, list1.Equal(list2)) +} + +func Test_EqualReturnsTrueForUnsortedIdenticalLists(t *testing.T) { + list1 := List{"domain3", "domain1", "domain2"} + list2 := List{"domain1", "domain2", "domain3"} + + assert.True(t, list1.Equal(list2)) +} + +func Test_EqualReturnsFalseForEmptyAndNonEmptyList(t *testing.T) { + list1 := List{} + list2 := List{"domain1"} + + assert.False(t, list1.Equal(list2)) +} + +func Test_EqualReturnsTrueForBothEmptyLists(t *testing.T) { + list1 := List{} + list2 := List{} + + assert.True(t, list1.Equal(list2)) +} diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 73fc6edba..90f9670d1 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -284,7 +284,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { for _, rule := range firewallRules { contains := false for _, expectedRule := range epectedFirewallRules { - if rule.IsEqual(expectedRule) { + if rule.Equal(expectedRule) { contains = true break } diff --git a/management/server/route_test.go b/management/server/route_test.go index 7169316d4..e4585753f 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -459,7 +459,7 @@ func TestCreateRoute(t *testing.T) { // assign generated ID testCase.expectedRoute.ID = outRoute.ID - if !testCase.expectedRoute.IsEqual(outRoute) { + if !testCase.expectedRoute.Equal(outRoute) { t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", outRoute, testCase.expectedRoute) } }) @@ -1000,7 +1000,7 @@ func TestSaveRoute(t *testing.T) { savedRoute, saved := account.Routes[testCase.expectedRoute.ID] require.True(t, saved) - if !testCase.expectedRoute.IsEqual(savedRoute) { + if !testCase.expectedRoute.Equal(savedRoute) { t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", savedRoute, testCase.expectedRoute) } }) @@ -1194,7 +1194,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, peer1Routes.Routes, 1, "we should receive one route for peer1") - require.True(t, expectedRoute.IsEqual(peer1Routes.Routes[0]), "received route should be equal") + require.True(t, expectedRoute.Equal(peer1Routes.Routes[0]), "received route should be equal") peer2Routes, err := am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) @@ -1206,7 +1206,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) require.Len(t, peer2Routes.Routes, 1, "we should receive one route") - require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") + require.True(t, peer1Routes.Routes[0].Equal(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") newGroup := &types.Group{ ID: xid.New().String(), diff --git a/management/server/types/firewall_rule.go b/management/server/types/firewall_rule.go index b96ccea42..118a6a3d2 100644 --- a/management/server/types/firewall_rule.go +++ b/management/server/types/firewall_rule.go @@ -38,8 +38,8 @@ type FirewallRule struct { PortRange RulePortRange } -// IsEqual checks if two firewall rules are equal. -func (r *FirewallRule) IsEqual(other *FirewallRule) bool { +// Equal checks if two firewall rules are equal. +func (r *FirewallRule) Equal(other *FirewallRule) bool { return r.PeerIP == other.PeerIP && r.Direction == other.Direction && r.Action == other.Action && diff --git a/management/server/types/network.go b/management/server/types/network.go index 26153a7d5..00082bb41 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -8,11 +8,13 @@ import ( "github.com/c-robinson/iplib" "github.com/rs/xid" + "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) @@ -38,12 +40,26 @@ type NetworkMap struct { } func (nm *NetworkMap) Merge(other *NetworkMap) { - nm.Peers = append(nm.Peers, other.Peers...) - nm.Routes = append(nm.Routes, other.Routes...) - nm.OfflinePeers = append(nm.OfflinePeers, other.OfflinePeers...) - nm.FirewallRules = append(nm.FirewallRules, other.FirewallRules...) - nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, other.RoutesFirewallRules...) - nm.ForwardingRules = append(nm.ForwardingRules, other.ForwardingRules...) + nm.Peers = mergeUniquePeersByID(nm.Peers, other.Peers) + nm.Routes = util.MergeUnique(nm.Routes, other.Routes) + nm.OfflinePeers = mergeUniquePeersByID(nm.OfflinePeers, other.OfflinePeers) + nm.FirewallRules = util.MergeUnique(nm.FirewallRules, other.FirewallRules) + nm.RoutesFirewallRules = util.MergeUnique(nm.RoutesFirewallRules, other.RoutesFirewallRules) + nm.ForwardingRules = util.MergeUnique(nm.ForwardingRules, other.ForwardingRules) +} + +func mergeUniquePeersByID(peers1, peers2 []*nbpeer.Peer) []*nbpeer.Peer { + result := make(map[string]*nbpeer.Peer) + for _, peer := range peers1 { + result[peer.ID] = peer + } + for _, peer := range peers2 { + if _, ok := result[peer.ID]; !ok { + result[peer.ID] = peer + } + } + + return maps.Values(result) } type ForwardingRule struct { @@ -75,6 +91,13 @@ func (f *ForwardingRule) ToProto() *proto.ForwardingRule { } } +func (f *ForwardingRule) Equal(other *ForwardingRule) bool { + return f.RuleProtocol == other.RuleProtocol && + f.DestinationPorts.Equal(&other.DestinationPorts) && + f.TranslatedAddress.Equal(other.TranslatedAddress) && + f.TranslatedPorts.Equal(&other.TranslatedPorts) +} + func ipToBytes(ip net.IP) []byte { if ip4 := ip.To4(); ip4 != nil { return ip4 diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go index fc34a0b6f..b86732415 100644 --- a/management/server/types/policyrule.go +++ b/management/server/types/policyrule.go @@ -33,6 +33,10 @@ func (r *RulePortRange) ToProto() *proto.PortInfo { } } +func (r *RulePortRange) Equal(other *RulePortRange) bool { + return r.Start == other.Start && r.End == other.End +} + // PolicyRule is the metadata of the policy type PolicyRule struct { // ID of the policy rule diff --git a/management/server/types/route_firewall_rule.go b/management/server/types/route_firewall_rule.go index 64708d68a..18eda7eda 100644 --- a/management/server/types/route_firewall_rule.go +++ b/management/server/types/route_firewall_rule.go @@ -30,3 +30,28 @@ type RouteFirewallRule struct { // isDynamic indicates whether the rule is for DNS routing IsDynamic bool } + +func (r *RouteFirewallRule) Equal(other *RouteFirewallRule) bool { + if r.Action != other.Action { + return false + } + if r.Destination != other.Destination { + return false + } + if r.Protocol != other.Protocol { + return false + } + if r.Port != other.Port { + return false + } + if !r.PortRange.Equal(&other.PortRange) { + return false + } + if !r.Domains.Equal(other.Domains) { + return false + } + if r.IsDynamic != other.IsDynamic { + return false + } + return true +} diff --git a/management/server/util/util.go b/management/server/util/util.go index d85b55f02..617484274 100644 --- a/management/server/util/util.go +++ b/management/server/util/util.go @@ -19,3 +19,34 @@ func Difference(a, b []string) []string { func ToPtr[T any](value T) *T { return &value } + +type comparableObject[T any] interface { + Equal(other T) bool +} + +func MergeUnique[T comparableObject[T]](arr1, arr2 []T) []T { + var result []T + + for _, item := range arr1 { + if !contains(result, item) { + result = append(result, item) + } + } + + for _, item := range arr2 { + if !contains(result, item) { + result = append(result, item) + } + } + + return result +} + +func contains[T comparableObject[T]](slice []T, element T) bool { + for _, item := range slice { + if item.Equal(element) { + return true + } + } + return false +} diff --git a/management/server/util/util_test.go b/management/server/util/util_test.go new file mode 100644 index 000000000..5c928b369 --- /dev/null +++ b/management/server/util/util_test.go @@ -0,0 +1,41 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type testObject struct { + value int +} + +func (t testObject) Equal(other testObject) bool { + return t.value == other.value +} + +func Test_MergeUniqueArraysWithoutDuplicates(t *testing.T) { + arr1 := []testObject{{value: 1}, {value: 2}} + arr2 := []testObject{{value: 2}, {value: 3}} + result := MergeUnique(arr1, arr2) + assert.Len(t, result, 3) + assert.Contains(t, result, testObject{value: 1}) + assert.Contains(t, result, testObject{value: 2}) + assert.Contains(t, result, testObject{value: 3}) +} + +func Test_MergeUniqueHandlesEmptyArrays(t *testing.T) { + arr1 := []testObject{} + arr2 := []testObject{} + result := MergeUnique(arr1, arr2) + assert.Empty(t, result) +} + +func Test_MergeUniqueHandlesOneEmptyArray(t *testing.T) { + arr1 := []testObject{{value: 1}, {value: 2}} + arr2 := []testObject{} + result := MergeUnique(arr1, arr2) + assert.Len(t, result, 2) + assert.Contains(t, result, testObject{value: 1}) + assert.Contains(t, result, testObject{value: 2}) +} diff --git a/route/route.go b/route/route.go index ad2aaba89..f7bf3ea87 100644 --- a/route/route.go +++ b/route/route.go @@ -132,8 +132,8 @@ func (r *Route) Copy() *Route { return route } -// IsEqual compares one route with the other -func (r *Route) IsEqual(other *Route) bool { +// Equal compares one route with the other +func (r *Route) Equal(other *Route) bool { if r == nil && other == nil { return true } else if r == nil || other == nil {