From 0bdcb41e20223563bbff0e8e5a5becfa2a781a96 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 23 Oct 2024 19:03:48 +0300 Subject: [PATCH] Refactor peer expiry, inactivity, location and status update to remove get account Signed-off-by: bcmmbaga --- management/server/account.go | 59 ++++++------------- management/server/account_test.go | 23 +++----- management/server/mock_server/account_mock.go | 4 +- management/server/peer.go | 31 ++++++---- management/server/sql_store.go | 8 +-- management/server/sql_store_test.go | 16 ++--- management/server/store.go | 4 +- management/server/user.go | 16 +++-- 8 files changed, 72 insertions(+), 89 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 5759b1730..c08665c2a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -95,7 +95,7 @@ type AccountManager interface { GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) ListUsers(ctx context.Context, accountID string) ([]*User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *Account) error + MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) @@ -1230,29 +1230,24 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context. func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + expiredPeers, err := am.getExpiredPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed getting account %s expiring peers", accountID) - return account.GetNextPeerExpiration() + return 0, false } - expiredPeers := account.GetExpiredPeers() var peerIDs []string for _, peer := range expiredPeers { peerIDs = append(peerIDs, peer.ID) } - log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID) - if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { - log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", account.Id) - return account.GetNextPeerExpiration() + if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil { + log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", accountID) + return 0, false } - return account.GetNextPeerExpiration() + return am.getNextPeerExpiration(ctx, accountID) } } @@ -1266,29 +1261,25 @@ func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context // peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + inactivePeers, err := am.getInactivePeers(ctx, accountID) if err != nil { - log.Errorf("failed getting account %s expiring peers", account.Id) - return account.GetNextInactivePeerExpiration() + log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID) + return 0, false } - expiredPeers := account.GetInactivePeers() var peerIDs []string - for _, peer := range expiredPeers { + for _, peer := range inactivePeers { peerIDs = append(peerIDs, peer.ID) } - log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID) - if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { - log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) - return account.GetNextInactivePeerExpiration() + if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers); err != nil { + log.Errorf("failed updating account peers while expiring peers for account %s", accountID) + return 0, false } - return account.GetNextInactivePeerExpiration() + return am.getNextInactivePeerExpiration(ctx, accountID) } } @@ -2317,7 +2308,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID return nil, nil, nil, err } - err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account) + err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } @@ -2326,23 +2317,11 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID } func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { - accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) - defer accountUnlock() - peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) - defer peerUnlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account) + err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } - return nil - } func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error { diff --git a/management/server/account_test.go b/management/server/account_test.go index b656320a9..565c3129b 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1695,10 +1695,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "unable to get the account") - - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ @@ -1721,11 +1718,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { // disable expiration first update := peer.Copy() update.LoginExpirationEnabled = false - _, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) + _, err = manager.UpdatePeer(context.Background(), accountID, userID, update) require.NoError(t, err, "unable to update peer") // enabling expiration should trigger the routine update.LoginExpirationEnabled = true - _, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) + _, err = manager.UpdatePeer(context.Background(), accountID, userID, update) require.NoError(t, err, "unable to update peer") failed := waitTimeout(wg, time.Second) @@ -1769,11 +1766,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "unable to get the account") - // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1801,10 +1795,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "unable to get the account") - - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} @@ -1818,7 +1809,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }, } // enabling PeerLoginExpirationEnabled should trigger the expiration job - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1831,7 +1822,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test wg.Add(1) // disabling PeerLoginExpirationEnabled should trigger cancel - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index a996db298..18f6ff16c 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -115,7 +115,7 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str if am.SyncAndMarkPeerFunc != nil { return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) } - return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") + return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncAndMarkPeer is not implemented") } func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error { @@ -214,7 +214,7 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error { if am.MarkPeerConnectedFunc != nil { return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) } diff --git a/management/server/peer.go b/management/server/peer.go index a4fb6c15b..eb60d02cb 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -104,37 +104,46 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error { - peer, err := account.FindPeerByPubKey(peerPubKey) +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peerPubKey) if err != nil { return err } - expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account) + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, accountID) if err != nil { return err } if peer.AddedWithSSOLogin() { - if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, account.Id) + if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } - if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { - am.checkAndSchedulePeerInactivityExpiration(ctx, account.Id) + if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } } if expired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf(errGetAccountFmt, err) + } am.updateAccountPeers(ctx, account) } return nil } -func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) { +func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -154,16 +163,14 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context peer.Location.CountryCode = location.Country.ISOCode peer.Location.CityName = location.City.Names.En peer.Location.GeoNameID = location.City.GeonameID - err = am.Store.SavePeerLocation(account.Id, peer) + err = am.Store.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer) if err != nil { log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) } } } - account.UpdatePeer(peer) - - err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) + err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus) if err != nil { return false, err } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index bd2c42373..fe2e4ee76 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -346,7 +346,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID return nil } -func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { +func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -354,7 +354,7 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe "peer_status_last_seen", "peer_status_connected", "peer_status_login_expired", "peer_status_required_approval", } - result := s.db.Model(&nbpeer.Peer{}). + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). Select(fieldsToUpdate). Where(accountAndIDQueryCondition, accountID, peerID). Updates(&peerCopy) @@ -369,14 +369,14 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe return nil } -func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { +func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peerWithLocation *nbpeer.Peer) error { // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. var peerCopy nbpeer.Peer // Since the location field has been migrated to JSON serialization, // updating the struct ensures the correct data format is inserted into the database. peerCopy.Location = peerWithLocation.Location - result := s.db.Model(&nbpeer.Peer{}). + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID). Updates(peerCopy) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index a30c9771a..0a3e1aaf9 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -445,7 +445,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { // save status of non-existing peer newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -465,7 +465,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { err = store.SaveAccount(context.Background(), account) require.NoError(t, err) - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -476,7 +476,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { newStatus.Connected = true - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -511,7 +511,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { Meta: nbpeer.PeerSystemMeta{}, } // error is expected as peer is not in store yet - err = store.SavePeerLocation(account.Id, peer) + err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer) assert.Error(t, err) account.Peers[peer.ID] = peer @@ -523,7 +523,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { peer.Location.CityName = "Berlin" peer.Location.GeoNameID = 2950159 - err = store.SavePeerLocation(account.Id, account.Peers[peer.ID]) + err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, account.Peers[peer.ID]) assert.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -533,7 +533,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { assert.Equal(t, peer.Location, actual) peer.ID = "non-existing-peer" - err = store.SavePeerLocation(account.Id, peer) + err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -916,7 +916,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { // save status of non-existing peer newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) assert.Error(t, err) // save new status of existing peer @@ -933,7 +933,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { err = store.SaveAccount(context.Background(), account) require.NoError(t, err) - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) diff --git a/management/server/store.go b/management/server/store.go index 6ab83c458..1f2ebc81e 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -102,8 +102,8 @@ type Store interface { GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error - SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error - SavePeerLocation(accountID string, peer *nbpeer.Peer) error + SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error + SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error diff --git a/management/server/user.go b/management/server/user.go index cd5e7c334..44b371b49 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -768,7 +768,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if len(expiredPeers) > 0 { - if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { + if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil { log.WithContext(ctx).Errorf("failed update expired peers: %s", err) return nil, err } @@ -1049,21 +1049,22 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun } // expireAndUpdatePeers expires all peers of the given user and updates them in the account -func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error { +func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { var peerIDs []string for _, peer := range peers { if peer.Status.LoginExpired { continue } + peerIDs = append(peerIDs, peer.ID) peer.MarkLoginExpired(true) - account.UpdatePeer(peer) - if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil { + + if err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil { return err } am.StoreEvent( ctx, - peer.UserID, peer.ID, account.Id, + peer.UserID, peer.ID, accountID, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), ) } @@ -1071,6 +1072,11 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) + + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf(errGetAccountFmt, err) + } am.updateAccountPeers(ctx, account) } return nil