mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
add peer lock to peer meta update and fix isEqual func
This commit is contained in:
parent
a9d06b883f
commit
b0518933cb
@ -11,6 +11,7 @@ import (
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -38,7 +39,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@ -171,7 +171,7 @@ type Engine struct {
|
||||
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
srWatcher *guard.SRWatcher
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@ -1481,6 +1481,17 @@ func (e *Engine) stopDNSServer() {
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||
for _, check := range checks {
|
||||
sort.Slice(check.Files, func(i, j int) bool {
|
||||
return check.Files[i] < check.Files[j]
|
||||
})
|
||||
}
|
||||
for _, oCheck := range oChecks {
|
||||
sort.Slice(oCheck.Files, func(i, j int) bool {
|
||||
return oCheck.Files[i] < oCheck.Files[j]
|
||||
})
|
||||
}
|
||||
|
||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||
return slices.Equal(checks.Files, oChecks.Files)
|
||||
})
|
||||
|
@ -1006,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CheckFilesEqual(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputChecks1 []*mgmtProto.Checks
|
||||
inputChecks2 []*mgmtProto.Checks
|
||||
expectedBool bool
|
||||
}{
|
||||
{
|
||||
name: "Equal Files In Equal Order Should Return True",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "Equal Files In Reverse Order Should Return True",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile2",
|
||||
"testfile1",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "Unequal Files Should Return False",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile3",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: false,
|
||||
},
|
||||
{
|
||||
name: "Compared With Empty Should Return False",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{},
|
||||
},
|
||||
},
|
||||
expectedBool: false,
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
|
||||
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
|
@ -2319,7 +2319,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -2335,6 +2335,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer unlockPeer()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -166,6 +166,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
|
||||
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
|
||||
|
||||
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@ -654,6 +656,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
|
||||
updated := peer.UpdateMetaIfNew(sync.Meta)
|
||||
if updated {
|
||||
log.WithContext(ctx).Debugf("peer %s metadata updated", peer.ID)
|
||||
err = am.Store.SavePeer(ctx, account.Id, peer)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
|
Loading…
Reference in New Issue
Block a user