mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-31 18:39:31 +01:00
Add store locks and prevent fetching setup keys peers when retrieving user peers with empty userID
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
21561a2b07
commit
fde9f2ffda
@ -558,21 +558,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
|
||||
|
||||
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
|
||||
err = transaction.AddPeerToAllGroup(ctx, LockingStrengthUpdate, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||
}
|
||||
|
||||
if len(groupsToAdd) > 0 {
|
||||
for _, g := range groupsToAdd {
|
||||
err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g)
|
||||
err = transaction.AddPeerToGroup(ctx, LockingStrengthUpdate, accountID, newPeer.ID, g)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = transaction.AddPeerToAccount(ctx, newPeer)
|
||||
err = transaction.AddPeerToAccount(ctx, LockingStrengthUpdate, newPeer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||
}
|
||||
|
@ -1030,9 +1030,10 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
|
||||
var group nbgroup.Group
|
||||
result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&group, "account_id = ? AND name = ?", accountID, "All")
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group 'All' not found for account")
|
||||
@ -1048,16 +1049,17 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
|
||||
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
|
||||
if err := s.db.Save(&group).Error; err != nil {
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error {
|
||||
var group nbgroup.Group
|
||||
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID).
|
||||
First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.NewGroupNotFoundError(groupID)
|
||||
@ -1074,7 +1076,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
|
||||
|
||||
group.Peers = append(group.Peers, peerId)
|
||||
|
||||
if err := s.db.Save(&group).Error; err != nil {
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue updating group: %s", err)
|
||||
}
|
||||
|
||||
@ -1096,6 +1098,12 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre
|
||||
// GetUserPeers retrieves peers for a user.
|
||||
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
var peers []*nbpeer.Peer
|
||||
|
||||
// Exclude peers added via setup keys, as they are not user-specific and have an empty user_id.
|
||||
if userID == "" {
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Find(&peers, "account_id = ? AND user_id = ?", accountID, userID)
|
||||
if err := result.Error; err != nil {
|
||||
@ -1106,8 +1114,8 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||
if err := s.db.Create(peer).Error; err != nil {
|
||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error {
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
||||
}
|
||||
|
||||
|
@ -95,9 +95,9 @@ type Store interface {
|
||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
|
||||
AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error
|
||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
|
Loading…
Reference in New Issue
Block a user