Merge branch 'setupkey-get-account-refactoring' into groups-get-account-refactoring

# Conflicts:
#	management/server/sql_store.go
This commit is contained in:
bcmmbaga 2024-11-08 19:48:13 +03:00
commit d58cf50127
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547

View File

@ -300,13 +300,11 @@ func (s *SqlStore) GetInstallationID() string {
} }
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
startTime := time.Now()
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
peerCopy := peer.Copy() peerCopy := peer.Copy()
peerCopy.AccountID = accountID peerCopy.AccountID = accountID
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err := s.db.Transaction(func(tx *gorm.DB) error {
// check if peer exists before saving // check if peer exists before saving
var peerID string var peerID string
result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID) result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
@ -327,9 +325,6 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
}) })
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return err return err
} }
@ -337,8 +332,6 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
} }
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
startTime := time.Now()
accountCopy := Account{ accountCopy := Account{
Domain: domain, Domain: domain,
DomainCategory: category, DomainCategory: category,
@ -346,14 +339,11 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
} }
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
result := s.db.WithContext(ctx).Model(&Account{}). result := s.db.Model(&Account{}).
Select(fieldsToUpdate). Select(fieldsToUpdate).
Where(idQueryCondition, accountID). Where(idQueryCondition, accountID).
Updates(&accountCopy) Updates(&accountCopy)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return result.Error return result.Error
} }
@ -365,8 +355,6 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
} }
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
startTime := time.Now()
var peerCopy nbpeer.Peer var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus peerCopy.Status = &peerStatus
@ -379,9 +367,6 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
Where(accountAndIDQueryCondition, accountID, peerID). Where(accountAndIDQueryCondition, accountID, peerID).
Updates(&peerCopy) Updates(&peerCopy)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return result.Error return result.Error
} }
@ -393,8 +378,6 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
} }
func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
startTime := time.Now()
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer var peerCopy nbpeer.Peer
// Since the location field has been migrated to JSON serialization, // Since the location field has been migrated to JSON serialization,
@ -406,9 +389,6 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
Updates(peerCopy) Updates(peerCopy)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return result.Error return result.Error
} }
@ -422,8 +402,6 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
// SaveUsers saves the given list of users to the database. // SaveUsers saves the given list of users to the database.
// It updates existing users if a conflict occurs. // It updates existing users if a conflict occurs.
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
startTime := time.Now()
usersToSave := make([]User, 0, len(users)) usersToSave := make([]User, 0, len(users))
for _, user := range users { for _, user := range users {
user.AccountID = accountID user.AccountID = accountID
@ -437,9 +415,6 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
Clauses(clause.OnConflict{UpdateAll: true}). Clauses(clause.OnConflict{UpdateAll: true}).
Create(&usersToSave).Error Create(&usersToSave).Error
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "failed to save users to store: %v", err) return status.Errorf(status.Internal, "failed to save users to store: %v", err)
} }
@ -448,13 +423,8 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
// SaveUser saves the given user to the database. // SaveUser saves the given user to the database.
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
startTime := time.Now() result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error) return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
} }
return nil return nil
@ -462,17 +432,12 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u
// SaveGroups saves the given list of groups to the database. // SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
startTime := time.Now()
if len(groups) == 0 { if len(groups) == 0 {
return nil return nil
} }
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
} }
return nil return nil
@ -499,10 +464,8 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
} }
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
startTime := time.Now()
var accountID string var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id"). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory, strings.ToLower(domain), true, PrivateCategory,
).First(&accountID) ).First(&accountID)
@ -510,9 +473,6 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
} }
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
return "", status.NewGetAccountFromStoreError(result.Error) return "", status.NewGetAccountFromStoreError(result.Error)
} }
@ -521,17 +481,12 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
} }
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
startTime := time.Now()
var key SetupKey var key SetupKey
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, setupKey) result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewSetupKeyNotFoundError(result.Error) return nil, status.NewSetupKeyNotFoundError(result.Error)
} }
@ -543,17 +498,12 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
} }
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) { func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
startTime := time.Now()
var token PersonalAccessToken var token PersonalAccessToken
result := s.db.First(&token, "hashed_token = ?", hashedToken) result := s.db.First(&token, "hashed_token = ?", hashedToken)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return "", status.NewGetAccountFromStoreError(result.Error) return "", status.NewGetAccountFromStoreError(result.Error)
} }
@ -562,17 +512,12 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
} }
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) { func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) {
startTime := time.Now()
var token PersonalAccessToken var token PersonalAccessToken
result := s.db.First(&token, idQueryCondition, tokenID) result := s.db.First(&token, idQueryCondition, tokenID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return nil, status.NewGetAccountFromStoreError(result.Error) return nil, status.NewGetAccountFromStoreError(result.Error)
} }
@ -596,18 +541,13 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
} }
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
startTime := time.Now()
var user User var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).First(&user, idQueryCondition, userID) Preload(clause.Associations).First(&user, idQueryCondition, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID) return nil, status.NewUserNotFoundError(userID)
} }
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewGetUserFromStoreError() return nil, status.NewGetUserFromStoreError()
} }
@ -615,17 +555,12 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
} }
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
startTime := time.Now()
var users []*User var users []*User
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
} }
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error) log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting users from store") return nil, status.Errorf(status.Internal, "issue getting users from store")
} }
@ -634,17 +569,12 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
} }
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) { func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) {
startTime := time.Now()
var groups []*nbgroup.Group var groups []*nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
} }
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("failed to get account groups from the store: %s", result.Error) log.WithContext(ctx).Errorf("failed to get account groups from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get account groups from the store") return nil, status.Errorf(status.Internal, "failed to get account groups from the store")
} }
@ -744,17 +674,12 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
} }
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
startTime := time.Now()
var user User var user User
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID) result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewGetAccountFromStoreError(result.Error) return nil, status.NewGetAccountFromStoreError(result.Error)
} }
@ -766,17 +691,12 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
} }
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
startTime := time.Now()
var peer nbpeer.Peer var peer nbpeer.Peer
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID) result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewGetAccountFromStoreError(result.Error) return nil, status.NewGetAccountFromStoreError(result.Error)
} }
@ -788,17 +708,12 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
} }
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
startTime := time.Now()
var peer nbpeer.Peer var peer nbpeer.Peer
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey) result := s.db.Select("account_id").First(&peer, keyQueryCondition, peerKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewGetAccountFromStoreError(result.Error) return nil, status.NewGetAccountFromStoreError(result.Error)
} }
@ -810,18 +725,13 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
} }
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
startTime := time.Now()
var peer nbpeer.Peer var peer nbpeer.Peer
var accountID string var accountID string
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID) result := s.db.Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
return "", status.NewGetAccountFromStoreError(result.Error) return "", status.NewGetAccountFromStoreError(result.Error)
} }
@ -829,17 +739,12 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
} }
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
startTime := time.Now()
var accountID string var accountID string
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
return "", status.NewGetAccountFromStoreError(result.Error) return "", status.NewGetAccountFromStoreError(result.Error)
} }
@ -847,17 +752,12 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
} }
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
startTime := time.Now()
var accountID string var accountID string
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
} }
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
return "", status.NewSetupKeyNotFoundError(result.Error) return "", status.NewSetupKeyNotFoundError(result.Error)
} }
@ -869,21 +769,16 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
} }
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
startTime := time.Now()
var ipJSONStrings []string var ipJSONStrings []string
// Fetch the IP addresses as JSON strings // Fetch the IP addresses as JSON strings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID). Where("account_id = ?", accountID).
Pluck("ip", &ipJSONStrings) Pluck("ip", &ipJSONStrings)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account") return nil, status.Errorf(status.NotFound, "no peers found for the account")
} }
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
} }
@ -901,10 +796,8 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
} }
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
startTime := time.Now()
var labels []string var labels []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID). Where("account_id = ?", accountID).
Pluck("dns_label", &labels) Pluck("dns_label", &labels)
@ -912,9 +805,6 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account") return nil, status.Errorf(status.NotFound, "no peers found for the account")
} }
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error)
} }
@ -923,33 +813,23 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
} }
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
startTime := time.Now()
var accountNetwork AccountNetwork var accountNetwork AccountNetwork
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID) return nil, status.NewAccountNotFoundError(accountID)
} }
if errors.Is(err, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err) return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
} }
return accountNetwork.Network, nil return accountNetwork.Network, nil
} }
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
startTime := time.Now()
var peer nbpeer.Peer var peer nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found") return nil, status.Errorf(status.NotFound, "peer not found")
} }
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
} }
@ -957,16 +837,11 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
} }
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) { func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
startTime := time.Now()
var accountSettings AccountSettings var accountSettings AccountSettings
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found") return nil, status.Errorf(status.NotFound, "settings not found")
} }
if errors.Is(err, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err) return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err)
} }
return accountSettings.Settings, nil return accountSettings.Settings, nil
@ -974,17 +849,12 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
// SaveUserLastLogin stores the last login time for a user in DB. // SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
startTime := time.Now()
var user User var user User
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID) result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID) return status.NewUserNotFoundError(userID)
} }
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.NewGetUserFromStoreError() return status.NewGetUserFromStoreError()
} }
user.LastLogin = lastLogin user.LastLogin = lastLogin
@ -993,8 +863,6 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri
} }
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
startTime := time.Now()
definitionJSON, err := json.Marshal(checks) definitionJSON, err := json.Marshal(checks)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1003,9 +871,6 @@ func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *p
var postureCheck posture.Checks var postureCheck posture.Checks
err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, err return nil, err
} }
@ -1115,27 +980,20 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore,
} }
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
startTime := time.Now()
var setupKey SetupKey var setupKey SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, keyQueryCondition, key) First(&setupKey, keyQueryCondition, key)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found") return nil, status.Errorf(status.NotFound, "setup key not found")
} }
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewSetupKeyNotFoundError(result.Error) return nil, status.NewSetupKeyNotFoundError(result.Error)
} }
return &setupKey, nil return &setupKey, nil
} }
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
startTime := time.Now() result := s.db.Model(&SetupKey{}).
result := s.db.WithContext(ctx).Model(&SetupKey{}).
Where(idQueryCondition, setupKeyID). Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{ Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"), "used_times": gorm.Expr("used_times + 1"),
@ -1143,9 +1001,6 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
}) })
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error) return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
} }
@ -1157,17 +1012,12 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
} }
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
startTime := time.Now()
var group nbgroup.Group var group nbgroup.Group
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group) result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account") return status.Errorf(status.NotFound, "group 'All' not found for account")
} }
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error) return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
} }
@ -1180,9 +1030,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
group.Peers = append(group.Peers, peerID) group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil { if err := s.db.Save(&group).Error; err != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue updating group 'All': %s", err) return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
} }
@ -1190,17 +1037,13 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
} }
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
startTime := time.Now()
var group nbgroup.Group var group nbgroup.Group
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group) result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewGroupNotFoundError(groupID) return status.NewGroupNotFoundError(groupID)
} }
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue finding group: %s", result.Error) return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
} }
@ -1213,9 +1056,6 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
group.Peers = append(group.Peers, peerId) group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil { if err := s.db.Save(&group).Error; err != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue updating group: %s", err) return status.Errorf(status.Internal, "issue updating group: %s", err)
} }
@ -1224,16 +1064,11 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
// GetUserPeers retrieves peers for a user. // GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID) return getRecords[*nbpeer.Peer](s.db.Where("user_id = ?", userID), lockStrength, accountID)
} }
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
startTime := time.Now() if err := s.db.Create(peer).Error; err != nil {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
if errors.Is(err, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue adding peer to account: %s", err) return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
} }
@ -1243,7 +1078,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
// GetPeerByID retrieves a peer by its ID and account ID. // GetPeerByID retrieves a peer by its ID and account ID.
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) { func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
var peer *nbpeer.Peer var peer *nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&peer, accountAndIDQueryCondition, accountID, peerID) First(&peer, accountAndIDQueryCondition, accountID, peerID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -1257,21 +1092,16 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
} }
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
startTime := time.Now()
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
} }
return nil return nil
} }
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
tx := s.db.WithContext(ctx).Begin() tx := s.db.Begin()
if tx.Error != nil { if tx.Error != nil {
return tx.Error return tx.Error
} }
@ -1295,18 +1125,13 @@ func (s *SqlStore) GetDB() *gorm.DB {
} }
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) { func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
startTime := time.Now()
var accountDNSSettings AccountDNSSettings var accountDNSSettings AccountDNSSettings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
First(&accountDNSSettings, idQueryCondition, accountID) First(&accountDNSSettings, idQueryCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "dns settings not found") return nil, status.Errorf(status.NotFound, "dns settings not found")
} }
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error) return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
} }
return &accountDNSSettings.DNSSettings, nil return &accountDNSSettings.DNSSettings, nil
@ -1314,18 +1139,13 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
// AccountExists checks whether an account exists by the given ID. // AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
startTime := time.Now()
var accountID string var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Select("id").First(&accountID, idQueryCondition, id) Select("id").First(&accountID, idQueryCondition, id)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return false, nil return false, nil
} }
if errors.Is(result.Error, context.Canceled) {
return false, status.NewStoreContextCanceledError(time.Since(startTime))
}
return false, result.Error return false, result.Error
} }
@ -1334,18 +1154,13 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. // GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
startTime := time.Now()
var account Account var account Account
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category"). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account) Where(idQueryCondition, accountID).First(&account)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", "", status.Errorf(status.NotFound, "account not found") return "", "", status.Errorf(status.NotFound, "account not found")
} }
if errors.Is(result.Error, context.Canceled) {
return "", "", status.NewStoreContextCanceledError(time.Since(startTime))
}
return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error) return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
} }
@ -1431,32 +1246,32 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
// GetAccountPolicies retrieves policies for an account. // GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID)
} }
// GetPolicyByID retrieves a policy by its ID and account ID. // GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID) return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID)
} }
// GetAccountPostureChecks retrieves posture checks for an account. // GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID) return getRecords[*posture.Checks](s.db, lockStrength, accountID)
} }
// GetPostureChecksByID retrieves posture checks by their ID and account ID. // GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID) return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID)
} }
// GetAccountRoutes retrieves network routes for an account. // GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID) return getRecords[*route.Route](s.db, lockStrength, accountID)
} }
// GetRouteByID retrieves a route by its ID and account ID. // GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID) return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID)
} }
// GetAccountSetupKeys retrieves setup keys for an account. // GetAccountSetupKeys retrieves setup keys for an account.
@ -1518,12 +1333,12 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren
// GetAccountNameServerGroups retrieves name server groups for an account. // GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID) return getRecords[*nbdns.NameServerGroup](s.db, lockStrength, accountID)
} }
// GetNameServerGroupByID retrieves a name server group by its ID and account ID. // GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) { func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID) return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID)
} }
// getRecords retrieves records from the database based on the account ID. // getRecords retrieves records from the database based on the account ID.