From 87148c503f8c236c39492d05bba7681473811d43 Mon Sep 17 00:00:00 2001 From: Pedro Maia Costa <550684+pnmcosta@users.noreply.github.com> Date: Wed, 4 Jun 2025 11:21:31 +0100 Subject: [PATCH] [management] support account retrieval and creation by private domain (#3825) * [management] sys initiator save user (#3911) * [management] activity events with multiple external account users (#3914) --- management/server/account.go | 60 ++++++++++++------- management/server/account/manager.go | 2 +- management/server/account_test.go | 40 ++++++++++--- management/server/event.go | 43 ++++++------- management/server/mock_server/account_mock.go | 11 ++-- management/server/user.go | 38 ++++++++---- 6 files changed, 121 insertions(+), 73 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 033ec5fa1..63879802a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1730,23 +1730,26 @@ func (am *DefaultAccountManager) GetStore() store.Store { return am.Store } -// Creates account by private domain. -// Expects domain value to be a valid and a private dns domain. -func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) { +func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) { cancel := am.Store.AcquireGlobalLock(ctx) defer cancel() - domain = strings.ToLower(domain) - - count, err := am.Store.CountAccountsByPrivateDomain(ctx, domain) - if err != nil { - return nil, err + existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) + if handleNotFound(err) != nil { + return nil, false, err } - if count > 0 { - return nil, status.Errorf(status.InvalidArgument, "account with private domain already exists") + // a primary account already exists for this private domain + if err == nil { + existingAccount, err := am.Store.GetAccount(ctx, existingPrimaryAccountID) + if err != nil { + return nil, false, err + } + + return existingAccount, false, nil } + // create a new account for this private domain // retry twice for new ID clashes for range 2 { accountId := xid.New().String() @@ -1776,7 +1779,7 @@ func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Contex Users: users, // @todo check if using the MSP owner id here is ok CreatedBy: initiatorId, - Domain: domain, + Domain: strings.ToLower(domain), DomainCategory: types.PrivateCategory, IsDomainPrimaryAccount: false, Routes: routes, @@ -1795,19 +1798,22 @@ func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Contex } if err := newAccount.AddAllGroup(); err != nil { - return nil, status.Errorf(status.Internal, "failed to add all group to new account by private domain") + return nil, false, status.Errorf(status.Internal, "failed to add all group to new account by private domain") } if err := am.Store.SaveAccount(ctx, newAccount); err != nil { - log.WithContext(ctx).Errorf("failed to save new account %s by private domain: %v", newAccount.Id, err) - return nil, err + log.WithContext(ctx).WithFields(log.Fields{ + "accountId": newAccount.Id, + "domain": domain, + }).Errorf("failed to create new account: %v", err) + return nil, false, err } am.StoreEvent(ctx, initiatorId, newAccount.Id, accountId, activity.AccountCreated, nil) - return newAccount, nil + return newAccount, true, nil } - return nil, status.Errorf(status.Internal, "failed to create new account by private domain") + return nil, false, status.Errorf(status.Internal, "failed to get or create new account by private domain") } func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { @@ -1820,21 +1826,29 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc return account, nil } - // additional check to ensure there is only one account for this domain at the time of update - count, err := am.Store.CountAccountsByPrivateDomain(ctx, account.Domain) - if err != nil { + existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain) + + // error is not a not found error + if handleNotFound(err) != nil { return nil, err } - if count > 1 { - return nil, status.Errorf(status.Internal, "more than one account exists with the same private domain") + // a primary account already exists for this private domain + if err == nil { + log.WithContext(ctx).WithFields(log.Fields{ + "accountId": accountId, + "existingAccountId": existingPrimaryAccountID, + }).Errorf("cannot update account to primary, another account already exists as primary for the same domain") + return nil, status.Errorf(status.Internal, "cannot update account to primary") } account.IsDomainPrimaryAccount = true if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to update primary account %s by private domain: %v", account.Id, err) - return nil, status.Errorf(status.Internal, "failed to update primary account %s by private domain", account.Id) + log.WithContext(ctx).WithFields(log.Fields{ + "accountId": accountId, + }).Errorf("failed to update account to primary: %v", err) + return nil, status.Errorf(status.Internal, "failed to update account to primary") } return account, nil diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 9bc4f9605..030bd94ef 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -113,7 +113,7 @@ type Manager interface { BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error GetStore() store.Store - CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) + GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index c5583d226..5ada28ca3 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -14,7 +14,6 @@ import ( "time" "github.com/golang/mock/gomock" - "github.com/netbirdio/netbird/management/server/idp" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,6 +24,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" @@ -3198,7 +3198,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { } } -func Test_CreateAccountByPrivateDomain(t *testing.T) { +func Test_GetCreateAccountByPrivateDomain(t *testing.T) { manager, err := createManager(t) if err != nil { t.Fatal(err) @@ -3209,9 +3209,10 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) { initiatorId := "test-user" domain := "example.com" - account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain) + account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain) assert.NoError(t, err) + assert.True(t, created) assert.False(t, account.IsDomainPrimaryAccount) assert.Equal(t, domain, account.Domain) assert.Equal(t, types.PrivateCategory, account.DomainCategory) @@ -3220,9 +3221,25 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) { assert.Equal(t, 0, len(account.Users)) assert.Equal(t, 0, len(account.SetupKeys)) - // retry should fail - _, err = manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain) - assert.Error(t, err) + // should return a new account because the previous one is not primary + account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain) + assert.NoError(t, err) + + assert.True(t, created2) + assert.False(t, account2.IsDomainPrimaryAccount) + assert.Equal(t, domain, account2.Domain) + assert.Equal(t, types.PrivateCategory, account2.DomainCategory) + assert.Equal(t, initiatorId, account2.CreatedBy) + assert.Equal(t, 1, len(account2.Groups)) + assert.Equal(t, 0, len(account2.Users)) + assert.Equal(t, 0, len(account2.SetupKeys)) + + account, err = manager.UpdateToPrimaryAccount(ctx, account.Id) + assert.NoError(t, err) + assert.True(t, account.IsDomainPrimaryAccount) + + _, err = manager.UpdateToPrimaryAccount(ctx, account2.Id) + assert.Error(t, err, "should not be able to update a second account to primary") } func Test_UpdateToPrimaryAccount(t *testing.T) { @@ -3236,14 +3253,21 @@ func Test_UpdateToPrimaryAccount(t *testing.T) { initiatorId := "test-user" domain := "example.com" - account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain) + account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain) assert.NoError(t, err) + assert.True(t, created) assert.False(t, account.IsDomainPrimaryAccount) + assert.Equal(t, domain, account.Domain) - // retry should fail account, err = manager.UpdateToPrimaryAccount(ctx, account.Id) assert.NoError(t, err) assert.True(t, account.IsDomainPrimaryAccount) + + account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain) + assert.NoError(t, err) + assert.False(t, created2) + assert.True(t, account.IsDomainPrimaryAccount) + assert.Equal(t, account.Id, account2.Id) } func TestDefaultAccountManager_IsCacheCold(t *testing.T) { diff --git a/management/server/event.go b/management/server/event.go index 2952edc8c..d94714e2c 100644 --- a/management/server/event.go +++ b/management/server/event.go @@ -143,11 +143,10 @@ func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events [ return eventUserInfos, nil } - return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos, userId) + return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos) } -func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo, userId string) (map[string]eventUserInfo, error) { - externalAccountId := "" +func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo) (map[string]eventUserInfo, error) { fetched := make(map[string]struct{}) externalUsers := []*types.User{} for _, id := range externalUserIds { @@ -161,34 +160,30 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, continue } - if externalAccountId != "" && externalAccountId != externalUser.AccountID { - return nil, fmt.Errorf("multiple external user accounts in events") - } - - if externalAccountId == "" { - externalAccountId = externalUser.AccountID - } - fetched[id] = struct{}{} externalUsers = append(externalUsers, externalUser) } - // if we couldn't determine an account, return what we have - if externalAccountId == "" { - log.WithContext(ctx).Warnf("failed to determine external user account from users: %v", externalUserIds) - return eventUserInfos, nil + usersByExternalAccount := map[string][]*types.User{} + for _, u := range externalUsers { + if _, ok := usersByExternalAccount[u.AccountID]; !ok { + usersByExternalAccount[u.AccountID] = make([]*types.User, 0) + } + usersByExternalAccount[u.AccountID] = append(usersByExternalAccount[u.AccountID], u) } - externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, userId, externalUsers) - if err != nil { - return nil, err - } + for externalAccountId, externalUsers := range usersByExternalAccount { + externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, "", externalUsers) + if err != nil { + return nil, err + } - for i, k := range externalUserInfos { - eventUserInfos[i] = eventUserInfo{ - email: k.Email, - name: k.Name, - accountId: externalAccountId, + for i, k := range externalUserInfos { + eventUserInfos[i] = eventUserInfo{ + email: k.Email, + name: k.Name, + accountId: externalAccountId, + } } } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 0dd3f927e..ed47d3914 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -113,11 +113,12 @@ type MockAccountManager struct { DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) GetStoreFunc func() store.Store - CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error) UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) + + GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) } func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { @@ -862,11 +863,11 @@ func (am *MockAccountManager) GetStore() store.Store { return nil } -func (am *MockAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) { - if am.CreateAccountByPrivateDomainFunc != nil { - return am.CreateAccountByPrivateDomainFunc(ctx, initiatorId, domain) +func (am *MockAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) { + if am.GetOrCreateAccountByPrivateDomainFunc != nil { + return am.GetOrCreateAccountByPrivateDomainFunc(ctx, initiatorId, domain) } - return nil, status.Errorf(codes.Unimplemented, "method CreateAccountByPrivateDomain is not implemented") + return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented") } func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { diff --git a/management/server/user.go b/management/server/user.go index 5c162c50b..6d780cda3 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -531,9 +531,13 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, groupsMap[group.ID] = group } - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) - if err != nil { - return nil, err + var initiatorUser *types.User + if initiatorUserID != activity.SystemInitiator { + result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) + if err != nil { + return nil, err + } + initiatorUser = result } err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -543,7 +547,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate( - ctx, transaction, groupsMap, accountID, initiatorUser, update, addIfNotExists, settings, + ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings, ) if err != nil { return fmt.Errorf("failed to process user update: %w", err) @@ -629,7 +633,7 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac } func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transaction store.Store, groupsMap map[string]*types.Group, - accountID string, initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) { + accountID, initiatorUserId string, initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) { if update == nil { return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil") @@ -653,10 +657,12 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact updatedUser.Issued = update.Issued updatedUser.IntegrationReference = update.IntegrationReference - transferredOwnerRole, err := handleOwnerRoleTransfer(ctx, transaction, initiatorUser, update) + var transferredOwnerRole bool + result, err := handleOwnerRoleTransfer(ctx, transaction, initiatorUser, update) if err != nil { return false, nil, nil, nil, err } + transferredOwnerRole = result userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthUpdate, updatedUser.AccountID, update.Id) if err != nil { @@ -682,7 +688,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } updateAccountPeers := len(userPeers) > 0 - userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUser.Id, oldUser, updatedUser, transferredOwnerRole) + userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole) return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil } @@ -709,7 +715,7 @@ func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, ac } func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) { - if initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner { + if initiatorUser != nil && initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner { newInitiatorUser := initiatorUser.Copy() newInitiatorUser.Role = types.UserRoleAdmin @@ -737,6 +743,10 @@ func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.Us // validateUserUpdate validates the update operation for a user. func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUser, update *types.User) error { + if initiatorUser == nil { + return nil + } + // @todo double check these if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") @@ -818,9 +828,13 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun return nil, status.NewPermissionValidationError(err) } - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) - if err != nil { - return nil, fmt.Errorf("failed to get user: %w", err) + var user *types.User + if initiatorUserID != activity.SystemInitiator { + result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) + if err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } + user = result } accountUsers := []*types.User{} @@ -830,7 +844,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun if err != nil { return nil, err } - case user.AccountID == accountID: + case user != nil && user.AccountID == accountID: accountUsers = append(accountUsers, user) default: return map[string]*types.UserInfo{}, nil