From f984b8a0916acb1a131585e2a9545f3ae4abee30 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Mon, 27 Feb 2023 16:44:26 +0100 Subject: [PATCH] Proactively expire peers' login per account (#698) Goals: Enable peer login expiration when adding new peer Expire peer's login when the time comes The account manager triggers peer expiration routine in future if the following conditions are true: peer expiration is enabled for the account there is at least one peer that has expiration enabled and is connected The time of the next expiration check is based on the nearest peer expiration. Account manager finds a peer with the oldest last login (auth) timestamp and calculates the time when it has to run the routine as a sum of the configured peer login expiration duration and the peer's last login time. When triggered, the expiration routine checks whether there are expired peers. The management server closes the update channel of these peers and updates network map of other peers to exclude expired peers so that the expired peers are not able to connect anywhere. The account manager can reschedule or cancel peer expiration in the following cases: when admin changes account setting (peer expiration enable/disable) when admin updates the expiration duration of the account when admin updates peer expiration (enable/disable) when peer connects (Sync) P.S. The network map calculation was updated to exclude peers that have login expired. --- management/server/account.go | 110 ++++++- management/server/account_test.go | 435 +++++++++++++++++++++++++++ management/server/grpcserver.go | 7 +- management/server/peer.go | 51 +++- management/server/peer_test.go | 2 +- management/server/scheduler.go | 114 +++++++ management/server/scheduler_test.go | 94 ++++++ management/server/turncredentials.go | 1 + management/server/updatechannel.go | 21 +- 9 files changed, 814 insertions(+), 21 deletions(-) create mode 100644 management/server/scheduler.go create mode 100644 management/server/scheduler_test.go diff --git a/management/server/account.go b/management/server/account.go index e2293e508..ac00462fa 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -119,7 +119,8 @@ type DefaultAccountManager struct { // singleAccountModeDomain is a domain to use in singleAccountMode setup singleAccountModeDomain string // dnsDomain is used for peer resolution. This is appended to the peer's name - dnsDomain string + dnsDomain string + peerLoginExpiry Scheduler } // Settings represents Account settings structure that can be modified via API and Dashboard @@ -307,6 +308,58 @@ func (a *Account) GetGroup(groupID string) *Group { return a.Groups[groupID] } +// GetExpiredPeers returns peers that have been expired +func (a *Account) GetExpiredPeers() []*Peer { + var peers []*Peer + for _, peer := range a.GetPeersWithExpiration() { + expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration) + if expired { + peers = append(peers, peer) + } + } + + return peers +} + +// GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are connected. +func (a *Account) GetNextPeerExpiration() (time.Duration, bool) { + + peersWithExpiry := a.GetPeersWithExpiration() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + // consider only connected peers because others will require login on connecting to the management server + if peer.Status.LoginExpired || !peer.Status.Connected { + continue + } + _, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration) + if nextExpiry == nil || duration < *nextExpiry { + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true +func (a *Account) GetPeersWithExpiration() []*Peer { + peers := make([]*Peer, 0) + for _, peer := range a.Peers { + if peer.LoginExpirationEnabled { + peers = append(peers, peer) + } + } + return peers +} + // GetPeers returns a list of all Account peers func (a *Account) GetPeers() []*Peer { var peers []*Peer @@ -550,13 +603,14 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage cacheLoading: map[string]chan struct{}{}, dnsDomain: dnsDomain, eventStore: eventStore, + peerLoginExpiry: NewDefaultScheduler(), } allAccounts := store.GetAllAccounts() // enable single account mode only if configured by user and number of existing accounts is not grater than 1 am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1 if am.singleAccountMode { if !isDomainValid(singleAccountModeDomain) { - return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for single accound mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) + return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain log.Infof("single account mode enabled, accounts number %d", len(allAccounts)) @@ -640,12 +694,16 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, event := activity.AccountPeerLoginExpirationEnabled if !newSettings.PeerLoginExpirationEnabled { event = activity.AccountPeerLoginExpirationDisabled + am.peerLoginExpiry.Cancel([]string{accountID}) + } else { + am.checkAndSchedulePeerLoginExpiration(account) } am.storeEvent(userID, accountID, accountID, event, nil) } if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { am.storeEvent(userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) + am.checkAndSchedulePeerLoginExpiration(account) } updatedAccount := account.UpdateSettings(newSettings) @@ -658,6 +716,54 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, return updatedAccount, nil } +func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) { + return func() (time.Duration, bool) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + log.Errorf("failed getting account %s expiring peers", account.Id) + return account.GetNextPeerExpiration() + } + + var peerIDs []string + for _, peer := range account.GetExpiredPeers() { + if peer.Status.LoginExpired { + continue + } + peerIDs = append(peerIDs, peer.ID) + peer.MarkLoginExpired(true) + account.UpdatePeer(peer) + err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status) + if err != nil { + log.Errorf("failed saving peer status while expiring peer %s", peer.ID) + return account.GetNextPeerExpiration() + } + } + + log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + + if len(peerIDs) != 0 { + // this will trigger peer disconnect from the management service + am.peersUpdateManager.CloseChannels(peerIDs) + err := am.updateAccountPeers(account) + if err != nil { + log.Errorf("failed updating account peers while expiring peers for account %s", accountID) + return account.GetNextPeerExpiration() + } + } + return account.GetNextPeerExpiration() + } +} + +func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(account *Account) { + am.peerLoginExpiry.Cancel([]string{account.Id}) + if nextRun, ok := account.GetNextPeerExpiration(); ok { + go am.peerLoginExpiry.Schedule(nextRun, account.Id, am.peerLoginExpirationJob(account.Id)) + } +} + // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, error) { diff --git a/management/server/account_test.go b/management/server/account_test.go index 979c41c86..1d672e1b7 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1294,6 +1294,147 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true) assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour) } +func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + require.NoError(t, err, "unable to create an account") + + key, err := wgtypes.GenerateKey() + require.NoError(t, err, "unable to generate WireGuard key") + peer, err := manager.AddPeer("", userID, &Peer{ + Key: key.PublicKey().String(), + Meta: PeerSystemMeta{}, + Name: "test-peer", + LoginExpirationEnabled: true, + }) + require.NoError(t, err, "unable to add peer") + err = manager.MarkPeerConnected(key.PublicKey().String(), true) + require.NoError(t, err, "unable to mark peer connected") + account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + PeerLoginExpiration: time.Hour, + PeerLoginExpirationEnabled: true}) + require.NoError(t, err, "expecting to update account settings successfully but got error") + + wg := &sync.WaitGroup{} + wg.Add(2) + manager.peerLoginExpiry = &MockScheduler{ + CancelFunc: func(IDs []string) { + wg.Done() + }, + ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + wg.Done() + }, + } + + // disable expiration first + update := peer.Copy() + update.LoginExpirationEnabled = false + _, err = manager.UpdatePeer(account.Id, userID, update) + require.NoError(t, err, "unable to update peer") + // enabling expiration should trigger the routine + update.LoginExpirationEnabled = true + _, err = manager.UpdatePeer(account.Id, userID, update) + require.NoError(t, err, "unable to update peer") + + failed := waitTimeout(wg, time.Second) + if failed { + t.Fatal("timeout while waiting for test to finish") + } +} + +func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + require.NoError(t, err, "unable to create an account") + + key, err := wgtypes.GenerateKey() + require.NoError(t, err, "unable to generate WireGuard key") + _, err = manager.AddPeer("", userID, &Peer{ + Key: key.PublicKey().String(), + Meta: PeerSystemMeta{}, + Name: "test-peer", + LoginExpirationEnabled: true, + }) + require.NoError(t, err, "unable to add peer") + _, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + PeerLoginExpiration: time.Hour, + PeerLoginExpirationEnabled: true}) + require.NoError(t, err, "expecting to update account settings successfully but got error") + + wg := &sync.WaitGroup{} + wg.Add(2) + manager.peerLoginExpiry = &MockScheduler{ + CancelFunc: func(IDs []string) { + wg.Done() + }, + ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + wg.Done() + }, + } + + // when we mark peer as connected, the peer login expiration routine should trigger + err = manager.MarkPeerConnected(key.PublicKey().String(), true) + require.NoError(t, err, "unable to mark peer connected") + + failed := waitTimeout(wg, time.Second) + if failed { + t.Fatal("timeout while waiting for test to finish") + } + +} + +func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + require.NoError(t, err, "unable to create an account") + + key, err := wgtypes.GenerateKey() + require.NoError(t, err, "unable to generate WireGuard key") + _, err = manager.AddPeer("", userID, &Peer{ + Key: key.PublicKey().String(), + Meta: PeerSystemMeta{}, + Name: "test-peer", + LoginExpirationEnabled: true, + }) + require.NoError(t, err, "unable to add peer") + err = manager.MarkPeerConnected(key.PublicKey().String(), true) + require.NoError(t, err, "unable to mark peer connected") + + wg := &sync.WaitGroup{} + wg.Add(2) + manager.peerLoginExpiry = &MockScheduler{ + CancelFunc: func(IDs []string) { + wg.Done() + }, + ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + wg.Done() + }, + } + // enabling PeerLoginExpirationEnabled should trigger the expiration job + account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + PeerLoginExpiration: time.Hour, + PeerLoginExpirationEnabled: true}) + require.NoError(t, err, "expecting to update account settings successfully but got error") + + failed := waitTimeout(wg, time.Second) + if failed { + t.Fatal("timeout while waiting for test to finish") + } + wg.Add(1) + + // disabling PeerLoginExpirationEnabled should trigger cancel + _, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + PeerLoginExpiration: time.Hour, + PeerLoginExpirationEnabled: false}) + require.NoError(t, err, "expecting to update account settings successfully but got error") + failed = waitTimeout(wg, time.Second) + if failed { + t.Fatal("timeout while waiting for test to finish") + } +} func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, err := createManager(t) @@ -1326,6 +1467,286 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") } +func TestAccount_GetExpiredPeers(t *testing.T) { + type test struct { + name string + peers map[string]*Peer + expectedPeers map[string]struct{} + } + testCases := []test{ + { + name: "Peers with login expiration disabled, no expired peers", + peers: map[string]*Peer{ + "peer-1": { + LoginExpirationEnabled: false, + }, + "peer-2": { + LoginExpirationEnabled: false, + }, + }, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Two peers expired", + peers: map[string]*Peer{ + "peer-1": { + ID: "peer-1", + LoginExpirationEnabled: true, + Status: &PeerStatus{ + LastSeen: time.Now(), + Connected: true, + LoginExpired: false, + }, + LastLogin: time.Now().Add(-30 * time.Minute), + }, + "peer-2": { + ID: "peer-2", + LoginExpirationEnabled: true, + Status: &PeerStatus{ + LastSeen: time.Now(), + Connected: true, + LoginExpired: false, + }, + LastLogin: time.Now().Add(-2 * time.Hour), + }, + + "peer-3": { + ID: "peer-3", + LoginExpirationEnabled: true, + Status: &PeerStatus{ + LastSeen: time.Now(), + Connected: true, + LoginExpired: false, + }, + LastLogin: time.Now().Add(-1 * time.Hour), + }, + }, + expectedPeers: map[string]struct{}{ + "peer-2": {}, + "peer-3": {}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + Settings: &Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: time.Hour, + }, + } + + expiredPeers := account.GetExpiredPeers() + assert.Len(t, expiredPeers, len(testCase.expectedPeers)) + for _, peer := range expiredPeers { + if _, ok := testCase.expectedPeers[peer.ID]; !ok { + t.Fatalf("expected to have peer %s expired", peer.ID) + } + } + }) + } + +} + +func TestAccount_GetPeersWithExpiration(t *testing.T) { + type test struct { + name string + peers map[string]*Peer + expectedPeers map[string]struct{} + } + + testCases := []test{ + { + name: "No account peers, no peers with expiration", + peers: map[string]*Peer{}, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Peers with login expiration disabled, no peers with expiration", + peers: map[string]*Peer{ + "peer-1": { + LoginExpirationEnabled: false, + }, + "peer-2": { + LoginExpirationEnabled: false, + }, + }, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Peers with login expiration enabled, return peers with expiration", + peers: map[string]*Peer{ + "peer-1": { + ID: "peer-1", + LoginExpirationEnabled: true, + }, + "peer-2": { + LoginExpirationEnabled: false, + }, + }, + expectedPeers: map[string]struct{}{ + "peer-1": {}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + } + + actual := account.GetPeersWithExpiration() + assert.Len(t, actual, len(testCase.expectedPeers)) + if len(testCase.expectedPeers) > 0 { + for k := range testCase.expectedPeers { + contains := false + for _, peer := range actual { + if k == peer.ID { + contains = true + } + } + assert.True(t, contains) + } + } + }) + } + +} + +func TestAccount_GetNextPeerExpiration(t *testing.T) { + + type test struct { + name string + peers map[string]*Peer + expiration time.Duration + expirationEnabled bool + expectedNextRun bool + expectedNextExpiration time.Duration + } + + expectedNextExpiration := time.Minute + testCases := []test{ + { + name: "No peers, no expiration", + peers: map[string]*Peer{}, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "No connected peers, no expiration", + peers: map[string]*Peer{ + "peer-1": { + Status: &PeerStatus{ + Connected: false, + }, + LoginExpirationEnabled: true, + }, + "peer-2": { + Status: &PeerStatus{ + Connected: true, + }, + LoginExpirationEnabled: false, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "Connected peers with disabled expiration, no expiration", + peers: map[string]*Peer{ + "peer-1": { + Status: &PeerStatus{ + Connected: true, + }, + LoginExpirationEnabled: false, + }, + "peer-2": { + Status: &PeerStatus{ + Connected: true, + }, + LoginExpirationEnabled: false, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "Expired peers, no expiration", + peers: map[string]*Peer{ + "peer-1": { + Status: &PeerStatus{ + Connected: true, + LoginExpired: true, + }, + LoginExpirationEnabled: true, + }, + "peer-2": { + Status: &PeerStatus{ + Connected: true, + LoginExpired: true, + }, + LoginExpirationEnabled: true, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "To be expired peer, return expiration", + peers: map[string]*Peer{ + "peer-1": { + Status: &PeerStatus{ + Connected: true, + LoginExpired: false, + }, + LoginExpirationEnabled: true, + LastLogin: time.Now(), + }, + "peer-2": { + Status: &PeerStatus{ + Connected: true, + LoginExpired: true, + }, + LoginExpirationEnabled: true, + }, + }, + expiration: time.Minute, + expirationEnabled: false, + expectedNextRun: true, + expectedNextExpiration: expectedNextExpiration, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + Settings: &Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled}, + } + + expiration, ok := account.GetNextPeerExpiration() + assert.Equal(t, ok, testCase.expectedNextRun) + if testCase.expectedNextRun { + assert.True(t, expiration >= 0 && expiration <= testCase.expectedNextExpiration) + } else { + assert.Equal(t, expiration, testCase.expectedNextExpiration) + } + + }) + } + +} + func createManager(t *testing.T) (*DefaultAccountManager, error) { store, err := createStore(t) if err != nil { @@ -1344,3 +1765,17 @@ func createStore(t *testing.T) (Store, error) { return store, nil } + +func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { + c := make(chan struct{}) + go func() { + defer close(c) + wg.Wait() + }() + select { + case <-c: + return false + case <-time.After(timeout): + return true + } +} diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 7d4d60207..0ee9e0715 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -136,7 +136,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi if err != nil { return status.Error(codes.Internal, "internal server error") } - expired, left := peer.LoginExpired(account.Settings) + expired, left := peer.LoginExpired(account.Settings.PeerLoginExpiration) + expired = account.Settings.PeerLoginExpirationEnabled && expired if peer.UserID != "" && (expired || peer.Status.LoginExpired) { err = s.accountManager.MarkPeerLoginExpired(peerKey.String(), true) if err != nil { @@ -380,7 +381,9 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p if err != nil { return nil, status.Error(codes.Internal, "internal server error") } - expired, left := peer.LoginExpired(account.Settings) + + expired, left := peer.LoginExpired(account.Settings.PeerLoginExpiration) + expired = account.Settings.PeerLoginExpirationEnabled && expired if peer.UserID != "" && (expired || peer.Status.LoginExpired) { // it might be that peer expired but user has logged in already, check token then if loginReq.GetJwtToken() == "" { diff --git a/management/server/peer.go b/management/server/peer.go index 5e3f5e69b..49732421f 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -93,17 +93,24 @@ func (p *Peer) Copy() *Peer { } } +// MarkLoginExpired marks peer's status expired or not +func (p *Peer) MarkLoginExpired(expired bool) { + newStatus := p.Status.Copy() + newStatus.LastSeen = time.Now() + newStatus.LoginExpired = expired + p.Status = newStatus +} + // LoginExpired indicates whether the peer's login has expired or not. // If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired. // Return true if a login has expired, false otherwise, and time left to expiration (negative when expired). // Login expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property. -// Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled -// and if disabled on the Account level, then Peer.LoginExpirationEnabled is ineffective. -func (p *Peer) LoginExpired(accountSettings *Settings) (bool, time.Duration) { - expiresAt := p.LastLogin.Add(accountSettings.PeerLoginExpiration) +// Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled. +func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) { + expiresAt := p.LastLogin.Add(expiresIn) now := time.Now() timeLeft := expiresAt.Sub(now) - return accountSettings.PeerLoginExpirationEnabled && p.LoginExpirationEnabled && (timeLeft <= 0), timeLeft + return p.LoginExpirationEnabled && (timeLeft <= 0), timeLeft } // FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain @@ -202,13 +209,10 @@ func (am *DefaultAccountManager) MarkPeerLoginExpired(peerPubKey string, loginEx return err } - newStatus := peer.Status.Copy() - newStatus.LastSeen = time.Now() - newStatus.LoginExpired = loginExpired - peer.Status = newStatus + peer.MarkLoginExpired(loginExpired) account.UpdatePeer(peer) - err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) + err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status) if err != nil { return err } @@ -237,7 +241,8 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected return err } - newStatus := peer.Status.Copy() + oldStatus := peer.Status.Copy() + newStatus := oldStatus newStatus.LastSeen = time.Now() newStatus.Connected = connected // whenever peer got connected that means that it logged in successfully @@ -251,6 +256,20 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected if err != nil { return err } + + if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(account) + } + + if oldStatus.LoginExpired { + // we need to update other peers because when peer login expires all other peers are notified to disconnect from + //the expired one. Here we notify them that connection is now allowed again. + err = am.updateAccountPeers(account) + if err != nil { + return err + } + } + return nil } @@ -307,6 +326,10 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *Pe event = activity.PeerLoginExpirationDisabled } am.storeEvent(userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + + if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(account) + } } account.UpdatePeer(peer) @@ -529,7 +552,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (* SSHEnabled: false, SSHKey: peer.SSHKey, LastLogin: time.Now(), - LoginExpirationEnabled: false, + LoginExpirationEnabled: true, } // add peer to 'All' group @@ -775,6 +798,10 @@ func (a *Account) getPeersByACL(peerID string) []*Peer { ) continue } + expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration) + if expired { + continue + } // exclude original peer if _, ok := peersSet[peer.ID]; peer.ID != peerID && !ok { peersSet[peer.ID] = struct{}{} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index eb503d218..5ebbad4ec 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -57,7 +57,7 @@ func TestPeer_LoginExpired(t *testing.T) { LastLogin: c.lastLogin, } - expired, _ := peer.LoginExpired(c.accountSettings) + expired, _ := peer.LoginExpired(c.accountSettings.PeerLoginExpiration) assert.Equal(t, expired, c.expected) }) } diff --git a/management/server/scheduler.go b/management/server/scheduler.go new file mode 100644 index 000000000..a35bdc30c --- /dev/null +++ b/management/server/scheduler.go @@ -0,0 +1,114 @@ +package server + +import ( + log "github.com/sirupsen/logrus" + "sync" + "time" +) + +// Scheduler is an interface which implementations can schedule and cancel jobs +type Scheduler interface { + Cancel(IDs []string) + Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) +} + +// MockScheduler is a mock implementation of Scheduler +type MockScheduler struct { + CancelFunc func(IDs []string) + ScheduleFunc func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) +} + +// Cancel mocks the Cancel function of the Scheduler interface +func (mock *MockScheduler) Cancel(IDs []string) { + if mock.CancelFunc != nil { + mock.CancelFunc(IDs) + return + } + log.Errorf("MockScheduler doesn't have Cancel function defined ") +} + +// Schedule mocks the Schedule function of the Scheduler interface +func (mock *MockScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + if mock.ScheduleFunc != nil { + mock.ScheduleFunc(in, ID, job) + return + } + log.Errorf("MockScheduler doesn't have Schedule function defined") +} + +// DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them. +type DefaultScheduler struct { + // jobs map holds cancellation channels indexed by the job ID + jobs map[string]chan struct{} + mu *sync.Mutex +} + +// NewDefaultScheduler creates an instance of a DefaultScheduler +func NewDefaultScheduler() *DefaultScheduler { + return &DefaultScheduler{ + jobs: make(map[string]chan struct{}), + mu: &sync.Mutex{}, + } +} + +func (wm *DefaultScheduler) cancel(ID string) bool { + cancel, ok := wm.jobs[ID] + if ok { + delete(wm.jobs, ID) + select { + case cancel <- struct{}{}: + log.Debugf("cancelled scheduled job %s", ID) + default: + log.Warnf("couldn't cancel job %s because there was no routine listening on the cancel event", ID) + return false + } + + } + return ok +} + +// Cancel cancels the scheduled job by ID if present. +// If job wasn't found the function returns false. +func (wm *DefaultScheduler) Cancel(IDs []string) { + wm.mu.Lock() + defer wm.mu.Unlock() + + for _, id := range IDs { + wm.cancel(id) + } +} + +// Schedule a job to run in some time in the future. If job returns true then it will be scheduled one more time. +// If job with the provided ID already exists, a new one won't be scheduled. +func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + wm.mu.Lock() + defer wm.mu.Unlock() + cancel := make(chan struct{}) + if _, ok := wm.jobs[ID]; ok { + log.Debugf("couldn't schedule a job %s because it already exists. There are %d total jobs scheduled.", + ID, len(wm.jobs)) + return + } + + wm.jobs[ID] = cancel + log.Debugf("scheduled a job %s to run in %s. There are %d total jobs scheduled.", ID, in.String(), len(wm.jobs)) + go func() { + select { + case <-time.After(in): + log.Debugf("time to do a scheduled job %s", ID) + runIn, reschedule := job() + wm.mu.Lock() + defer wm.mu.Unlock() + delete(wm.jobs, ID) + if reschedule { + go wm.Schedule(runIn, ID, job) + } + case <-cancel: + log.Debugf("stopped scheduled job %s ", ID) + wm.mu.Lock() + defer wm.mu.Unlock() + delete(wm.jobs, ID) + return + } + }() +} diff --git a/management/server/scheduler_test.go b/management/server/scheduler_test.go new file mode 100644 index 000000000..0c0cef99b --- /dev/null +++ b/management/server/scheduler_test.go @@ -0,0 +1,94 @@ +package server + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "math/rand" + "sync" + "testing" + "time" +) + +func TestScheduler_Performance(t *testing.T) { + scheduler := NewDefaultScheduler() + n := 500 + wg := &sync.WaitGroup{} + wg.Add(n) + maxMs := 500 + minMs := 50 + for i := 0; i < n; i++ { + millis := time.Duration(rand.Intn(maxMs-minMs)+minMs) * time.Millisecond + go scheduler.Schedule(millis, fmt.Sprintf("test-scheduler-job-%d", i), func() (nextRunIn time.Duration, reschedule bool) { + time.Sleep(millis) + wg.Done() + return 0, false + }) + } + failed := waitTimeout(wg, 3*time.Second) + if failed { + t.Fatal("timed out while waiting for test to finish") + return + } + assert.Len(t, scheduler.jobs, 0) +} + +func TestScheduler_Cancel(t *testing.T) { + jobID1 := "test-scheduler-job-1" + jobID2 := "test-scheduler-job-2" + scheduler := NewDefaultScheduler() + scheduler.Schedule(2*time.Second, jobID1, func() (nextRunIn time.Duration, reschedule bool) { + return 0, false + }) + scheduler.Schedule(2*time.Second, jobID2, func() (nextRunIn time.Duration, reschedule bool) { + return 0, false + }) + + assert.Len(t, scheduler.jobs, 2) + scheduler.Cancel([]string{jobID1}) + assert.Len(t, scheduler.jobs, 1) + assert.NotNil(t, scheduler.jobs[jobID2]) +} + +func TestScheduler_Schedule(t *testing.T) { + jobID := "test-scheduler-job-1" + scheduler := NewDefaultScheduler() + wg := &sync.WaitGroup{} + wg.Add(1) + // job without reschedule should be triggered once + job := func() (nextRunIn time.Duration, reschedule bool) { + wg.Done() + return 0, false + } + scheduler.Schedule(300*time.Millisecond, jobID, job) + failed := waitTimeout(wg, time.Second) + if failed { + t.Fatal("timed out while waiting for test to finish") + return + } + + // job with reschedule should be triggered at least twice + wg = &sync.WaitGroup{} + mx := &sync.Mutex{} + scheduledTimes := 0 + wg.Add(2) + job = func() (nextRunIn time.Duration, reschedule bool) { + mx.Lock() + defer mx.Unlock() + // ensure we repeat only twice + if scheduledTimes < 2 { + wg.Done() + scheduledTimes++ + return 300 * time.Millisecond, true + } + return 0, false + } + + scheduler.Schedule(300*time.Millisecond, jobID, job) + failed = waitTimeout(wg, time.Second) + if failed { + t.Fatal("timed out while waiting for test to finish") + return + } + scheduler.cancel(jobID) + +} diff --git a/management/server/turncredentials.go b/management/server/turncredentials.go index dcfab57dd..752376767 100644 --- a/management/server/turncredentials.go +++ b/management/server/turncredentials.go @@ -115,6 +115,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) { Turns: turns, }, } + log.Debugf("sending new TURN credentials to peer %s", peerID) err := m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update}) if err != nil { log.Errorf("error while sending TURN update to peer %s %v", peerID, err) diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 4b4d6e3d1..6cc10ad24 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -60,10 +60,7 @@ func (p *PeersUpdateManager) CreateChannel(peerID string) chan *UpdateMessage { return channel } -// CloseChannel closes updates channel of a given peer -func (p *PeersUpdateManager) CloseChannel(peerID string) { - p.channelsMux.Lock() - defer p.channelsMux.Unlock() +func (p *PeersUpdateManager) closeChannel(peerID string) { if channel, ok := p.peerChannels[peerID]; ok { delete(p.peerChannels, peerID) close(channel) @@ -72,6 +69,22 @@ func (p *PeersUpdateManager) CloseChannel(peerID string) { log.Debugf("closed updates channel of a peer %s", peerID) } +// CloseChannels closes updates channel for each given peer +func (p *PeersUpdateManager) CloseChannels(peerIDs []string) { + p.channelsMux.Lock() + defer p.channelsMux.Unlock() + for _, id := range peerIDs { + p.closeChannel(id) + } +} + +// CloseChannel closes updates channel of a given peer +func (p *PeersUpdateManager) CloseChannel(peerID string) { + p.channelsMux.Lock() + defer p.channelsMux.Unlock() + p.closeChannel(peerID) +} + // GetAllConnectedPeers returns a copy of the connected peers map func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} { p.channelsMux.Lock()