Proactively expire peers' login per account (#698)

Goals:

Enable peer login expiration when adding new peer
Expire peer's login when the time comes
The account manager triggers peer expiration routine in future if the
following conditions are true:

peer expiration is enabled for the account
there is at least one peer that has expiration enabled and is connected
The time of the next expiration check is based on the nearest peer expiration.
Account manager finds a peer with the oldest last login (auth) timestamp and
calculates the time when it has to run the routine as a sum of the configured
peer login expiration duration and the peer's last login time.

When triggered, the expiration routine checks whether there are expired peers.
The management server closes the update channel of these peers and updates
network map of other peers to exclude expired peers so that the expired peers
are not able to connect anywhere.

The account manager can reschedule or cancel peer expiration in the following cases:

when admin changes account setting (peer expiration enable/disable)
when admin updates the expiration duration of the account
when admin updates peer expiration (enable/disable)
when peer connects (Sync)
P.S. The network map calculation was updated to exclude peers that have login expired.
This commit is contained in:
Misha Bragin 2023-02-27 16:44:26 +01:00 committed by GitHub
parent 9f951c8fb5
commit f984b8a091
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 814 additions and 21 deletions

View File

@ -119,7 +119,8 @@ type DefaultAccountManager struct {
// singleAccountModeDomain is a domain to use in singleAccountMode setup // singleAccountModeDomain is a domain to use in singleAccountMode setup
singleAccountModeDomain string singleAccountModeDomain string
// dnsDomain is used for peer resolution. This is appended to the peer's name // dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string dnsDomain string
peerLoginExpiry Scheduler
} }
// Settings represents Account settings structure that can be modified via API and Dashboard // Settings represents Account settings structure that can be modified via API and Dashboard
@ -307,6 +308,58 @@ func (a *Account) GetGroup(groupID string) *Group {
return a.Groups[groupID] return a.Groups[groupID]
} }
// GetExpiredPeers returns peers that have been expired
func (a *Account) GetExpiredPeers() []*Peer {
var peers []*Peer
for _, peer := range a.GetPeersWithExpiration() {
expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration)
if expired {
peers = append(peers, peer)
}
}
return peers
}
// GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are connected.
func (a *Account) GetNextPeerExpiration() (time.Duration, bool) {
peersWithExpiry := a.GetPeersWithExpiration()
if len(peersWithExpiry) == 0 {
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithExpiry {
// consider only connected peers because others will require login on connecting to the management server
if peer.Status.LoginExpired || !peer.Status.Connected {
continue
}
_, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration)
if nextExpiry == nil || duration < *nextExpiry {
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true
func (a *Account) GetPeersWithExpiration() []*Peer {
peers := make([]*Peer, 0)
for _, peer := range a.Peers {
if peer.LoginExpirationEnabled {
peers = append(peers, peer)
}
}
return peers
}
// GetPeers returns a list of all Account peers // GetPeers returns a list of all Account peers
func (a *Account) GetPeers() []*Peer { func (a *Account) GetPeers() []*Peer {
var peers []*Peer var peers []*Peer
@ -550,13 +603,14 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage
cacheLoading: map[string]chan struct{}{}, cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain, dnsDomain: dnsDomain,
eventStore: eventStore, eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(),
} }
allAccounts := store.GetAllAccounts() allAccounts := store.GetAllAccounts()
// enable single account mode only if configured by user and number of existing accounts is not grater than 1 // enable single account mode only if configured by user and number of existing accounts is not grater than 1
am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1 am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1
if am.singleAccountMode { if am.singleAccountMode {
if !isDomainValid(singleAccountModeDomain) { if !isDomainValid(singleAccountModeDomain) {
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for single accound mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
} }
am.singleAccountModeDomain = singleAccountModeDomain am.singleAccountModeDomain = singleAccountModeDomain
log.Infof("single account mode enabled, accounts number %d", len(allAccounts)) log.Infof("single account mode enabled, accounts number %d", len(allAccounts))
@ -640,12 +694,16 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
event := activity.AccountPeerLoginExpirationEnabled event := activity.AccountPeerLoginExpirationEnabled
if !newSettings.PeerLoginExpirationEnabled { if !newSettings.PeerLoginExpirationEnabled {
event = activity.AccountPeerLoginExpirationDisabled event = activity.AccountPeerLoginExpirationDisabled
am.peerLoginExpiry.Cancel([]string{accountID})
} else {
am.checkAndSchedulePeerLoginExpiration(account)
} }
am.storeEvent(userID, accountID, accountID, event, nil) am.storeEvent(userID, accountID, accountID, event, nil)
} }
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
am.storeEvent(userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) am.storeEvent(userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(account)
} }
updatedAccount := account.UpdateSettings(newSettings) updatedAccount := account.UpdateSettings(newSettings)
@ -658,6 +716,54 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
return updatedAccount, nil return updatedAccount, nil
} }
func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
log.Errorf("failed getting account %s expiring peers", account.Id)
return account.GetNextPeerExpiration()
}
var peerIDs []string
for _, peer := range account.GetExpiredPeers() {
if peer.Status.LoginExpired {
continue
}
peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
if err != nil {
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
return account.GetNextPeerExpiration()
}
}
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(peerIDs)
err := am.updateAccountPeers(account)
if err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", accountID)
return account.GetNextPeerExpiration()
}
}
return account.GetNextPeerExpiration()
}
}
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(account *Account) {
am.peerLoginExpiry.Cancel([]string{account.Id})
if nextRun, ok := account.GetNextPeerExpiration(); ok {
go am.peerLoginExpiry.Schedule(nextRun, account.Id, am.peerLoginExpirationJob(account.Id))
}
}
// newAccount creates a new Account with a generated ID and generated default setup keys. // newAccount creates a new Account with a generated ID and generated default setup keys.
// If ID is already in use (due to collision) we try one more time before returning error // If ID is already in use (due to collision) we try one more time before returning error
func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, error) { func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, error) {

View File

@ -1294,6 +1294,147 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true) assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true)
assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour) assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour)
} }
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer, err := manager.AddPeer("", userID, &Peer{
Key: key.PublicKey().String(),
Meta: PeerSystemMeta{},
Name: "test-peer",
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
err = manager.MarkPeerConnected(key.PublicKey().String(), true)
require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true})
require.NoError(t, err, "expecting to update account settings successfully but got error")
wg := &sync.WaitGroup{}
wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(IDs []string) {
wg.Done()
},
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done()
},
}
// disable expiration first
update := peer.Copy()
update.LoginExpirationEnabled = false
_, err = manager.UpdatePeer(account.Id, userID, update)
require.NoError(t, err, "unable to update peer")
// enabling expiration should trigger the routine
update.LoginExpirationEnabled = true
_, err = manager.UpdatePeer(account.Id, userID, update)
require.NoError(t, err, "unable to update peer")
failed := waitTimeout(wg, time.Second)
if failed {
t.Fatal("timeout while waiting for test to finish")
}
}
func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, err = manager.AddPeer("", userID, &Peer{
Key: key.PublicKey().String(),
Meta: PeerSystemMeta{},
Name: "test-peer",
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true})
require.NoError(t, err, "expecting to update account settings successfully but got error")
wg := &sync.WaitGroup{}
wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(IDs []string) {
wg.Done()
},
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done()
},
}
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(key.PublicKey().String(), true)
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
if failed {
t.Fatal("timeout while waiting for test to finish")
}
}
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, err = manager.AddPeer("", userID, &Peer{
Key: key.PublicKey().String(),
Meta: PeerSystemMeta{},
Name: "test-peer",
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
err = manager.MarkPeerConnected(key.PublicKey().String(), true)
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}
wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(IDs []string) {
wg.Done()
},
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done()
},
}
// enabling PeerLoginExpirationEnabled should trigger the expiration job
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true})
require.NoError(t, err, "expecting to update account settings successfully but got error")
failed := waitTimeout(wg, time.Second)
if failed {
t.Fatal("timeout while waiting for test to finish")
}
wg.Add(1)
// disabling PeerLoginExpirationEnabled should trigger cancel
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false})
require.NoError(t, err, "expecting to update account settings successfully but got error")
failed = waitTimeout(wg, time.Second)
if failed {
t.Fatal("timeout while waiting for test to finish")
}
}
func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
@ -1326,6 +1467,286 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
} }
func TestAccount_GetExpiredPeers(t *testing.T) {
type test struct {
name string
peers map[string]*Peer
expectedPeers map[string]struct{}
}
testCases := []test{
{
name: "Peers with login expiration disabled, no expired peers",
peers: map[string]*Peer{
"peer-1": {
LoginExpirationEnabled: false,
},
"peer-2": {
LoginExpirationEnabled: false,
},
},
expectedPeers: map[string]struct{}{},
},
{
name: "Two peers expired",
peers: map[string]*Peer{
"peer-1": {
ID: "peer-1",
LoginExpirationEnabled: true,
Status: &PeerStatus{
LastSeen: time.Now(),
Connected: true,
LoginExpired: false,
},
LastLogin: time.Now().Add(-30 * time.Minute),
},
"peer-2": {
ID: "peer-2",
LoginExpirationEnabled: true,
Status: &PeerStatus{
LastSeen: time.Now(),
Connected: true,
LoginExpired: false,
},
LastLogin: time.Now().Add(-2 * time.Hour),
},
"peer-3": {
ID: "peer-3",
LoginExpirationEnabled: true,
Status: &PeerStatus{
LastSeen: time.Now(),
Connected: true,
LoginExpired: false,
},
LastLogin: time.Now().Add(-1 * time.Hour),
},
},
expectedPeers: map[string]struct{}{
"peer-2": {},
"peer-3": {},
},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
Peers: testCase.peers,
Settings: &Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
}
expiredPeers := account.GetExpiredPeers()
assert.Len(t, expiredPeers, len(testCase.expectedPeers))
for _, peer := range expiredPeers {
if _, ok := testCase.expectedPeers[peer.ID]; !ok {
t.Fatalf("expected to have peer %s expired", peer.ID)
}
}
})
}
}
func TestAccount_GetPeersWithExpiration(t *testing.T) {
type test struct {
name string
peers map[string]*Peer
expectedPeers map[string]struct{}
}
testCases := []test{
{
name: "No account peers, no peers with expiration",
peers: map[string]*Peer{},
expectedPeers: map[string]struct{}{},
},
{
name: "Peers with login expiration disabled, no peers with expiration",
peers: map[string]*Peer{
"peer-1": {
LoginExpirationEnabled: false,
},
"peer-2": {
LoginExpirationEnabled: false,
},
},
expectedPeers: map[string]struct{}{},
},
{
name: "Peers with login expiration enabled, return peers with expiration",
peers: map[string]*Peer{
"peer-1": {
ID: "peer-1",
LoginExpirationEnabled: true,
},
"peer-2": {
LoginExpirationEnabled: false,
},
},
expectedPeers: map[string]struct{}{
"peer-1": {},
},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
Peers: testCase.peers,
}
actual := account.GetPeersWithExpiration()
assert.Len(t, actual, len(testCase.expectedPeers))
if len(testCase.expectedPeers) > 0 {
for k := range testCase.expectedPeers {
contains := false
for _, peer := range actual {
if k == peer.ID {
contains = true
}
}
assert.True(t, contains)
}
}
})
}
}
func TestAccount_GetNextPeerExpiration(t *testing.T) {
type test struct {
name string
peers map[string]*Peer
expiration time.Duration
expirationEnabled bool
expectedNextRun bool
expectedNextExpiration time.Duration
}
expectedNextExpiration := time.Minute
testCases := []test{
{
name: "No peers, no expiration",
peers: map[string]*Peer{},
expiration: time.Second,
expirationEnabled: false,
expectedNextRun: false,
expectedNextExpiration: time.Duration(0),
},
{
name: "No connected peers, no expiration",
peers: map[string]*Peer{
"peer-1": {
Status: &PeerStatus{
Connected: false,
},
LoginExpirationEnabled: true,
},
"peer-2": {
Status: &PeerStatus{
Connected: true,
},
LoginExpirationEnabled: false,
},
},
expiration: time.Second,
expirationEnabled: false,
expectedNextRun: false,
expectedNextExpiration: time.Duration(0),
},
{
name: "Connected peers with disabled expiration, no expiration",
peers: map[string]*Peer{
"peer-1": {
Status: &PeerStatus{
Connected: true,
},
LoginExpirationEnabled: false,
},
"peer-2": {
Status: &PeerStatus{
Connected: true,
},
LoginExpirationEnabled: false,
},
},
expiration: time.Second,
expirationEnabled: false,
expectedNextRun: false,
expectedNextExpiration: time.Duration(0),
},
{
name: "Expired peers, no expiration",
peers: map[string]*Peer{
"peer-1": {
Status: &PeerStatus{
Connected: true,
LoginExpired: true,
},
LoginExpirationEnabled: true,
},
"peer-2": {
Status: &PeerStatus{
Connected: true,
LoginExpired: true,
},
LoginExpirationEnabled: true,
},
},
expiration: time.Second,
expirationEnabled: false,
expectedNextRun: false,
expectedNextExpiration: time.Duration(0),
},
{
name: "To be expired peer, return expiration",
peers: map[string]*Peer{
"peer-1": {
Status: &PeerStatus{
Connected: true,
LoginExpired: false,
},
LoginExpirationEnabled: true,
LastLogin: time.Now(),
},
"peer-2": {
Status: &PeerStatus{
Connected: true,
LoginExpired: true,
},
LoginExpirationEnabled: true,
},
},
expiration: time.Minute,
expirationEnabled: false,
expectedNextRun: true,
expectedNextExpiration: expectedNextExpiration,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
Peers: testCase.peers,
Settings: &Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled},
}
expiration, ok := account.GetNextPeerExpiration()
assert.Equal(t, ok, testCase.expectedNextRun)
if testCase.expectedNextRun {
assert.True(t, expiration >= 0 && expiration <= testCase.expectedNextExpiration)
} else {
assert.Equal(t, expiration, testCase.expectedNextExpiration)
}
})
}
}
func createManager(t *testing.T) (*DefaultAccountManager, error) { func createManager(t *testing.T) (*DefaultAccountManager, error) {
store, err := createStore(t) store, err := createStore(t)
if err != nil { if err != nil {
@ -1344,3 +1765,17 @@ func createStore(t *testing.T) (Store, error) {
return store, nil return store, nil
} }
func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
c := make(chan struct{})
go func() {
defer close(c)
wg.Wait()
}()
select {
case <-c:
return false
case <-time.After(timeout):
return true
}
}

View File

@ -136,7 +136,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
if err != nil { if err != nil {
return status.Error(codes.Internal, "internal server error") return status.Error(codes.Internal, "internal server error")
} }
expired, left := peer.LoginExpired(account.Settings) expired, left := peer.LoginExpired(account.Settings.PeerLoginExpiration)
expired = account.Settings.PeerLoginExpirationEnabled && expired
if peer.UserID != "" && (expired || peer.Status.LoginExpired) { if peer.UserID != "" && (expired || peer.Status.LoginExpired) {
err = s.accountManager.MarkPeerLoginExpired(peerKey.String(), true) err = s.accountManager.MarkPeerLoginExpired(peerKey.String(), true)
if err != nil { if err != nil {
@ -380,7 +381,9 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, "internal server error") return nil, status.Error(codes.Internal, "internal server error")
} }
expired, left := peer.LoginExpired(account.Settings)
expired, left := peer.LoginExpired(account.Settings.PeerLoginExpiration)
expired = account.Settings.PeerLoginExpirationEnabled && expired
if peer.UserID != "" && (expired || peer.Status.LoginExpired) { if peer.UserID != "" && (expired || peer.Status.LoginExpired) {
// it might be that peer expired but user has logged in already, check token then // it might be that peer expired but user has logged in already, check token then
if loginReq.GetJwtToken() == "" { if loginReq.GetJwtToken() == "" {

View File

@ -93,17 +93,24 @@ func (p *Peer) Copy() *Peer {
} }
} }
// MarkLoginExpired marks peer's status expired or not
func (p *Peer) MarkLoginExpired(expired bool) {
newStatus := p.Status.Copy()
newStatus.LastSeen = time.Now()
newStatus.LoginExpired = expired
p.Status = newStatus
}
// LoginExpired indicates whether the peer's login has expired or not. // LoginExpired indicates whether the peer's login has expired or not.
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired. // If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired). // Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).
// Login expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property. // Login expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
// Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled // Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled.
// and if disabled on the Account level, then Peer.LoginExpirationEnabled is ineffective. func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) {
func (p *Peer) LoginExpired(accountSettings *Settings) (bool, time.Duration) { expiresAt := p.LastLogin.Add(expiresIn)
expiresAt := p.LastLogin.Add(accountSettings.PeerLoginExpiration)
now := time.Now() now := time.Now()
timeLeft := expiresAt.Sub(now) timeLeft := expiresAt.Sub(now)
return accountSettings.PeerLoginExpirationEnabled && p.LoginExpirationEnabled && (timeLeft <= 0), timeLeft return p.LoginExpirationEnabled && (timeLeft <= 0), timeLeft
} }
// FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain // FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain
@ -202,13 +209,10 @@ func (am *DefaultAccountManager) MarkPeerLoginExpired(peerPubKey string, loginEx
return err return err
} }
newStatus := peer.Status.Copy() peer.MarkLoginExpired(loginExpired)
newStatus.LastSeen = time.Now()
newStatus.LoginExpired = loginExpired
peer.Status = newStatus
account.UpdatePeer(peer) account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
if err != nil { if err != nil {
return err return err
} }
@ -237,7 +241,8 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
return err return err
} }
newStatus := peer.Status.Copy() oldStatus := peer.Status.Copy()
newStatus := oldStatus
newStatus.LastSeen = time.Now() newStatus.LastSeen = time.Now()
newStatus.Connected = connected newStatus.Connected = connected
// whenever peer got connected that means that it logged in successfully // whenever peer got connected that means that it logged in successfully
@ -251,6 +256,20 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
if err != nil { if err != nil {
return err return err
} }
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(account)
}
if oldStatus.LoginExpired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
//the expired one. Here we notify them that connection is now allowed again.
err = am.updateAccountPeers(account)
if err != nil {
return err
}
}
return nil return nil
} }
@ -307,6 +326,10 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *Pe
event = activity.PeerLoginExpirationDisabled event = activity.PeerLoginExpirationDisabled
} }
am.storeEvent(userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) am.storeEvent(userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(account)
}
} }
account.UpdatePeer(peer) account.UpdatePeer(peer)
@ -529,7 +552,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
SSHEnabled: false, SSHEnabled: false,
SSHKey: peer.SSHKey, SSHKey: peer.SSHKey,
LastLogin: time.Now(), LastLogin: time.Now(),
LoginExpirationEnabled: false, LoginExpirationEnabled: true,
} }
// add peer to 'All' group // add peer to 'All' group
@ -775,6 +798,10 @@ func (a *Account) getPeersByACL(peerID string) []*Peer {
) )
continue continue
} }
expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration)
if expired {
continue
}
// exclude original peer // exclude original peer
if _, ok := peersSet[peer.ID]; peer.ID != peerID && !ok { if _, ok := peersSet[peer.ID]; peer.ID != peerID && !ok {
peersSet[peer.ID] = struct{}{} peersSet[peer.ID] = struct{}{}

View File

@ -57,7 +57,7 @@ func TestPeer_LoginExpired(t *testing.T) {
LastLogin: c.lastLogin, LastLogin: c.lastLogin,
} }
expired, _ := peer.LoginExpired(c.accountSettings) expired, _ := peer.LoginExpired(c.accountSettings.PeerLoginExpiration)
assert.Equal(t, expired, c.expected) assert.Equal(t, expired, c.expected)
}) })
} }

