[management] Add transaction to addPeer (#2469)

This PR removes the GetAccount and SaveAccount operations from the AddPeer and instead makes use of gorm.Transaction to add the new peer.
This commit is contained in:
pascal-fischer
2024-09-16 15:47:03 +02:00
committed by GitHub
parent 730dd1733e
commit 6c50b0c84b
13 changed files with 1095 additions and 216 deletions

View File

@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
@ -33,6 +34,7 @@ import (
const (
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
peerNotFoundFMT = "peer %s not found"
)
@ -415,13 +417,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
var key SetupKey
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting setup key from store")
return nil, status.NewSetupKeyNotFoundError()
}
if key.AccountID == "" {
@ -474,15 +475,15 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return &user, nil
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User
result := s.db.First(&user, idQueryCondition, userID)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
return nil, status.NewUserNotFoundError(userID)
}
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting user from store")
return nil, status.NewGetUserFromStoreError()
}
return &user, nil
@ -535,7 +536,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found")
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@ -595,7 +596,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
var user User
result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
@ -612,12 +613,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@ -631,12 +631,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@ -650,12 +649,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
var peer nbpeer.Peer
var accountID string
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
@ -677,61 +675,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
}
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
var key SetupKey
var accountID string
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting setup key from store")
return "", status.NewSetupKeyNotFoundError()
}
if accountID == "" {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return accountID, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) {
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
var ipJSONStrings []string
// Fetch the IP addresses as JSON strings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("ip", &ipJSONStrings)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
return nil, status.Errorf(status.Internal, "issue getting IPs from store")
}
// Convert the JSON strings to net.IP objects
ips := make([]net.IP, len(ipJSONStrings))
for i, ipJSON := range ipJSONStrings {
var ip net.IP
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
}
ips[i] = ip
}
return ips, nil
}
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
var labels []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("dns_label", &labels)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
}
return labels, nil
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
var accountNetwork AccountNetwork
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting network from store")
}
return accountNetwork.Network, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
var peer nbpeer.Peer
result := s.db.First(&peer, "key = ?", peerKey)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting peer from store")
}
return &peer, nil
}
func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) {
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
var accountSettings AccountSettings
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err)
return nil, status.Errorf(status.Internal, "issue getting settings from store")
}
return accountSettings.Settings, nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user User
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "user %s not found", userID)
return status.NewUserNotFoundError(userID)
}
return status.Errorf(status.Internal, "issue getting user from store")
return status.NewGetUserFromStoreError()
}
user.LastLogin = lastLogin
return s.db.Save(user).Error
return s.db.Save(&user).Error
}
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
@ -850,3 +904,123 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
return store, nil
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
var setupKey SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, keyQueryCondition, strings.ToUpper(key))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return nil, status.NewSetupKeyNotFoundError()
}
return &setupKey, nil
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
result := s.db.WithContext(ctx).Model(&SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
"last_used": time.Now(),
})
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "setup key not found")
}
return nil
}
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
return status.Errorf(status.Internal, "issue finding group 'All'")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All'")
}
return nil
}
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account")
}
return status.Errorf(status.Internal, "issue finding group")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerId {
return nil
}
}
group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group")
}
return nil
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account")
}
return nil
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing network serial count")
}
return nil
}
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
tx := s.db.WithContext(ctx).Begin()
if tx.Error != nil {
return tx.Error
}
repo := s.withTx(tx)
err := operation(repo)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
func (s *SqlStore) withTx(tx *gorm.DB) Store {
return &SqlStore{
db: tx,
}
}