diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 4004f1b57..b1ec66286 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -120,14 +120,20 @@ type MockAccountManager struct { GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) + UpdateAccountPeersFunc func(ctx context.Context, accountID string) + BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) } func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { - // do nothing + if am.UpdateAccountPeersFunc != nil { + am.UpdateAccountPeersFunc(ctx, accountID) + } } func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { - // do nothing + if am.BufferUpdateAccountPeersFunc != nil { + am.BufferUpdateAccountPeersFunc(ctx, accountID) + } } func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { diff --git a/management/server/peer.go b/management/server/peer.go index 21a9579fc..c6ade83c0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -9,6 +9,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "time" "github.com/rs/xid" @@ -1280,18 +1281,39 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account } } -func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { - mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{}) - lock := mu.(*sync.Mutex) +type bufferUpdate struct { + mu sync.Mutex + next *time.Timer + update atomic.Bool +} - if !lock.TryLock() { +func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { + bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) + b := bufUpd.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) return } + if b.next != nil { + b.next.Stop() + } + go func() { - time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load())) - lock.Unlock() + defer b.mu.Unlock() am.UpdateAccountPeers(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + if b.next == nil { + b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() { + am.UpdateAccountPeers(ctx, accountID) + }) + return + } + b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load())) }() } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 07ec5037b..d41020514 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "testing" "time" @@ -25,6 +26,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/status" @@ -2251,3 +2253,131 @@ func Test_AddPeer(t *testing.T) { assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes) assert.Equal(t, uint64(totalPeers), account.Network.Serial) } + +func TestBufferUpdateAccountPeers(t *testing.T) { + const ( + peersCount = 1000 + updateAccountInterval = 50 * time.Millisecond + ) + + var ( + deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32 + uapLastRun, dpLastRun atomic.Int64 + + totalNewRuns, totalOldRuns int + ) + + uap := func(ctx context.Context, accountID string) { + updatePeersDeleted.Store(deletedPeers.Load()) + updatePeersRuns.Add(1) + uapLastRun.Store(time.Now().UnixMilli()) + time.Sleep(100 * time.Millisecond) + } + + t.Run("new approach", func(t *testing.T) { + updatePeersRuns.Store(0) + updatePeersDeleted.Store(0) + deletedPeers.Store(0) + + var mustore sync.Map + bufupd := func(ctx context.Context, accountID string) { + mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{}) + b := mu.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) + return + } + + if b.next != nil { + b.next.Stop() + } + + go func() { + defer b.mu.Unlock() + uap(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + b.next = time.AfterFunc(updateAccountInterval, func() { + uap(ctx, accountID) + }) + }() + } + dp := func(ctx context.Context, accountID, peerID, userID string) error { + deletedPeers.Add(1) + dpLastRun.Store(time.Now().UnixMilli()) + time.Sleep(10 * time.Millisecond) + bufupd(ctx, accountID) + return nil + } + + am := mock_server.MockAccountManager{ + UpdateAccountPeersFunc: uap, + BufferUpdateAccountPeersFunc: bufupd, + DeletePeerFunc: dp, + } + empty := "" + for range peersCount { + //nolint + am.DeletePeer(context.Background(), empty, empty, empty) + } + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") + assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") + assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") + + totalNewRuns = int(updatePeersRuns.Load()) + }) + + t.Run("old approach", func(t *testing.T) { + updatePeersRuns.Store(0) + updatePeersDeleted.Store(0) + deletedPeers.Store(0) + + var mustore sync.Map + bufupd := func(ctx context.Context, accountID string) { + mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{}) + b := mu.(*sync.Mutex) + + if !b.TryLock() { + return + } + + go func() { + time.Sleep(updateAccountInterval) + b.Unlock() + uap(ctx, accountID) + }() + } + dp := func(ctx context.Context, accountID, peerID, userID string) error { + deletedPeers.Add(1) + dpLastRun.Store(time.Now().UnixMilli()) + time.Sleep(10 * time.Millisecond) + bufupd(ctx, accountID) + return nil + } + + am := mock_server.MockAccountManager{ + UpdateAccountPeersFunc: uap, + BufferUpdateAccountPeersFunc: bufupd, + DeletePeerFunc: dp, + } + empty := "" + for range peersCount { + //nolint + am.DeletePeer(context.Background(), empty, empty, empty) + } + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") + assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") + assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") + + totalOldRuns = int(updatePeersRuns.Load()) + }) + assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) + t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) +}