mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
Check if new account ID is already being used (#364)
This commit is contained in:
parent
35c7cae267
commit
f9f2d7c7ef
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,4 +1,5 @@
|
||||
.idea
|
||||
.run
|
||||
*.iml
|
||||
dist/
|
||||
bin/
|
||||
|
@ -100,12 +100,6 @@ type UserInfo struct {
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// NewAccount creates a new Account with a generated ID and generated default setup keys
|
||||
func NewAccount(userId, domain string) *Account {
|
||||
accountId := xid.New().String()
|
||||
return newAccountWithId(accountId, userId, domain)
|
||||
}
|
||||
|
||||
func (a *Account) Copy() *Account {
|
||||
peers := map[string]*Peer{}
|
||||
for id, peer := range a.Peers {
|
||||
@ -198,6 +192,27 @@ func BuildManager(
|
||||
|
||||
}
|
||||
|
||||
// newAccount creates a new Account with a generated ID and generated default setup keys.
|
||||
// If ID is already in use (due to collision) we try one more time before returning error
|
||||
func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, error) {
|
||||
for i := 0; i < 2; i++ {
|
||||
accountId := xid.New().String()
|
||||
|
||||
_, err := am.Store.GetAccount(accountId)
|
||||
statusErr, _ := status.FromError(err)
|
||||
if err == nil {
|
||||
log.Warnf("an account with ID already exists, retrying...")
|
||||
continue
|
||||
} else if statusErr.Code() == codes.NotFound {
|
||||
return newAccountWithId(accountId, userID, domain), nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.Internal, "error while creating new account")
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) warmupIDPCache() error {
|
||||
userData, err := am.idpManager.GetAllAccounts()
|
||||
if err != nil {
|
||||
@ -368,7 +383,7 @@ func mergeLocalAndQueryUser(queried idp.UserData, local User) *UserInfo {
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) loadFromCache(ctx context.Context, accountID interface{}) (interface{}, error) {
|
||||
func (am *DefaultAccountManager) loadFromCache(_ context.Context, accountID interface{}) (interface{}, error) {
|
||||
return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID))
|
||||
}
|
||||
|
||||
@ -458,8 +473,17 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(
|
||||
primaryDomain bool,
|
||||
) error {
|
||||
account.IsDomainPrimaryAccount = primaryDomain
|
||||
account.Domain = strings.ToLower(claims.Domain)
|
||||
account.DomainCategory = claims.DomainCategory
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
userObj := account.Users[claims.UserId]
|
||||
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
|
||||
account.Domain = lowerDomain
|
||||
}
|
||||
// prevent updating category for different domain until admin logs in
|
||||
if account.Domain == lowerDomain {
|
||||
account.DomainCategory = claims.DomainCategory
|
||||
}
|
||||
|
||||
err := am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed saving updated account")
|
||||
@ -523,7 +547,10 @@ func (am *DefaultAccountManager) handleNewUserAccount(
|
||||
return nil, status.Errorf(codes.Internal, "failed saving updated account")
|
||||
}
|
||||
} else {
|
||||
account = NewAccount(claims.UserId, lowerDomain)
|
||||
account, err = am.newAccount(claims.UserId, lowerDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = am.updateAccountDomainAttributes(account, claims, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -96,7 +96,8 @@ func TestNewAccount(t *testing.T) {
|
||||
|
||||
domain := "netbird.io"
|
||||
userId := "account_creator"
|
||||
account := NewAccount(userId, domain)
|
||||
accountID := "account_id"
|
||||
account := newAccountWithId(accountID, userId, domain)
|
||||
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
||||
}
|
||||
|
||||
|
@ -33,7 +33,7 @@ func TestNewStore(t *testing.T) {
|
||||
func TestSaveAccount(t *testing.T) {
|
||||
store := newStore(t)
|
||||
|
||||
account := NewAccount("testuser", "")
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &Peer{
|
||||
@ -72,7 +72,7 @@ func TestSaveAccount(t *testing.T) {
|
||||
func TestStore(t *testing.T) {
|
||||
store := newStore(t)
|
||||
|
||||
account := NewAccount("testuser", "")
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
account.Peers["testpeer"] = &Peer{
|
||||
Key: "peerkey",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
|
@ -59,7 +59,10 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
|
||||
account, err := am.Store.GetUserAccount(userId)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
account = NewAccount(userId, lowerDomain)
|
||||
account, err = am.newAccount(userId, lowerDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed creating account")
|
||||
@ -70,7 +73,9 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
|
||||
}
|
||||
}
|
||||
|
||||
if account.Domain != lowerDomain {
|
||||
userObj := account.Users[userId]
|
||||
|
||||
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
|
||||
account.Domain = lowerDomain
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user