Add lock for peer store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-11-14 19:24:51 +03:00
parent 8420a52563
commit f5e7449d01
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
2 changed files with 26 additions and 12 deletions

View File

@ -300,12 +300,12 @@ func (s *SqlStore) GetInstallationID() string {
return installation.InstallationIDValue return installation.InstallationIDValue
} }
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
peerCopy := peer.Copy() peerCopy := peer.Copy()
peerCopy.AccountID = accountID peerCopy.AccountID = accountID
err := s.db.Transaction(func(tx *gorm.DB) error { err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Transaction(func(tx *gorm.DB) error {
// check if peer exists before saving // check if peer exists before saving
var peerID string var peerID string
result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID) result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
@ -355,7 +355,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
return nil return nil
} }
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus peerCopy.Status = &peerStatus
@ -363,7 +363,7 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
"peer_status_last_seen", "peer_status_connected", "peer_status_last_seen", "peer_status_connected",
"peer_status_login_expired", "peer_status_required_approval", "peer_status_login_expired", "peer_status_required_approval",
} }
result := s.db.Model(&nbpeer.Peer{}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Select(fieldsToUpdate). Select(fieldsToUpdate).
Where(accountAndIDQueryCondition, accountID, peerID). Where(accountAndIDQueryCondition, accountID, peerID).
Updates(&peerCopy) Updates(&peerCopy)
@ -378,14 +378,14 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
return nil return nil
} }
func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peerWithLocation *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer var peerCopy nbpeer.Peer
// Since the location field has been migrated to JSON serialization, // Since the location field has been migrated to JSON serialization,
// updating the struct ensures the correct data format is inserted into the database. // updating the struct ensures the correct data format is inserted into the database.
peerCopy.Location = peerWithLocation.Location peerCopy.Location = peerWithLocation.Location
result := s.db.Model(&nbpeer.Peer{}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID). Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
Updates(peerCopy) Updates(peerCopy)
@ -740,9 +740,10 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
return accountID, nil return accountID, nil
} }
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) {
var accountID string var accountID string
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&User{}).
Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@ -1066,6 +1067,18 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
return nil return nil
} }
// GetAccountPeers retrieves peers for an account.
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get peers from store")
}
return peers, nil
}
// GetUserPeers retrieves peers for a user. // GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer var peers []*nbpeer.Peer

View File

@ -48,7 +48,7 @@ type Store interface {
GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByUserID(userID string) (string, error) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
@ -99,15 +99,16 @@ type Store interface {
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)