mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-16 10:20:09 +01:00
[client/management] add peer lock to peer meta update and fix isEqual func (#2840)
This commit is contained in:
parent
44e799c687
commit
4aee3c9e33
@ -11,6 +11,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -1484,6 +1485,17 @@ func (e *Engine) stopDNSServer() {
|
|||||||
|
|
||||||
// isChecksEqual checks if two slices of checks are equal.
|
// isChecksEqual checks if two slices of checks are equal.
|
||||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||||
|
for _, check := range checks {
|
||||||
|
sort.Slice(check.Files, func(i, j int) bool {
|
||||||
|
return check.Files[i] < check.Files[j]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, oCheck := range oChecks {
|
||||||
|
sort.Slice(oCheck.Files, func(i, j int) bool {
|
||||||
|
return oCheck.Files[i] < oCheck.Files[j]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||||
return slices.Equal(checks.Files, oChecks.Files)
|
return slices.Equal(checks.Files, oChecks.Files)
|
||||||
})
|
})
|
||||||
|
@ -1006,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_CheckFilesEqual(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
inputChecks1 []*mgmtProto.Checks
|
||||||
|
inputChecks2 []*mgmtProto.Checks
|
||||||
|
expectedBool bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Equal Files In Equal Order Should Return True",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Equal Files In Reverse Order Should Return True",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile2",
|
||||||
|
"testfile1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unequal Files Should Return False",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Compared With Empty Should Return False",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
|
||||||
|
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2319,7 +2319,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
|||||||
|
|
||||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -2335,6 +2335,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
|||||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
|
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||||
|
defer unlockPeer()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -168,6 +168,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
|
|||||||
|
|
||||||
account.UpdatePeer(peer)
|
account.UpdatePeer(peer)
|
||||||
|
|
||||||
|
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
|
||||||
|
|
||||||
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to save peer status: %w", err)
|
return false, fmt.Errorf("failed to save peer status: %w", err)
|
||||||
@ -657,6 +659,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
|||||||
|
|
||||||
updated := peer.UpdateMetaIfNew(sync.Meta)
|
updated := peer.UpdateMetaIfNew(sync.Meta)
|
||||||
if updated {
|
if updated {
|
||||||
|
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
|
||||||
err = am.Store.SavePeer(ctx, account.Id, peer)
|
err = am.Store.SavePeer(ctx, account.Id, peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
|
return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
|
||||||
|
Loading…
Reference in New Issue
Block a user