mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-29 03:23:56 +01:00
Proactively expire peers' login per account (#698)
Goals: Enable peer login expiration when adding new peer Expire peer's login when the time comes The account manager triggers peer expiration routine in future if the following conditions are true: peer expiration is enabled for the account there is at least one peer that has expiration enabled and is connected The time of the next expiration check is based on the nearest peer expiration. Account manager finds a peer with the oldest last login (auth) timestamp and calculates the time when it has to run the routine as a sum of the configured peer login expiration duration and the peer's last login time. When triggered, the expiration routine checks whether there are expired peers. The management server closes the update channel of these peers and updates network map of other peers to exclude expired peers so that the expired peers are not able to connect anywhere. The account manager can reschedule or cancel peer expiration in the following cases: when admin changes account setting (peer expiration enable/disable) when admin updates the expiration duration of the account when admin updates peer expiration (enable/disable) when peer connects (Sync) P.S. The network map calculation was updated to exclude peers that have login expired.
This commit is contained in:
parent
9f951c8fb5
commit
f984b8a091
@ -120,6 +120,7 @@ type DefaultAccountManager struct {
|
|||||||
singleAccountModeDomain string
|
singleAccountModeDomain string
|
||||||
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
||||||
dnsDomain string
|
dnsDomain string
|
||||||
|
peerLoginExpiry Scheduler
|
||||||
}
|
}
|
||||||
|
|
||||||
// Settings represents Account settings structure that can be modified via API and Dashboard
|
// Settings represents Account settings structure that can be modified via API and Dashboard
|
||||||
@ -307,6 +308,58 @@ func (a *Account) GetGroup(groupID string) *Group {
|
|||||||
return a.Groups[groupID]
|
return a.Groups[groupID]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetExpiredPeers returns peers that have been expired
|
||||||
|
func (a *Account) GetExpiredPeers() []*Peer {
|
||||||
|
var peers []*Peer
|
||||||
|
for _, peer := range a.GetPeersWithExpiration() {
|
||||||
|
expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration)
|
||||||
|
if expired {
|
||||||
|
peers = append(peers, peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||||
|
// If there is no peer that expires this function returns false and a duration of 0.
|
||||||
|
// This function only considers peers that haven't been expired yet and that are connected.
|
||||||
|
func (a *Account) GetNextPeerExpiration() (time.Duration, bool) {
|
||||||
|
|
||||||
|
peersWithExpiry := a.GetPeersWithExpiration()
|
||||||
|
if len(peersWithExpiry) == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
var nextExpiry *time.Duration
|
||||||
|
for _, peer := range peersWithExpiry {
|
||||||
|
// consider only connected peers because others will require login on connecting to the management server
|
||||||
|
if peer.Status.LoginExpired || !peer.Status.Connected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration)
|
||||||
|
if nextExpiry == nil || duration < *nextExpiry {
|
||||||
|
nextExpiry = &duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if nextExpiry == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return *nextExpiry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true
|
||||||
|
func (a *Account) GetPeersWithExpiration() []*Peer {
|
||||||
|
peers := make([]*Peer, 0)
|
||||||
|
for _, peer := range a.Peers {
|
||||||
|
if peer.LoginExpirationEnabled {
|
||||||
|
peers = append(peers, peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
// GetPeers returns a list of all Account peers
|
// GetPeers returns a list of all Account peers
|
||||||
func (a *Account) GetPeers() []*Peer {
|
func (a *Account) GetPeers() []*Peer {
|
||||||
var peers []*Peer
|
var peers []*Peer
|
||||||
@ -550,13 +603,14 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage
|
|||||||
cacheLoading: map[string]chan struct{}{},
|
cacheLoading: map[string]chan struct{}{},
|
||||||
dnsDomain: dnsDomain,
|
dnsDomain: dnsDomain,
|
||||||
eventStore: eventStore,
|
eventStore: eventStore,
|
||||||
|
peerLoginExpiry: NewDefaultScheduler(),
|
||||||
}
|
}
|
||||||
allAccounts := store.GetAllAccounts()
|
allAccounts := store.GetAllAccounts()
|
||||||
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
||||||
am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1
|
am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1
|
||||||
if am.singleAccountMode {
|
if am.singleAccountMode {
|
||||||
if !isDomainValid(singleAccountModeDomain) {
|
if !isDomainValid(singleAccountModeDomain) {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for single accound mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
|
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
|
||||||
}
|
}
|
||||||
am.singleAccountModeDomain = singleAccountModeDomain
|
am.singleAccountModeDomain = singleAccountModeDomain
|
||||||
log.Infof("single account mode enabled, accounts number %d", len(allAccounts))
|
log.Infof("single account mode enabled, accounts number %d", len(allAccounts))
|
||||||
@ -640,12 +694,16 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
|
|||||||
event := activity.AccountPeerLoginExpirationEnabled
|
event := activity.AccountPeerLoginExpirationEnabled
|
||||||
if !newSettings.PeerLoginExpirationEnabled {
|
if !newSettings.PeerLoginExpirationEnabled {
|
||||||
event = activity.AccountPeerLoginExpirationDisabled
|
event = activity.AccountPeerLoginExpirationDisabled
|
||||||
|
am.peerLoginExpiry.Cancel([]string{accountID})
|
||||||
|
} else {
|
||||||
|
am.checkAndSchedulePeerLoginExpiration(account)
|
||||||
}
|
}
|
||||||
am.storeEvent(userID, accountID, accountID, event, nil)
|
am.storeEvent(userID, accountID, accountID, event, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
|
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
|
||||||
am.storeEvent(userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
|
am.storeEvent(userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
|
||||||
|
am.checkAndSchedulePeerLoginExpiration(account)
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedAccount := account.UpdateSettings(newSettings)
|
updatedAccount := account.UpdateSettings(newSettings)
|
||||||
@ -658,6 +716,54 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
|
|||||||
return updatedAccount, nil
|
return updatedAccount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) {
|
||||||
|
return func() (time.Duration, bool) {
|
||||||
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed getting account %s expiring peers", account.Id)
|
||||||
|
return account.GetNextPeerExpiration()
|
||||||
|
}
|
||||||
|
|
||||||
|
var peerIDs []string
|
||||||
|
for _, peer := range account.GetExpiredPeers() {
|
||||||
|
if peer.Status.LoginExpired {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
peerIDs = append(peerIDs, peer.ID)
|
||||||
|
peer.MarkLoginExpired(true)
|
||||||
|
account.UpdatePeer(peer)
|
||||||
|
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
|
||||||
|
return account.GetNextPeerExpiration()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
|
||||||
|
|
||||||
|
if len(peerIDs) != 0 {
|
||||||
|
// this will trigger peer disconnect from the management service
|
||||||
|
am.peersUpdateManager.CloseChannels(peerIDs)
|
||||||
|
err := am.updateAccountPeers(account)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed updating account peers while expiring peers for account %s", accountID)
|
||||||
|
return account.GetNextPeerExpiration()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return account.GetNextPeerExpiration()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(account *Account) {
|
||||||
|
am.peerLoginExpiry.Cancel([]string{account.Id})
|
||||||
|
if nextRun, ok := account.GetNextPeerExpiration(); ok {
|
||||||
|
go am.peerLoginExpiry.Schedule(nextRun, account.Id, am.peerLoginExpirationJob(account.Id))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// newAccount creates a new Account with a generated ID and generated default setup keys.
|
// newAccount creates a new Account with a generated ID and generated default setup keys.
|
||||||
// If ID is already in use (due to collision) we try one more time before returning error
|
// If ID is already in use (due to collision) we try one more time before returning error
|
||||||
func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, error) {
|
func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, error) {
|
||||||
|
@ -1294,6 +1294,147 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
|||||||
assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true)
|
assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true)
|
||||||
assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour)
|
assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour)
|
||||||
}
|
}
|
||||||
|
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||||
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
|
key, err := wgtypes.GenerateKey()
|
||||||
|
require.NoError(t, err, "unable to generate WireGuard key")
|
||||||
|
peer, err := manager.AddPeer("", userID, &Peer{
|
||||||
|
Key: key.PublicKey().String(),
|
||||||
|
Meta: PeerSystemMeta{},
|
||||||
|
Name: "test-peer",
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "unable to add peer")
|
||||||
|
err = manager.MarkPeerConnected(key.PublicKey().String(), true)
|
||||||
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||||
|
PeerLoginExpiration: time.Hour,
|
||||||
|
PeerLoginExpirationEnabled: true})
|
||||||
|
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||||
|
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(2)
|
||||||
|
manager.peerLoginExpiry = &MockScheduler{
|
||||||
|
CancelFunc: func(IDs []string) {
|
||||||
|
wg.Done()
|
||||||
|
},
|
||||||
|
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||||
|
wg.Done()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// disable expiration first
|
||||||
|
update := peer.Copy()
|
||||||
|
update.LoginExpirationEnabled = false
|
||||||
|
_, err = manager.UpdatePeer(account.Id, userID, update)
|
||||||
|
require.NoError(t, err, "unable to update peer")
|
||||||
|
// enabling expiration should trigger the routine
|
||||||
|
update.LoginExpirationEnabled = true
|
||||||
|
_, err = manager.UpdatePeer(account.Id, userID, update)
|
||||||
|
require.NoError(t, err, "unable to update peer")
|
||||||
|
|
||||||
|
failed := waitTimeout(wg, time.Second)
|
||||||
|
if failed {
|
||||||
|
t.Fatal("timeout while waiting for test to finish")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||||
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
|
key, err := wgtypes.GenerateKey()
|
||||||
|
require.NoError(t, err, "unable to generate WireGuard key")
|
||||||
|
_, err = manager.AddPeer("", userID, &Peer{
|
||||||
|
Key: key.PublicKey().String(),
|
||||||
|
Meta: PeerSystemMeta{},
|
||||||
|
Name: "test-peer",
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "unable to add peer")
|
||||||
|
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||||
|
PeerLoginExpiration: time.Hour,
|
||||||
|
PeerLoginExpirationEnabled: true})
|
||||||
|
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||||
|
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(2)
|
||||||
|
manager.peerLoginExpiry = &MockScheduler{
|
||||||
|
CancelFunc: func(IDs []string) {
|
||||||
|
wg.Done()
|
||||||
|
},
|
||||||
|
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||||
|
wg.Done()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||||
|
err = manager.MarkPeerConnected(key.PublicKey().String(), true)
|
||||||
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
|
failed := waitTimeout(wg, time.Second)
|
||||||
|
if failed {
|
||||||
|
t.Fatal("timeout while waiting for test to finish")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||||
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
|
key, err := wgtypes.GenerateKey()
|
||||||
|
require.NoError(t, err, "unable to generate WireGuard key")
|
||||||
|
_, err = manager.AddPeer("", userID, &Peer{
|
||||||
|
Key: key.PublicKey().String(),
|
||||||
|
Meta: PeerSystemMeta{},
|
||||||
|
Name: "test-peer",
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "unable to add peer")
|
||||||
|
err = manager.MarkPeerConnected(key.PublicKey().String(), true)
|
||||||
|
require.NoError(t, err, "unable to mark peer connected")
|
||||||
|
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(2)
|
||||||
|
manager.peerLoginExpiry = &MockScheduler{
|
||||||
|
CancelFunc: func(IDs []string) {
|
||||||
|
wg.Done()
|
||||||
|
},
|
||||||
|
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||||
|
wg.Done()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// enabling PeerLoginExpirationEnabled should trigger the expiration job
|
||||||
|
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||||
|
PeerLoginExpiration: time.Hour,
|
||||||
|
PeerLoginExpirationEnabled: true})
|
||||||
|
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||||
|
|
||||||
|
failed := waitTimeout(wg, time.Second)
|
||||||
|
if failed {
|
||||||
|
t.Fatal("timeout while waiting for test to finish")
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
// disabling PeerLoginExpirationEnabled should trigger cancel
|
||||||
|
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||||
|
PeerLoginExpiration: time.Hour,
|
||||||
|
PeerLoginExpirationEnabled: false})
|
||||||
|
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||||
|
failed = waitTimeout(wg, time.Second)
|
||||||
|
if failed {
|
||||||
|
t.Fatal("timeout while waiting for test to finish")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
@ -1326,6 +1467,286 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
|||||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
|
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
peers map[string]*Peer
|
||||||
|
expectedPeers map[string]struct{}
|
||||||
|
}
|
||||||
|
testCases := []test{
|
||||||
|
{
|
||||||
|
name: "Peers with login expiration disabled, no expired peers",
|
||||||
|
peers: map[string]*Peer{
|
||||||
|
"peer-1": {
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedPeers: map[string]struct{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Two peers expired",
|
||||||
|
peers: map[string]*Peer{
|
||||||
|
"peer-1": {
|
||||||
|
ID: "peer-1",
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
Status: &PeerStatus{
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
LastLogin: time.Now().Add(-30 * time.Minute),
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
ID: "peer-2",
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
Status: &PeerStatus{
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
LastLogin: time.Now().Add(-2 * time.Hour),
|
||||||
|
},
|
||||||
|
|
||||||
|
"peer-3": {
|
||||||
|
ID: "peer-3",
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
Status: &PeerStatus{
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
LastLogin: time.Now().Add(-1 * time.Hour),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedPeers: map[string]struct{}{
|
||||||
|
"peer-2": {},
|
||||||
|
"peer-3": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Peers: testCase.peers,
|
||||||
|
Settings: &Settings{
|
||||||
|
PeerLoginExpirationEnabled: true,
|
||||||
|
PeerLoginExpiration: time.Hour,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expiredPeers := account.GetExpiredPeers()
|
||||||
|
assert.Len(t, expiredPeers, len(testCase.expectedPeers))
|
||||||
|
for _, peer := range expiredPeers {
|
||||||
|
if _, ok := testCase.expectedPeers[peer.ID]; !ok {
|
||||||
|
t.Fatalf("expected to have peer %s expired", peer.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_GetPeersWithExpiration(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
peers map[string]*Peer
|
||||||
|
expectedPeers map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []test{
|
||||||
|
{
|
||||||
|
name: "No account peers, no peers with expiration",
|
||||||
|
peers: map[string]*Peer{},
|
||||||
|
expectedPeers: map[string]struct{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peers with login expiration disabled, no peers with expiration",
|
||||||
|
peers: map[string]*Peer{
|
||||||
|
"peer-1": {
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedPeers: map[string]struct{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Peers with login expiration enabled, return peers with expiration",
|
||||||
|
peers: map[string]*Peer{
|
||||||
|
"peer-1": {
|
||||||
|
ID: "peer-1",
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedPeers: map[string]struct{}{
|
||||||
|
"peer-1": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Peers: testCase.peers,
|
||||||
|
}
|
||||||
|
|
||||||
|
actual := account.GetPeersWithExpiration()
|
||||||
|
assert.Len(t, actual, len(testCase.expectedPeers))
|
||||||
|
if len(testCase.expectedPeers) > 0 {
|
||||||
|
for k := range testCase.expectedPeers {
|
||||||
|
contains := false
|
||||||
|
for _, peer := range actual {
|
||||||
|
if k == peer.ID {
|
||||||
|
contains = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, contains)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
peers map[string]*Peer
|
||||||
|
expiration time.Duration
|
||||||
|
expirationEnabled bool
|
||||||
|
expectedNextRun bool
|
||||||
|
expectedNextExpiration time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedNextExpiration := time.Minute
|
||||||
|
testCases := []test{
|
||||||
|
{
|
||||||
|
name: "No peers, no expiration",
|
||||||
|
peers: map[string]*Peer{},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No connected peers, no expiration",
|
||||||
|
peers: map[string]*Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &PeerStatus{
|
||||||
|
Connected: false,
|
||||||
|
},
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
},
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Connected peers with disabled expiration, no expiration",
|
||||||
|
peers: map[string]*Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
},
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
},
|
||||||
|
LoginExpirationEnabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Expired peers, no expiration",
|
||||||
|
peers: map[string]*Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: true,
|
||||||
|
},
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: true,
|
||||||
|
},
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Second,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: false,
|
||||||
|
expectedNextExpiration: time.Duration(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To be expired peer, return expiration",
|
||||||
|
peers: map[string]*Peer{
|
||||||
|
"peer-1": {
|
||||||
|
Status: &PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: false,
|
||||||
|
},
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
LastLogin: time.Now(),
|
||||||
|
},
|
||||||
|
"peer-2": {
|
||||||
|
Status: &PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LoginExpired: true,
|
||||||
|
},
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expiration: time.Minute,
|
||||||
|
expirationEnabled: false,
|
||||||
|
expectedNextRun: true,
|
||||||
|
expectedNextExpiration: expectedNextExpiration,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Peers: testCase.peers,
|
||||||
|
Settings: &Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled},
|
||||||
|
}
|
||||||
|
|
||||||
|
expiration, ok := account.GetNextPeerExpiration()
|
||||||
|
assert.Equal(t, ok, testCase.expectedNextRun)
|
||||||
|
if testCase.expectedNextRun {
|
||||||
|
assert.True(t, expiration >= 0 && expiration <= testCase.expectedNextExpiration)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, expiration, testCase.expectedNextExpiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||||
store, err := createStore(t)
|
store, err := createStore(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1344,3 +1765,17 @@ func createStore(t *testing.T) (Store, error) {
|
|||||||
|
|
||||||
return store, nil
|
return store, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
|
||||||
|
c := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(c)
|
||||||
|
wg.Wait()
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-c:
|
||||||
|
return false
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -136,7 +136,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Error(codes.Internal, "internal server error")
|
return status.Error(codes.Internal, "internal server error")
|
||||||
}
|
}
|
||||||
expired, left := peer.LoginExpired(account.Settings)
|
expired, left := peer.LoginExpired(account.Settings.PeerLoginExpiration)
|
||||||
|
expired = account.Settings.PeerLoginExpirationEnabled && expired
|
||||||
if peer.UserID != "" && (expired || peer.Status.LoginExpired) {
|
if peer.UserID != "" && (expired || peer.Status.LoginExpired) {
|
||||||
err = s.accountManager.MarkPeerLoginExpired(peerKey.String(), true)
|
err = s.accountManager.MarkPeerLoginExpired(peerKey.String(), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -380,7 +381,9 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.Internal, "internal server error")
|
return nil, status.Error(codes.Internal, "internal server error")
|
||||||
}
|
}
|
||||||
expired, left := peer.LoginExpired(account.Settings)
|
|
||||||
|
expired, left := peer.LoginExpired(account.Settings.PeerLoginExpiration)
|
||||||
|
expired = account.Settings.PeerLoginExpirationEnabled && expired
|
||||||
if peer.UserID != "" && (expired || peer.Status.LoginExpired) {
|
if peer.UserID != "" && (expired || peer.Status.LoginExpired) {
|
||||||
// it might be that peer expired but user has logged in already, check token then
|
// it might be that peer expired but user has logged in already, check token then
|
||||||
if loginReq.GetJwtToken() == "" {
|
if loginReq.GetJwtToken() == "" {
|
||||||
|
@ -93,17 +93,24 @@ func (p *Peer) Copy() *Peer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarkLoginExpired marks peer's status expired or not
|
||||||
|
func (p *Peer) MarkLoginExpired(expired bool) {
|
||||||
|
newStatus := p.Status.Copy()
|
||||||
|
newStatus.LastSeen = time.Now()
|
||||||
|
newStatus.LoginExpired = expired
|
||||||
|
p.Status = newStatus
|
||||||
|
}
|
||||||
|
|
||||||
// LoginExpired indicates whether the peer's login has expired or not.
|
// LoginExpired indicates whether the peer's login has expired or not.
|
||||||
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
|
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
|
||||||
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).
|
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).
|
||||||
// Login expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
|
// Login expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
|
||||||
// Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled
|
// Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled.
|
||||||
// and if disabled on the Account level, then Peer.LoginExpirationEnabled is ineffective.
|
func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) {
|
||||||
func (p *Peer) LoginExpired(accountSettings *Settings) (bool, time.Duration) {
|
expiresAt := p.LastLogin.Add(expiresIn)
|
||||||
expiresAt := p.LastLogin.Add(accountSettings.PeerLoginExpiration)
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
timeLeft := expiresAt.Sub(now)
|
timeLeft := expiresAt.Sub(now)
|
||||||
return accountSettings.PeerLoginExpirationEnabled && p.LoginExpirationEnabled && (timeLeft <= 0), timeLeft
|
return p.LoginExpirationEnabled && (timeLeft <= 0), timeLeft
|
||||||
}
|
}
|
||||||
|
|
||||||
// FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain
|
// FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain
|
||||||
@ -202,13 +209,10 @@ func (am *DefaultAccountManager) MarkPeerLoginExpired(peerPubKey string, loginEx
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
newStatus := peer.Status.Copy()
|
peer.MarkLoginExpired(loginExpired)
|
||||||
newStatus.LastSeen = time.Now()
|
|
||||||
newStatus.LoginExpired = loginExpired
|
|
||||||
peer.Status = newStatus
|
|
||||||
account.UpdatePeer(peer)
|
account.UpdatePeer(peer)
|
||||||
|
|
||||||
err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -237,7 +241,8 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
newStatus := peer.Status.Copy()
|
oldStatus := peer.Status.Copy()
|
||||||
|
newStatus := oldStatus
|
||||||
newStatus.LastSeen = time.Now()
|
newStatus.LastSeen = time.Now()
|
||||||
newStatus.Connected = connected
|
newStatus.Connected = connected
|
||||||
// whenever peer got connected that means that it logged in successfully
|
// whenever peer got connected that means that it logged in successfully
|
||||||
@ -251,6 +256,20 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||||
|
am.checkAndSchedulePeerLoginExpiration(account)
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldStatus.LoginExpired {
|
||||||
|
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||||
|
//the expired one. Here we notify them that connection is now allowed again.
|
||||||
|
err = am.updateAccountPeers(account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -307,6 +326,10 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *Pe
|
|||||||
event = activity.PeerLoginExpirationDisabled
|
event = activity.PeerLoginExpirationDisabled
|
||||||
}
|
}
|
||||||
am.storeEvent(userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
am.storeEvent(userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||||
|
|
||||||
|
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||||
|
am.checkAndSchedulePeerLoginExpiration(account)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
account.UpdatePeer(peer)
|
account.UpdatePeer(peer)
|
||||||
@ -529,7 +552,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
|||||||
SSHEnabled: false,
|
SSHEnabled: false,
|
||||||
SSHKey: peer.SSHKey,
|
SSHKey: peer.SSHKey,
|
||||||
LastLogin: time.Now(),
|
LastLogin: time.Now(),
|
||||||
LoginExpirationEnabled: false,
|
LoginExpirationEnabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// add peer to 'All' group
|
// add peer to 'All' group
|
||||||
@ -775,6 +798,10 @@ func (a *Account) getPeersByACL(peerID string) []*Peer {
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration)
|
||||||
|
if expired {
|
||||||
|
continue
|
||||||
|
}
|
||||||
// exclude original peer
|
// exclude original peer
|
||||||
if _, ok := peersSet[peer.ID]; peer.ID != peerID && !ok {
|
if _, ok := peersSet[peer.ID]; peer.ID != peerID && !ok {
|
||||||
peersSet[peer.ID] = struct{}{}
|
peersSet[peer.ID] = struct{}{}
|
||||||
|
@ -57,7 +57,7 @@ func TestPeer_LoginExpired(t *testing.T) {
|
|||||||
LastLogin: c.lastLogin,
|
LastLogin: c.lastLogin,
|
||||||
}
|
}
|
||||||
|
|
||||||
expired, _ := peer.LoginExpired(c.accountSettings)
|
expired, _ := peer.LoginExpired(c.accountSettings.PeerLoginExpiration)
|
||||||
assert.Equal(t, expired, c.expected)
|
assert.Equal(t, expired, c.expected)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
114
management/server/scheduler.go
Normal file
114
management/server/scheduler.go
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Scheduler is an interface which implementations can schedule and cancel jobs
|
||||||
|
type Scheduler interface {
|
||||||
|
Cancel(IDs []string)
|
||||||
|
Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockScheduler is a mock implementation of Scheduler
|
||||||
|
type MockScheduler struct {
|
||||||
|
CancelFunc func(IDs []string)
|
||||||
|
ScheduleFunc func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel mocks the Cancel function of the Scheduler interface
|
||||||
|
func (mock *MockScheduler) Cancel(IDs []string) {
|
||||||
|
if mock.CancelFunc != nil {
|
||||||
|
mock.CancelFunc(IDs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("MockScheduler doesn't have Cancel function defined ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schedule mocks the Schedule function of the Scheduler interface
|
||||||
|
func (mock *MockScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||||
|
if mock.ScheduleFunc != nil {
|
||||||
|
mock.ScheduleFunc(in, ID, job)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("MockScheduler doesn't have Schedule function defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them.
|
||||||
|
type DefaultScheduler struct {
|
||||||
|
// jobs map holds cancellation channels indexed by the job ID
|
||||||
|
jobs map[string]chan struct{}
|
||||||
|
mu *sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultScheduler creates an instance of a DefaultScheduler
|
||||||
|
func NewDefaultScheduler() *DefaultScheduler {
|
||||||
|
return &DefaultScheduler{
|
||||||
|
jobs: make(map[string]chan struct{}),
|
||||||
|
mu: &sync.Mutex{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wm *DefaultScheduler) cancel(ID string) bool {
|
||||||
|
cancel, ok := wm.jobs[ID]
|
||||||
|
if ok {
|
||||||
|
delete(wm.jobs, ID)
|
||||||
|
select {
|
||||||
|
case cancel <- struct{}{}:
|
||||||
|
log.Debugf("cancelled scheduled job %s", ID)
|
||||||
|
default:
|
||||||
|
log.Warnf("couldn't cancel job %s because there was no routine listening on the cancel event", ID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel cancels the scheduled job by ID if present.
|
||||||
|
// If job wasn't found the function returns false.
|
||||||
|
func (wm *DefaultScheduler) Cancel(IDs []string) {
|
||||||
|
wm.mu.Lock()
|
||||||
|
defer wm.mu.Unlock()
|
||||||
|
|
||||||
|
for _, id := range IDs {
|
||||||
|
wm.cancel(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schedule a job to run in some time in the future. If job returns true then it will be scheduled one more time.
|
||||||
|
// If job with the provided ID already exists, a new one won't be scheduled.
|
||||||
|
func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||||
|
wm.mu.Lock()
|
||||||
|
defer wm.mu.Unlock()
|
||||||
|
cancel := make(chan struct{})
|
||||||
|
if _, ok := wm.jobs[ID]; ok {
|
||||||
|
log.Debugf("couldn't schedule a job %s because it already exists. There are %d total jobs scheduled.",
|
||||||
|
ID, len(wm.jobs))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wm.jobs[ID] = cancel
|
||||||
|
log.Debugf("scheduled a job %s to run in %s. There are %d total jobs scheduled.", ID, in.String(), len(wm.jobs))
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-time.After(in):
|
||||||
|
log.Debugf("time to do a scheduled job %s", ID)
|
||||||
|
runIn, reschedule := job()
|
||||||
|
wm.mu.Lock()
|
||||||
|
defer wm.mu.Unlock()
|
||||||
|
delete(wm.jobs, ID)
|
||||||
|
if reschedule {
|
||||||
|
go wm.Schedule(runIn, ID, job)
|
||||||
|
}
|
||||||
|
case <-cancel:
|
||||||
|
log.Debugf("stopped scheduled job %s ", ID)
|
||||||
|
wm.mu.Lock()
|
||||||
|
defer wm.mu.Unlock()
|
||||||
|
delete(wm.jobs, ID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
94
management/server/scheduler_test.go
Normal file
94
management/server/scheduler_test.go
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestScheduler_Performance(t *testing.T) {
|
||||||
|
scheduler := NewDefaultScheduler()
|
||||||
|
n := 500
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(n)
|
||||||
|
maxMs := 500
|
||||||
|
minMs := 50
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
millis := time.Duration(rand.Intn(maxMs-minMs)+minMs) * time.Millisecond
|
||||||
|
go scheduler.Schedule(millis, fmt.Sprintf("test-scheduler-job-%d", i), func() (nextRunIn time.Duration, reschedule bool) {
|
||||||
|
time.Sleep(millis)
|
||||||
|
wg.Done()
|
||||||
|
return 0, false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
failed := waitTimeout(wg, 3*time.Second)
|
||||||
|
if failed {
|
||||||
|
t.Fatal("timed out while waiting for test to finish")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.Len(t, scheduler.jobs, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScheduler_Cancel(t *testing.T) {
|
||||||
|
jobID1 := "test-scheduler-job-1"
|
||||||
|
jobID2 := "test-scheduler-job-2"
|
||||||
|
scheduler := NewDefaultScheduler()
|
||||||
|
scheduler.Schedule(2*time.Second, jobID1, func() (nextRunIn time.Duration, reschedule bool) {
|
||||||
|
return 0, false
|
||||||
|
})
|
||||||
|
scheduler.Schedule(2*time.Second, jobID2, func() (nextRunIn time.Duration, reschedule bool) {
|
||||||
|
return 0, false
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Len(t, scheduler.jobs, 2)
|
||||||
|
scheduler.Cancel([]string{jobID1})
|
||||||
|
assert.Len(t, scheduler.jobs, 1)
|
||||||
|
assert.NotNil(t, scheduler.jobs[jobID2])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScheduler_Schedule(t *testing.T) {
|
||||||
|
jobID := "test-scheduler-job-1"
|
||||||
|
scheduler := NewDefaultScheduler()
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
// job without reschedule should be triggered once
|
||||||
|
job := func() (nextRunIn time.Duration, reschedule bool) {
|
||||||
|
wg.Done()
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
scheduler.Schedule(300*time.Millisecond, jobID, job)
|
||||||
|
failed := waitTimeout(wg, time.Second)
|
||||||
|
if failed {
|
||||||
|
t.Fatal("timed out while waiting for test to finish")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// job with reschedule should be triggered at least twice
|
||||||
|
wg = &sync.WaitGroup{}
|
||||||
|
mx := &sync.Mutex{}
|
||||||
|
scheduledTimes := 0
|
||||||
|
wg.Add(2)
|
||||||
|
job = func() (nextRunIn time.Duration, reschedule bool) {
|
||||||
|
mx.Lock()
|
||||||
|
defer mx.Unlock()
|
||||||
|
// ensure we repeat only twice
|
||||||
|
if scheduledTimes < 2 {
|
||||||
|
wg.Done()
|
||||||
|
scheduledTimes++
|
||||||
|
return 300 * time.Millisecond, true
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
scheduler.Schedule(300*time.Millisecond, jobID, job)
|
||||||
|
failed = waitTimeout(wg, time.Second)
|
||||||
|
if failed {
|
||||||
|
t.Fatal("timed out while waiting for test to finish")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
scheduler.cancel(jobID)
|
||||||
|
|
||||||
|
}
|
@ -115,6 +115,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) {
|
|||||||
Turns: turns,
|
Turns: turns,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
log.Debugf("sending new TURN credentials to peer %s", peerID)
|
||||||
err := m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update})
|
err := m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error while sending TURN update to peer %s %v", peerID, err)
|
log.Errorf("error while sending TURN update to peer %s %v", peerID, err)
|
||||||
|
@ -60,10 +60,7 @@ func (p *PeersUpdateManager) CreateChannel(peerID string) chan *UpdateMessage {
|
|||||||
return channel
|
return channel
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseChannel closes updates channel of a given peer
|
func (p *PeersUpdateManager) closeChannel(peerID string) {
|
||||||
func (p *PeersUpdateManager) CloseChannel(peerID string) {
|
|
||||||
p.channelsMux.Lock()
|
|
||||||
defer p.channelsMux.Unlock()
|
|
||||||
if channel, ok := p.peerChannels[peerID]; ok {
|
if channel, ok := p.peerChannels[peerID]; ok {
|
||||||
delete(p.peerChannels, peerID)
|
delete(p.peerChannels, peerID)
|
||||||
close(channel)
|
close(channel)
|
||||||
@ -72,6 +69,22 @@ func (p *PeersUpdateManager) CloseChannel(peerID string) {
|
|||||||
log.Debugf("closed updates channel of a peer %s", peerID)
|
log.Debugf("closed updates channel of a peer %s", peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CloseChannels closes updates channel for each given peer
|
||||||
|
func (p *PeersUpdateManager) CloseChannels(peerIDs []string) {
|
||||||
|
p.channelsMux.Lock()
|
||||||
|
defer p.channelsMux.Unlock()
|
||||||
|
for _, id := range peerIDs {
|
||||||
|
p.closeChannel(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseChannel closes updates channel of a given peer
|
||||||
|
func (p *PeersUpdateManager) CloseChannel(peerID string) {
|
||||||
|
p.channelsMux.Lock()
|
||||||
|
defer p.channelsMux.Unlock()
|
||||||
|
p.closeChannel(peerID)
|
||||||
|
}
|
||||||
|
|
||||||
// GetAllConnectedPeers returns a copy of the connected peers map
|
// GetAllConnectedPeers returns a copy of the connected peers map
|
||||||
func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
|
func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
|
||||||
p.channelsMux.Lock()
|
p.channelsMux.Lock()
|
||||||
|
Loading…
Reference in New Issue
Block a user