View File

@ -0,0 +1,114 @@
package server
import (
log "github.com/sirupsen/logrus"
"sync"
"time"
)
// Scheduler is an interface which implementations can schedule and cancel jobs
type Scheduler interface {
Cancel(IDs []string)
Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
}
// MockScheduler is a mock implementation of Scheduler
type MockScheduler struct {
CancelFunc func(IDs []string)
ScheduleFunc func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
}
// Cancel mocks the Cancel function of the Scheduler interface
func (mock *MockScheduler) Cancel(IDs []string) {
if mock.CancelFunc != nil {
mock.CancelFunc(IDs)
return
}
log.Errorf("MockScheduler doesn't have Cancel function defined ")
}
// Schedule mocks the Schedule function of the Scheduler interface
func (mock *MockScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
if mock.ScheduleFunc != nil {
mock.ScheduleFunc(in, ID, job)
return
}
log.Errorf("MockScheduler doesn't have Schedule function defined")
}
// DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them.
type DefaultScheduler struct {
// jobs map holds cancellation channels indexed by the job ID
jobs map[string]chan struct{}
mu *sync.Mutex
}
// NewDefaultScheduler creates an instance of a DefaultScheduler
func NewDefaultScheduler() *DefaultScheduler {
return &DefaultScheduler{
jobs: make(map[string]chan struct{}),
mu: &sync.Mutex{},
}
}
func (wm *DefaultScheduler) cancel(ID string) bool {
cancel, ok := wm.jobs[ID]
if ok {
delete(wm.jobs, ID)
select {
case cancel <- struct{}{}:
log.Debugf("cancelled scheduled job %s", ID)
default:
log.Warnf("couldn't cancel job %s because there was no routine listening on the cancel event", ID)
return false
}
}
return ok
}
// Cancel cancels the scheduled job by ID if present.
// If job wasn't found the function returns false.
func (wm *DefaultScheduler) Cancel(IDs []string) {
wm.mu.Lock()
defer wm.mu.Unlock()
for _, id := range IDs {
wm.cancel(id)
}
}
// Schedule a job to run in some time in the future. If job returns true then it will be scheduled one more time.
// If job with the provided ID already exists, a new one won't be scheduled.
func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wm.mu.Lock()
defer wm.mu.Unlock()
cancel := make(chan struct{})
if _, ok := wm.jobs[ID]; ok {
log.Debugf("couldn't schedule a job %s because it already exists. There are %d total jobs scheduled.",
ID, len(wm.jobs))
return
}
wm.jobs[ID] = cancel
log.Debugf("scheduled a job %s to run in %s. There are %d total jobs scheduled.", ID, in.String(), len(wm.jobs))
go func() {
select {
case <-time.After(in):
log.Debugf("time to do a scheduled job %s", ID)
runIn, reschedule := job()
wm.mu.Lock()
defer wm.mu.Unlock()
delete(wm.jobs, ID)
if reschedule {
go wm.Schedule(runIn, ID, job)
}
case <-cancel:
log.Debugf("stopped scheduled job %s ", ID)
wm.mu.Lock()
defer wm.mu.Unlock()
delete(wm.jobs, ID)
return
}
}()
}

