Add store locks and prevent fetching setup keys peers when retrieving user peers with empty userID

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-11-28 12:18:02 +03:00
parent 21561a2b07
commit fde9f2ffda
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
3 changed files with 22 additions and 14 deletions

View File

@ -558,21 +558,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID) err = transaction.AddPeerToAllGroup(ctx, LockingStrengthUpdate, accountID, newPeer.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err) return fmt.Errorf("failed adding peer to All group: %w", err)
} }
if len(groupsToAdd) > 0 { if len(groupsToAdd) > 0 {
for _, g := range groupsToAdd { for _, g := range groupsToAdd {
err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g) err = transaction.AddPeerToGroup(ctx, LockingStrengthUpdate, accountID, newPeer.ID, g)
if err != nil { if err != nil {
return err return err
} }
} }
} }
err = transaction.AddPeerToAccount(ctx, newPeer) err = transaction.AddPeerToAccount(ctx, LockingStrengthUpdate, newPeer)
if err != nil { if err != nil {
return fmt.Errorf("failed to add peer to account: %w", err) return fmt.Errorf("failed to add peer to account: %w", err)
} }

View File

@ -1030,9 +1030,10 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
return nil return nil
} }
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
var group nbgroup.Group var group nbgroup.Group
result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&group, "account_id = ? AND name = ?", accountID, "All")
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account") return status.Errorf(status.NotFound, "group 'All' not found for account")
@ -1048,16 +1049,17 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
group.Peers = append(group.Peers, peerID) group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil { if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All': %s", err) return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
} }
return nil return nil
} }
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error {
var group nbgroup.Group var group nbgroup.Group
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID).
First(&group)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewGroupNotFoundError(groupID) return status.NewGroupNotFoundError(groupID)
@ -1074,7 +1076,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
group.Peers = append(group.Peers, peerId) group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil { if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group: %s", err) return status.Errorf(status.Internal, "issue updating group: %s", err)
} }
@ -1096,6 +1098,12 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre
// GetUserPeers retrieves peers for a user. // GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
// Exclude peers added via setup keys, as they are not user-specific and have an empty user_id.
if userID == "" {
return peers, nil
}
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&peers, "account_id = ? AND user_id = ?", accountID, userID) Find(&peers, "account_id = ? AND user_id = ?", accountID, userID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
@ -1106,8 +1114,8 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
return peers, nil return peers, nil
} }
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error {
if err := s.db.Create(peer).Error; err != nil { if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err) return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
} }

View File

@ -95,9 +95,9 @@ type Store interface {
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)