mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 16:54:16 +01:00
Check if new account ID is already being used (#364)
This commit is contained in:
parent
35c7cae267
commit
f9f2d7c7ef
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,4 +1,5 @@
|
|||||||
.idea
|
.idea
|
||||||
|
.run
|
||||||
*.iml
|
*.iml
|
||||||
dist/
|
dist/
|
||||||
bin/
|
bin/
|
||||||
|
@ -100,12 +100,6 @@ type UserInfo struct {
|
|||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccount creates a new Account with a generated ID and generated default setup keys
|
|
||||||
func NewAccount(userId, domain string) *Account {
|
|
||||||
accountId := xid.New().String()
|
|
||||||
return newAccountWithId(accountId, userId, domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) Copy() *Account {
|
func (a *Account) Copy() *Account {
|
||||||
peers := map[string]*Peer{}
|
peers := map[string]*Peer{}
|
||||||
for id, peer := range a.Peers {
|
for id, peer := range a.Peers {
|
||||||
@ -198,6 +192,27 @@ func BuildManager(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newAccount creates a new Account with a generated ID and generated default setup keys.
|
||||||
|
// If ID is already in use (due to collision) we try one more time before returning error
|
||||||
|
func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, error) {
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
accountId := xid.New().String()
|
||||||
|
|
||||||
|
_, err := am.Store.GetAccount(accountId)
|
||||||
|
statusErr, _ := status.FromError(err)
|
||||||
|
if err == nil {
|
||||||
|
log.Warnf("an account with ID already exists, retrying...")
|
||||||
|
continue
|
||||||
|
} else if statusErr.Code() == codes.NotFound {
|
||||||
|
return newAccountWithId(accountId, userID, domain), nil
|
||||||
|
} else {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, status.Errorf(codes.Internal, "error while creating new account")
|
||||||
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) warmupIDPCache() error {
|
func (am *DefaultAccountManager) warmupIDPCache() error {
|
||||||
userData, err := am.idpManager.GetAllAccounts()
|
userData, err := am.idpManager.GetAllAccounts()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -368,7 +383,7 @@ func mergeLocalAndQueryUser(queried idp.UserData, local User) *UserInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) loadFromCache(ctx context.Context, accountID interface{}) (interface{}, error) {
|
func (am *DefaultAccountManager) loadFromCache(_ context.Context, accountID interface{}) (interface{}, error) {
|
||||||
return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID))
|
return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -458,8 +473,17 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(
|
|||||||
primaryDomain bool,
|
primaryDomain bool,
|
||||||
) error {
|
) error {
|
||||||
account.IsDomainPrimaryAccount = primaryDomain
|
account.IsDomainPrimaryAccount = primaryDomain
|
||||||
account.Domain = strings.ToLower(claims.Domain)
|
|
||||||
|
lowerDomain := strings.ToLower(claims.Domain)
|
||||||
|
userObj := account.Users[claims.UserId]
|
||||||
|
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
|
||||||
|
account.Domain = lowerDomain
|
||||||
|
}
|
||||||
|
// prevent updating category for different domain until admin logs in
|
||||||
|
if account.Domain == lowerDomain {
|
||||||
account.DomainCategory = claims.DomainCategory
|
account.DomainCategory = claims.DomainCategory
|
||||||
|
}
|
||||||
|
|
||||||
err := am.Store.SaveAccount(account)
|
err := am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.Internal, "failed saving updated account")
|
return status.Errorf(codes.Internal, "failed saving updated account")
|
||||||
@ -523,7 +547,10 @@ func (am *DefaultAccountManager) handleNewUserAccount(
|
|||||||
return nil, status.Errorf(codes.Internal, "failed saving updated account")
|
return nil, status.Errorf(codes.Internal, "failed saving updated account")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
account = NewAccount(claims.UserId, lowerDomain)
|
account, err = am.newAccount(claims.UserId, lowerDomain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
err = am.updateAccountDomainAttributes(account, claims, true)
|
err = am.updateAccountDomainAttributes(account, claims, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -96,7 +96,8 @@ func TestNewAccount(t *testing.T) {
|
|||||||
|
|
||||||
domain := "netbird.io"
|
domain := "netbird.io"
|
||||||
userId := "account_creator"
|
userId := "account_creator"
|
||||||
account := NewAccount(userId, domain)
|
accountID := "account_id"
|
||||||
|
account := newAccountWithId(accountID, userId, domain)
|
||||||
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ func TestNewStore(t *testing.T) {
|
|||||||
func TestSaveAccount(t *testing.T) {
|
func TestSaveAccount(t *testing.T) {
|
||||||
store := newStore(t)
|
store := newStore(t)
|
||||||
|
|
||||||
account := NewAccount("testuser", "")
|
account := newAccountWithId("account_id", "testuser", "")
|
||||||
setupKey := GenerateDefaultSetupKey()
|
setupKey := GenerateDefaultSetupKey()
|
||||||
account.SetupKeys[setupKey.Key] = setupKey
|
account.SetupKeys[setupKey.Key] = setupKey
|
||||||
account.Peers["testpeer"] = &Peer{
|
account.Peers["testpeer"] = &Peer{
|
||||||
@ -72,7 +72,7 @@ func TestSaveAccount(t *testing.T) {
|
|||||||
func TestStore(t *testing.T) {
|
func TestStore(t *testing.T) {
|
||||||
store := newStore(t)
|
store := newStore(t)
|
||||||
|
|
||||||
account := NewAccount("testuser", "")
|
account := newAccountWithId("account_id", "testuser", "")
|
||||||
account.Peers["testpeer"] = &Peer{
|
account.Peers["testpeer"] = &Peer{
|
||||||
Key: "peerkey",
|
Key: "peerkey",
|
||||||
SetupKey: "peerkeysetupkey",
|
SetupKey: "peerkeysetupkey",
|
||||||
|
@ -59,7 +59,10 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
|
|||||||
account, err := am.Store.GetUserAccount(userId)
|
account, err := am.Store.GetUserAccount(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
account = NewAccount(userId, lowerDomain)
|
account, err = am.newAccount(userId, lowerDomain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
err = am.Store.SaveAccount(account)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "failed creating account")
|
return nil, status.Errorf(codes.Internal, "failed creating account")
|
||||||
@ -70,7 +73,9 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Domain != lowerDomain {
|
userObj := account.Users[userId]
|
||||||
|
|
||||||
|
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
|
||||||
account.Domain = lowerDomain
|
account.Domain = lowerDomain
|
||||||
err = am.Store.SaveAccount(account)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user