View File

@ -0,0 +1,94 @@
package server
import (
"fmt"
"github.com/stretchr/testify/assert"
"math/rand"
"sync"
"testing"
"time"
)
func TestScheduler_Performance(t *testing.T) {
scheduler := NewDefaultScheduler()
n := 500
wg := &sync.WaitGroup{}
wg.Add(n)
maxMs := 500
minMs := 50
for i := 0; i < n; i++ {
millis := time.Duration(rand.Intn(maxMs-minMs)+minMs) * time.Millisecond
go scheduler.Schedule(millis, fmt.Sprintf("test-scheduler-job-%d", i), func() (nextRunIn time.Duration, reschedule bool) {
time.Sleep(millis)
wg.Done()
return 0, false
})
}
failed := waitTimeout(wg, 3*time.Second)
if failed {
t.Fatal("timed out while waiting for test to finish")
return
}
assert.Len(t, scheduler.jobs, 0)
}
func TestScheduler_Cancel(t *testing.T) {
jobID1 := "test-scheduler-job-1"
jobID2 := "test-scheduler-job-2"
scheduler := NewDefaultScheduler()
scheduler.Schedule(2*time.Second, jobID1, func() (nextRunIn time.Duration, reschedule bool) {
return 0, false
})
scheduler.Schedule(2*time.Second, jobID2, func() (nextRunIn time.Duration, reschedule bool) {
return 0, false
})
assert.Len(t, scheduler.jobs, 2)
scheduler.Cancel([]string{jobID1})
assert.Len(t, scheduler.jobs, 1)
assert.NotNil(t, scheduler.jobs[jobID2])
}
func TestScheduler_Schedule(t *testing.T) {
jobID := "test-scheduler-job-1"
scheduler := NewDefaultScheduler()
wg := &sync.WaitGroup{}
wg.Add(1)
// job without reschedule should be triggered once
job := func() (nextRunIn time.Duration, reschedule bool) {
wg.Done()
return 0, false
}
scheduler.Schedule(300*time.Millisecond, jobID, job)
failed := waitTimeout(wg, time.Second)
if failed {
t.Fatal("timed out while waiting for test to finish")
return
}
// job with reschedule should be triggered at least twice
wg = &sync.WaitGroup{}
mx := &sync.Mutex{}
scheduledTimes := 0
wg.Add(2)
job = func() (nextRunIn time.Duration, reschedule bool) {
mx.Lock()
defer mx.Unlock()
// ensure we repeat only twice
if scheduledTimes < 2 {
wg.Done()
scheduledTimes++
return 300 * time.Millisecond, true
}
return 0, false
}
scheduler.Schedule(300*time.Millisecond, jobID, job)
failed = waitTimeout(wg, time.Second)
if failed {
t.Fatal("timed out while waiting for test to finish")
return
}
scheduler.cancel(jobID)
}

