mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-22 05:49:12 +01:00
Remove redundant accounts All group check on startup
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
bfeb7f0875
commit
4ad00e784c
@ -1052,39 +1052,21 @@ func BuildManager(
|
||||
metrics: metrics,
|
||||
requestBuffer: NewAccountRequestBuffer(ctx, store),
|
||||
}
|
||||
allAccounts := store.GetAllAccounts(ctx)
|
||||
allAccountIDs, err := store.GetAllAccountIDs(ctx, LockingStrengthShare)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
||||
am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1
|
||||
am.singleAccountMode = singleAccountModeDomain != "" && len(allAccountIDs) <= 1
|
||||
if am.singleAccountMode {
|
||||
if !isDomainValid(singleAccountModeDomain) {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
|
||||
}
|
||||
am.singleAccountModeDomain = singleAccountModeDomain
|
||||
log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", len(allAccounts))
|
||||
log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", len(allAccountIDs))
|
||||
} else {
|
||||
log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", len(allAccounts))
|
||||
}
|
||||
|
||||
// if account doesn't have a default group
|
||||
// we create 'all' group and add all peers into it
|
||||
// also we create default rule with source as destination
|
||||
for _, account := range allAccounts {
|
||||
shouldSave := false
|
||||
|
||||
_, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
if err := addAllGroup(account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shouldSave = true
|
||||
}
|
||||
|
||||
if shouldSave {
|
||||
err = store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", len(allAccountIDs))
|
||||
}
|
||||
|
||||
goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute)
|
||||
@ -1290,19 +1272,18 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain
|
||||
for i := 0; i < 2; i++ {
|
||||
accountId := xid.New().String()
|
||||
|
||||
_, err := am.Store.GetAccount(ctx, accountId)
|
||||
statusErr, _ := status.FromError(err)
|
||||
switch {
|
||||
case err == nil:
|
||||
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
|
||||
continue
|
||||
case statusErr.Type() == status.NotFound:
|
||||
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while checking account existence: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
newAccount := newAccountWithId(ctx, accountId, userID, domain)
|
||||
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
|
||||
return newAccount, nil
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.Internal, "error while creating new account")
|
||||
@ -1321,16 +1302,16 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error {
|
||||
// update their AppMetadata with the AccountID.
|
||||
if unsetData, ok := userData[idp.UnsetAccountID]; ok {
|
||||
for _, user := range unsetData {
|
||||
accountID, err := am.Store.GetAccountByUser(ctx, user.ID)
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, user.ID)
|
||||
if err == nil {
|
||||
data := userData[accountID.Id]
|
||||
data := userData[userAccountID]
|
||||
if data == nil {
|
||||
data = make([]*idp.UserData, 0, 1)
|
||||
}
|
||||
|
||||
user.AppMetadata.WTAccountID = accountID.Id
|
||||
user.AppMetadata.WTAccountID = userAccountID
|
||||
|
||||
userData[accountID.Id] = append(data, user)
|
||||
userData[userAccountID] = append(data, user)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1416,7 +1397,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
|
||||
return "", status.Errorf(status.NotFound, "no valid userID provided")
|
||||
}
|
||||
|
||||
accountID, err := am.Store.GetAccountIDByUserID(userID)
|
||||
accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||
@ -1696,9 +1677,6 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
|
||||
return nil
|
||||
}
|
||||
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlockAccount()
|
||||
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||
@ -1716,7 +1694,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
|
||||
}
|
||||
|
||||
newDomain := accountDomain
|
||||
newCategoty := domainCategory
|
||||
newCategory := domainCategory
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
if accountDomain != lowerDomain && user.HasAdminPower() {
|
||||
@ -1724,10 +1702,10 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
|
||||
}
|
||||
|
||||
if accountDomain == lowerDomain {
|
||||
newCategoty = claims.DomainCategory
|
||||
newCategory = claims.DomainCategory
|
||||
}
|
||||
|
||||
return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
|
||||
return am.Store.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, accountID, newDomain, newCategory, primaryDomain)
|
||||
}
|
||||
|
||||
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
|
||||
@ -2163,7 +2141,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
||||
return "", err
|
||||
}
|
||||
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
return "", err
|
||||
@ -2209,7 +2187,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
return "", err
|
||||
|
@ -411,7 +411,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
addedByUser := false
|
||||
if len(userID) > 0 {
|
||||
addedByUser = true
|
||||
accountID, err = am.Store.GetAccountIDByUserID(userID)
|
||||
accountID, err = am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID)
|
||||
} else {
|
||||
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
|
||||
}
|
||||
|
@ -324,7 +324,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, a
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
|
||||
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain bool) error {
|
||||
accountCopy := Account{
|
||||
Domain: domain,
|
||||
DomainCategory: category,
|
||||
@ -332,7 +332,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Select(fieldsToUpdate).
|
||||
Where(idQueryCondition, accountID).
|
||||
Updates(&accountCopy)
|
||||
@ -563,6 +563,18 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
|
||||
return all
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAllAccountIDs(ctx context.Context, lockStrength LockingStrength) ([]string, error) {
|
||||
var accountIDs []string
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Model(&Account{}).Pluck("id", &accountIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account IDs from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get account IDs from store")
|
||||
}
|
||||
|
||||
return accountIDs, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
@ -704,14 +716,15 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) {
|
||||
var accountID string
|
||||
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&User{}).
|
||||
Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.NewAccountNotFoundError()
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed to get accountID from the store: %s", result.Error)
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
|
@ -47,8 +47,9 @@ type Store interface {
|
||||
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
||||
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
||||
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
||||
GetAllAccountIDs(ctx context.Context, lockStrength LockingStrength) ([]string, error)
|
||||
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountIDByUserID(userID string) (string, error)
|
||||
GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
|
||||
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
||||
@ -62,7 +63,7 @@ type Store interface {
|
||||
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
|
||||
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
|
Loading…
Reference in New Issue
Block a user