Check if new account ID is already being used (#364)

This commit is contained in:
Maycon Santos 2022-06-20 18:20:43 +02:00 committed by GitHub
parent 35c7cae267
commit f9f2d7c7ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 49 additions and 15 deletions

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
.idea .idea
.run
*.iml *.iml
dist/ dist/
bin/ bin/

View File

@ -100,12 +100,6 @@ type UserInfo struct {
Role string `json:"role"` Role string `json:"role"`
} }
// NewAccount creates a new Account with a generated ID and generated default setup keys
func NewAccount(userId, domain string) *Account {
accountId := xid.New().String()
return newAccountWithId(accountId, userId, domain)
}
func (a *Account) Copy() *Account { func (a *Account) Copy() *Account {
peers := map[string]*Peer{} peers := map[string]*Peer{}
for id, peer := range a.Peers { for id, peer := range a.Peers {
@ -198,6 +192,27 @@ func BuildManager(
} }
// newAccount creates a new Account with a generated ID and generated default setup keys.
// If ID is already in use (due to collision) we try one more time before returning error
func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, error) {
for i := 0; i < 2; i++ {
accountId := xid.New().String()
_, err := am.Store.GetAccount(accountId)
statusErr, _ := status.FromError(err)
if err == nil {
log.Warnf("an account with ID already exists, retrying...")
continue
} else if statusErr.Code() == codes.NotFound {
return newAccountWithId(accountId, userID, domain), nil
} else {
return nil, err
}
}
return nil, status.Errorf(codes.Internal, "error while creating new account")
}
func (am *DefaultAccountManager) warmupIDPCache() error { func (am *DefaultAccountManager) warmupIDPCache() error {
userData, err := am.idpManager.GetAllAccounts() userData, err := am.idpManager.GetAllAccounts()
if err != nil { if err != nil {
@ -368,7 +383,7 @@ func mergeLocalAndQueryUser(queried idp.UserData, local User) *UserInfo {
} }
} }
func (am *DefaultAccountManager) loadFromCache(ctx context.Context, accountID interface{}) (interface{}, error) { func (am *DefaultAccountManager) loadFromCache(_ context.Context, accountID interface{}) (interface{}, error) {
return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID)) return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID))
} }
@ -458,8 +473,17 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(
primaryDomain bool, primaryDomain bool,
) error { ) error {
account.IsDomainPrimaryAccount = primaryDomain account.IsDomainPrimaryAccount = primaryDomain
account.Domain = strings.ToLower(claims.Domain)
account.DomainCategory = claims.DomainCategory lowerDomain := strings.ToLower(claims.Domain)
userObj := account.Users[claims.UserId]
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
account.Domain = lowerDomain
}
// prevent updating category for different domain until admin logs in
if account.Domain == lowerDomain {
account.DomainCategory = claims.DomainCategory
}
err := am.Store.SaveAccount(account) err := am.Store.SaveAccount(account)
if err != nil { if err != nil {
return status.Errorf(codes.Internal, "failed saving updated account") return status.Errorf(codes.Internal, "failed saving updated account")
@ -523,7 +547,10 @@ func (am *DefaultAccountManager) handleNewUserAccount(
return nil, status.Errorf(codes.Internal, "failed saving updated account") return nil, status.Errorf(codes.Internal, "failed saving updated account")
} }
} else { } else {
account = NewAccount(claims.UserId, lowerDomain) account, err = am.newAccount(claims.UserId, lowerDomain)
if err != nil {
return nil, err
}
err = am.updateAccountDomainAttributes(account, claims, true) err = am.updateAccountDomainAttributes(account, claims, true)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -96,7 +96,8 @@ func TestNewAccount(t *testing.T) {
domain := "netbird.io" domain := "netbird.io"
userId := "account_creator" userId := "account_creator"
account := NewAccount(userId, domain) accountID := "account_id"
account := newAccountWithId(accountID, userId, domain)
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId}) verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
} }

View File

@ -33,7 +33,7 @@ func TestNewStore(t *testing.T) {
func TestSaveAccount(t *testing.T) { func TestSaveAccount(t *testing.T) {
store := newStore(t) store := newStore(t)
account := NewAccount("testuser", "") account := newAccountWithId("account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey() setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &Peer{ account.Peers["testpeer"] = &Peer{
@ -72,7 +72,7 @@ func TestSaveAccount(t *testing.T) {
func TestStore(t *testing.T) { func TestStore(t *testing.T) {
store := newStore(t) store := newStore(t)
account := NewAccount("testuser", "") account := newAccountWithId("account_id", "testuser", "")
account.Peers["testpeer"] = &Peer{ account.Peers["testpeer"] = &Peer{
Key: "peerkey", Key: "peerkey",
SetupKey: "peerkeysetupkey", SetupKey: "peerkeysetupkey",

View File

@ -59,7 +59,10 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
account, err := am.Store.GetUserAccount(userId) account, err := am.Store.GetUserAccount(userId)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
account = NewAccount(userId, lowerDomain) account, err = am.newAccount(userId, lowerDomain)
if err != nil {
return nil, err
}
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed creating account") return nil, status.Errorf(codes.Internal, "failed creating account")
@ -70,7 +73,9 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
} }
} }
if account.Domain != lowerDomain { userObj := account.Users[userId]
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
account.Domain = lowerDomain account.Domain = lowerDomain
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {