Remove redundant accounts All group check on startup

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-11-03 18:49:16 +03:00
parent bfeb7f0875
commit 4ad00e784c
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
4 changed files with 48 additions and 56 deletions

View File

@ -1052,39 +1052,21 @@ func BuildManager(
metrics: metrics,
requestBuffer: NewAccountRequestBuffer(ctx, store),
}
allAccounts := store.GetAllAccounts(ctx)
allAccountIDs, err := store.GetAllAccountIDs(ctx, LockingStrengthShare)
if err != nil {
return nil, err
}
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1
am.singleAccountMode = singleAccountModeDomain != "" && len(allAccountIDs) <= 1
if am.singleAccountMode {
if !isDomainValid(singleAccountModeDomain) {
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
}
am.singleAccountModeDomain = singleAccountModeDomain
log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", len(allAccounts))
log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", len(allAccountIDs))
} else {
log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", len(allAccounts))
}
// if account doesn't have a default group
// we create 'all' group and add all peers into it
// also we create default rule with source as destination
for _, account := range allAccounts {
shouldSave := false
_, err := account.GetGroupAll()
if err != nil {
if err := addAllGroup(account); err != nil {
return nil, err
}
shouldSave = true
}
if shouldSave {
err = store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
}
log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", len(allAccountIDs))
}
goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute)
@ -1290,19 +1272,18 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain
for i := 0; i < 2; i++ {
accountId := xid.New().String()
_, err := am.Store.GetAccount(ctx, accountId)
statusErr, _ := status.FromError(err)
switch {
case err == nil:
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
continue
case statusErr.Type() == status.NotFound:
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountId)
if err != nil {
log.WithContext(ctx).Errorf("error while checking account existence: %v", err)
return nil, err
}
if !exists {
newAccount := newAccountWithId(ctx, accountId, userID, domain)
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
return newAccount, nil
default:
return nil, err
}
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
}
return nil, status.Errorf(status.Internal, "error while creating new account")
@ -1321,16 +1302,16 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error {
// update their AppMetadata with the AccountID.
if unsetData, ok := userData[idp.UnsetAccountID]; ok {
for _, user := range unsetData {
accountID, err := am.Store.GetAccountByUser(ctx, user.ID)
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, user.ID)
if err == nil {
data := userData[accountID.Id]
data := userData[userAccountID]
if data == nil {
data = make([]*idp.UserData, 0, 1)
}
user.AppMetadata.WTAccountID = accountID.Id
user.AppMetadata.WTAccountID = userAccountID
userData[accountID.Id] = append(data, user)
userData[userAccountID] = append(data, user)
}
}
}
@ -1416,7 +1397,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
return "", status.Errorf(status.NotFound, "no valid userID provided")
}
accountID, err := am.Store.GetAccountIDByUserID(userID)
accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
@ -1696,9 +1677,6 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
return nil
}
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlockAccount()
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
@ -1716,7 +1694,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
}
newDomain := accountDomain
newCategoty := domainCategory
newCategory := domainCategory
lowerDomain := strings.ToLower(claims.Domain)
if accountDomain != lowerDomain && user.HasAdminPower() {
@ -1724,10 +1702,10 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
}
if accountDomain == lowerDomain {
newCategoty = claims.DomainCategory
newCategory = claims.DomainCategory
}
return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
return am.Store.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, accountID, newDomain, newCategory, primaryDomain)
}
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
@ -2163,7 +2141,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", err
}
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
@ -2209,7 +2187,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
}
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err

View File

@ -411,7 +411,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
addedByUser := false
if len(userID) > 0 {
addedByUser = true
accountID, err = am.Store.GetAccountIDByUserID(userID)
accountID, err = am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID)
} else {
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
}

View File

@ -324,7 +324,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, a
return nil
}
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain bool) error {
accountCopy := Account{
Domain: domain,
DomainCategory: category,
@ -332,7 +332,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
}
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
result := s.db.WithContext(ctx).Model(&Account{}).
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Select(fieldsToUpdate).
Where(idQueryCondition, accountID).
Updates(&accountCopy)
@ -563,6 +563,18 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
return all
}
func (s *SqlStore) GetAllAccountIDs(ctx context.Context, lockStrength LockingStrength) ([]string, error) {
var accountIDs []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Model(&Account{}).Pluck("id", &accountIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get account IDs from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get account IDs from store")
}
return accountIDs, nil
}
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) {
start := time.Now()
defer func() {
@ -704,14 +716,15 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
return accountID, nil
}
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) {
var accountID string
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&User{}).
Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.NewAccountNotFoundError()
}
log.WithContext(ctx).Errorf("failed to get accountID from the store: %s", result.Error)
return "", status.NewGetAccountFromStoreError(result.Error)
}

View File

@ -47,8 +47,9 @@ type Store interface {
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
GetAllAccountIDs(ctx context.Context, lockStrength LockingStrength) ([]string, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByUserID(userID string) (string, error)
GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
@ -62,7 +63,7 @@ type Store interface {
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain bool) error
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)