View File

@ -115,6 +115,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) {
Turns: turns, Turns: turns,
}, },
} }
log.Debugf("sending new TURN credentials to peer %s", peerID)
err := m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update}) err := m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update})
if err != nil { if err != nil {
log.Errorf("error while sending TURN update to peer %s %v", peerID, err) log.Errorf("error while sending TURN update to peer %s %v", peerID, err)

View File

@ -60,10 +60,7 @@ func (p *PeersUpdateManager) CreateChannel(peerID string) chan *UpdateMessage {
return channel return channel
} }
// CloseChannel closes updates channel of a given peer func (p *PeersUpdateManager) closeChannel(peerID string) {
func (p *PeersUpdateManager) CloseChannel(peerID string) {
p.channelsMux.Lock()
defer p.channelsMux.Unlock()
if channel, ok := p.peerChannels[peerID]; ok { if channel, ok := p.peerChannels[peerID]; ok {
delete(p.peerChannels, peerID) delete(p.peerChannels, peerID)
close(channel) close(channel)
@ -72,6 +69,22 @@ func (p *PeersUpdateManager) CloseChannel(peerID string) {
log.Debugf("closed updates channel of a peer %s", peerID) log.Debugf("closed updates channel of a peer %s", peerID)
} }
// CloseChannels closes updates channel for each given peer
func (p *PeersUpdateManager) CloseChannels(peerIDs []string) {
p.channelsMux.Lock()
defer p.channelsMux.Unlock()
for _, id := range peerIDs {
p.closeChannel(id)
}
}
// CloseChannel closes updates channel of a given peer
func (p *PeersUpdateManager) CloseChannel(peerID string) {
p.channelsMux.Lock()
defer p.channelsMux.Unlock()
p.closeChannel(peerID)
}
// GetAllConnectedPeers returns a copy of the connected peers map // GetAllConnectedPeers returns a copy of the connected peers map
func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} { func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
p.channelsMux.Lock() p.channelsMux.Lock()