diff --git a/management/server/account.go b/management/server/account.go index 82f5ee4a3..daeaf6e55 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -24,6 +24,7 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" @@ -409,14 +410,15 @@ func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.C event = activity.AccountPeerLoginExpirationDisabled am.peerLoginExpiry.Cancel(ctx, []string{accountID}) } else { - am.checkAndSchedulePeerLoginExpiration(ctx, accountID) + am.schedulePeerLoginExpiration(ctx, accountID) } am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) - am.checkAndSchedulePeerLoginExpiration(ctx, accountID) + am.peerLoginExpiry.Cancel(ctx, []string{accountID}) + am.schedulePeerLoginExpiration(ctx, accountID) } } @@ -454,6 +456,10 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context. func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { + //nolint + ctx := context.WithValue(ctx, nbcontext.AccountIDKey, accountID) + //nolint + ctx = context.WithValue(ctx, hook.ExecutionContextKey, fmt.Sprintf("%s-PEER-EXPIRATION", hook.SystemSource)) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -478,8 +484,11 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc } } -func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) { - am.peerLoginExpiry.Cancel(ctx, []string{accountID}) +func (am *DefaultAccountManager) schedulePeerLoginExpiration(ctx context.Context, accountID string) { + if am.peerLoginExpiry.IsSchedulerRunning(accountID) { + log.WithContext(ctx).Tracef("peer login expiration job for account %s is already scheduled", accountID) + return + } if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok { go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID)) } diff --git a/management/server/account_test.go b/management/server/account_test.go index ba0191c03..c3b1f31a6 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1862,11 +1862,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. require.NoError(t, err, "expecting to update account settings successfully but got error") wg := &sync.WaitGroup{} - wg.Add(2) + wg.Add(1) manager.peerLoginExpiry = &MockScheduler{ - CancelFunc: func(ctx context.Context, IDs []string) { - wg.Done() - }, ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { wg.Done() }, diff --git a/management/server/peer.go b/management/server/peer.go index 4a468a6cd..f2469e09b 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -133,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, accountID) + am.schedulePeerLoginExpiration(ctx, accountID) } if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { @@ -296,7 +296,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain)) if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, accountID) + am.peerLoginExpiry.Cancel(ctx, []string{accountID}) + am.schedulePeerLoginExpiration(ctx, accountID) } } diff --git a/management/server/scheduler.go b/management/server/scheduler.go index 147b50fc6..df73c9a1d 100644 --- a/management/server/scheduler.go +++ b/management/server/scheduler.go @@ -12,6 +12,7 @@ import ( type Scheduler interface { Cancel(ctx context.Context, IDs []string) Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) + IsSchedulerRunning(ID string) bool } // MockScheduler is a mock implementation of Scheduler @@ -26,7 +27,7 @@ func (mock *MockScheduler) Cancel(ctx context.Context, IDs []string) { mock.CancelFunc(ctx, IDs) return } - log.WithContext(ctx).Errorf("MockScheduler doesn't have Cancel function defined ") + log.WithContext(ctx).Warnf("MockScheduler doesn't have Cancel function defined ") } // Schedule mocks the Schedule function of the Scheduler interface @@ -35,7 +36,13 @@ func (mock *MockScheduler) Schedule(ctx context.Context, in time.Duration, ID st mock.ScheduleFunc(ctx, in, ID, job) return } - log.WithContext(ctx).Errorf("MockScheduler doesn't have Schedule function defined") + log.WithContext(ctx).Warnf("MockScheduler doesn't have Schedule function defined") +} + +func (mock *MockScheduler) IsSchedulerRunning(ID string) bool { + // MockScheduler does not implement IsSchedulerRunning, so we return false + log.Warnf("MockScheduler doesn't have IsSchedulerRunning function defined") + return false } // DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them. @@ -124,3 +131,11 @@ func (wm *DefaultScheduler) Schedule(ctx context.Context, in time.Duration, ID s }() } + +// IsSchedulerRunning checks if a job with the provided ID is scheduled to run +func (wm *DefaultScheduler) IsSchedulerRunning(ID string) bool { + wm.mu.Lock() + defer wm.mu.Unlock() + _, ok := wm.jobs[ID] + return ok +}