From 207fa059d22b566e2887fbffbd20b6791d73c464 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 19 May 2025 16:42:47 +0200 Subject: [PATCH] [management] make locking strength clause optional (#3844) --- management/server/store/sql_store.go | 379 +++++++++++++++++++++++---- management/server/store/store.go | 1 + 2 files changed, 326 insertions(+), 54 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 878f09c7e..eb194ca9b 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -481,8 +481,13 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) } func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var accountID string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Select("id"). + result := tx.Model(&types.Account{}).Select("id"). Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", strings.ToLower(domain), true, types.PrivateCategory, ).First(&accountID) @@ -530,8 +535,13 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri } func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var user types.User - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id"). Where("personal_access_tokens.id = ?", patID).First(&user) if result.Error != nil { @@ -546,8 +556,13 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren } func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var user types.User - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&user, idQueryCondition, userID) + result := tx.First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) @@ -578,8 +593,13 @@ func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, } func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var users []*types.User - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) + result := tx.Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -592,8 +612,13 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre } func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var user types.User - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner) + result := tx.First(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed") @@ -605,8 +630,13 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre } func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var groups []*types.Group - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) + result := tx.Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -619,11 +649,16 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr } func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var groups []*types.Group likePattern := `%"ID":"` + resourceID + `"%` - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. Where("resources LIKE ?", likePattern). Find(&groups) @@ -664,8 +699,13 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) { } func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var accountMeta types.AccountMeta - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). + result := tx.Model(&types.Account{}). First(&accountMeta, idQueryCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("error when getting account meta %s from the store: %s", accountID, result.Error) @@ -833,8 +873,13 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) } func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var accountID string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.User{}). + result := tx.Model(&types.User{}). Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -847,8 +892,13 @@ func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength Lockin } func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var accountID string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + result := tx.Model(&nbpeer.Peer{}). Select("account_id").Where(idQueryCondition, peerID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -879,10 +929,15 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) } func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var ipJSONStrings []string // Fetch the IP addresses as JSON strings - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + result := tx.Model(&nbpeer.Peer{}). Where("account_id = ?", accountID). Pluck("ip", &ipJSONStrings) if result.Error != nil { @@ -906,8 +961,13 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength } func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var labels []string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + result := tx.Model(&nbpeer.Peer{}). Where("account_id = ?", accountID). Pluck("dns_label", &labels) @@ -923,8 +983,13 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock } func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var accountNetwork types.AccountNetwork - if err := s.db.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { + if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } @@ -934,8 +999,13 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt } func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var peer nbpeer.Peer - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, GetKeyQueryCondition(s), peerKey) + result := tx.First(&peer, GetKeyQueryCondition(s), peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -948,8 +1018,13 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking } func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var accountSettings types.AccountSettings - if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { + if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } @@ -959,8 +1034,13 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS } func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var createdBy string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). + result := tx.Model(&types.Account{}). Select("created_by").First(&createdBy, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1156,8 +1236,13 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s } func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var setupKey types.SetupKey - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. First(&setupKey, GetKeyQueryCondition(s), key) if result.Error != nil { @@ -1299,8 +1384,13 @@ func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string // GetPeerGroups retrieves all groups assigned to a specific peer in a given account. func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var groups []*types.Group - query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + query := tx. Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId)) if query.Error != nil { @@ -1332,6 +1422,11 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var peers []*nbpeer.Peer // Exclude peers added via setup keys, as they are not user-specific and have an empty user_id. @@ -1339,7 +1434,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt return peers, nil } - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. Find(&peers, "account_id = ? AND user_id = ?", accountID, userID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err) @@ -1359,8 +1454,13 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStr // GetPeerByID retrieves a peer by its ID and account ID. func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var peer *nbpeer.Peer - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. First(&peer, accountAndIDQueryCondition, accountID, peerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1374,8 +1474,13 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength // GetPeersByIDs retrieves peers by their IDs and account ID. func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var peers []*nbpeer.Peer - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs) + result := tx.Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store") @@ -1391,8 +1496,13 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng // GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user. func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var peers []*nbpeer.Peer - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). Find(&peers, accountIDCondition, accountID) if err := result.Error; err != nil { @@ -1405,8 +1515,13 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng // GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user. func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var peers []*nbpeer.Peer - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). Find(&peers, accountIDCondition, accountID) if err := result.Error; err != nil { @@ -1419,8 +1534,13 @@ func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStreng // GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing. func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var allEphemeralPeers, batchPeers []*nbpeer.Peer - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. Where("ephemeral = ?", true). FindInBatches(&batchPeers, 1000, func(tx *gorm.DB, batch int) error { allEphemeralPeers = append(allEphemeralPeers, batchPeers...) @@ -1496,8 +1616,13 @@ func (s *SqlStore) GetDB() *gorm.DB { } func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var accountDNSSettings types.AccountDNSSettings - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). + result := tx.Model(&types.Account{}). First(&accountDNSSettings, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1511,8 +1636,13 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki // AccountExists checks whether an account exists by the given ID. func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var accountID string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). + result := tx.Model(&types.Account{}). Select("id").First(&accountID, idQueryCondition, id) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1526,8 +1656,13 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng // GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var account types.Account - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Select("domain", "domain_category"). + result := tx.Model(&types.Account{}).Select("domain", "domain_category"). Where(idQueryCondition, accountID).First(&account) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1541,8 +1676,13 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength // GetGroupByID retrieves a group by ID and account ID. func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var group *types.Group - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID) + result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewGroupNotFoundError(groupID) @@ -1556,11 +1696,16 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt // GetGroupByName retrieves a group by name and account ID. func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var group types.Group // TODO: This fix is accepted for now, but if we need to handle this more frequently // we may need to reconsider changing the types. - query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) + query := tx.Preload(clause.Associations) switch s.storeEngine { case types.PostgresStoreEngine: @@ -1584,8 +1729,13 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren // GetGroupsByIDs retrieves groups by their IDs and account ID. func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var groups []*types.Group - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) + result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store") @@ -1639,8 +1789,13 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var policies []*types.Policy - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. Preload(clause.Associations).Find(&policies, accountIDCondition, accountID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error) @@ -1652,8 +1807,14 @@ func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingS // GetPolicyByID retrieves a policy by its ID and account ID. func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var policy *types.Policy - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). + + result := tx.Preload(clause.Associations). First(&policy, accountAndIDQueryCondition, accountID, policyID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -1712,8 +1873,13 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrengt // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var postureChecks []*posture.Checks - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID) + result := tx.Find(&postureChecks, accountIDCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get posture checks from store") @@ -1724,8 +1890,13 @@ func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength Loc // GetPostureChecksByID retrieves posture checks by their ID and account ID. func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var postureCheck *posture.Checks - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1740,8 +1911,13 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin // GetPostureChecksByIDs retrieves posture checks by their IDs and account ID. func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var postureChecks []*posture.Checks - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs) + result := tx.Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store") @@ -1794,8 +1970,13 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt // GetAccountSetupKeys retrieves setup keys for an account. func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var setupKeys []*types.SetupKey - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. Find(&setupKeys, accountIDCondition, accountID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err) @@ -1807,8 +1988,13 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking // GetSetupKeyByID retrieves a setup key by its ID and account ID. func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var setupKey *types.SetupKey - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -1849,8 +2035,13 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren // GetAccountNameServerGroups retrieves name server groups for an account. func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var nsGroups []*nbdns.NameServerGroup - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID) + result := tx.Find(&nsGroups, accountIDCondition, accountID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err) return nil, status.Errorf(status.Internal, "failed to get name server groups from store") @@ -1861,8 +2052,13 @@ func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength // GetNameServerGroupByID retrieves a name server group by its ID and account ID. func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var nsGroup *nbdns.NameServerGroup - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -1902,9 +2098,14 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki // getRecords retrieves records from the database based on the account ID. func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { + tx := db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var record []T - result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID) + result := tx.Find(&record, accountIDCondition, accountID) if err := result.Error; err != nil { parts := strings.Split(fmt.Sprintf("%T", record), ".") recordType := parts[len(parts)-1] @@ -1917,9 +2118,14 @@ func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID stri // getRecordByID retrieves a record by its ID and account ID from the database. func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) { + tx := db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var record T - result := db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&record, accountAndIDQueryCondition, accountID, recordID) if err := result.Error; err != nil { parts := strings.Split(fmt.Sprintf("%T", record), ".") @@ -1950,8 +2156,13 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre } func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var networks []*networkTypes.Network - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&networks, accountIDCondition, accountID) + result := tx.Find(&networks, accountIDCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get networks from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get networks from store") @@ -1961,8 +2172,13 @@ func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingS } func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var network *networkTypes.Network - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. First(&network, accountAndIDQueryCondition, accountID, networkID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -2002,8 +2218,13 @@ func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStreng } func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var netRouters []*routerTypes.NetworkRouter - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get network routers from store: %v", result.Error) @@ -2014,8 +2235,13 @@ func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength Lo } func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var netRouters []*routerTypes.NetworkRouter - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. Find(&netRouters, accountIDCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get network routers from store: %v", result.Error) @@ -2026,8 +2252,13 @@ func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrengt } func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var netRouter *routerTypes.NetworkRouter - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. First(&netRouter, accountAndIDQueryCondition, accountID, routerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -2066,8 +2297,13 @@ func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength Locking } func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var netResources []*resourceTypes.NetworkResource - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get network resources from store: %v", result.Error) @@ -2078,8 +2314,13 @@ func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength } func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var netResources []*resourceTypes.NetworkResource - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. Find(&netResources, accountIDCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get network resources from store: %v", result.Error) @@ -2090,8 +2331,13 @@ func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStren } func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var netResources *resourceTypes.NetworkResource - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. First(&netResources, accountAndIDQueryCondition, accountID, resourceID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -2105,8 +2351,13 @@ func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength Lock } func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var netResources *resourceTypes.NetworkResource - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. First(&netResources, "account_id = ? AND name = ?", accountID, resourceName) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -2146,8 +2397,13 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki // GetPATByHashedToken returns a PersonalAccessToken by its hashed token. func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var pat types.PersonalAccessToken - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken) + result := tx.First(&pat, "hashed_token = ?", hashedToken) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewPATNotFoundError(hashedToken) @@ -2161,8 +2417,13 @@ func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength Locking // GetPATByID retrieves a personal access token by its ID and user ID. func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var pat types.PersonalAccessToken - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. First(&pat, "id = ? AND user_id = ?", patID, userID) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -2177,8 +2438,13 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, // GetUserPATs retrieves personal access tokens for a user. func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var pats []*types.PersonalAccessToken - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID) + result := tx.Find(&pats, "user_id = ?", userID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err) return nil, status.Errorf(status.Internal, "failed to get user pat's from store") @@ -2236,10 +2502,15 @@ func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, } func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + jsonValue := fmt.Sprintf(`"%s"`, ip.String()) var peer nbpeer.Peer - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. First(&peer, "account_id = ? AND ip = ?", accountID, jsonValue) if result.Error != nil { // no logging here diff --git a/management/server/store/store.go b/management/server/store/store.go index d64596269..3d529ceb5 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -44,6 +44,7 @@ const ( LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions. LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows. LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates. + LockingStrengthNone LockingStrength = "NONE" // No locking, allowing all transactions to proceed without restrictions. ) type Store interface {