[management] support account retrieval and creation by private domain (#3825)

* [management] sys initiator save user (#3911)

* [management] activity events with multiple external account users (#3914)
This commit is contained in:
Pedro Maia Costa 2025-06-04 11:21:31 +01:00 committed by GitHub
parent 0cd36baf67
commit 87148c503f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 121 additions and 73 deletions

View File

@ -1730,23 +1730,26 @@ func (am *DefaultAccountManager) GetStore() store.Store {
return am.Store
}
// Creates account by private domain.
// Expects domain value to be a valid and a private dns domain.
func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) {
cancel := am.Store.AcquireGlobalLock(ctx)
defer cancel()
domain = strings.ToLower(domain)
count, err := am.Store.CountAccountsByPrivateDomain(ctx, domain)
if err != nil {
return nil, err
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
if handleNotFound(err) != nil {
return nil, false, err
}
if count > 0 {
return nil, status.Errorf(status.InvalidArgument, "account with private domain already exists")
// a primary account already exists for this private domain
if err == nil {
existingAccount, err := am.Store.GetAccount(ctx, existingPrimaryAccountID)
if err != nil {
return nil, false, err
}
return existingAccount, false, nil
}
// create a new account for this private domain
// retry twice for new ID clashes
for range 2 {
accountId := xid.New().String()
@ -1776,7 +1779,7 @@ func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Contex
Users: users,
// @todo check if using the MSP owner id here is ok
CreatedBy: initiatorId,
Domain: domain,
Domain: strings.ToLower(domain),
DomainCategory: types.PrivateCategory,
IsDomainPrimaryAccount: false,
Routes: routes,
@ -1795,19 +1798,22 @@ func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Contex
}
if err := newAccount.AddAllGroup(); err != nil {
return nil, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
return nil, false, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
}
if err := am.Store.SaveAccount(ctx, newAccount); err != nil {
log.WithContext(ctx).Errorf("failed to save new account %s by private domain: %v", newAccount.Id, err)
return nil, err
log.WithContext(ctx).WithFields(log.Fields{
"accountId": newAccount.Id,
"domain": domain,
}).Errorf("failed to create new account: %v", err)
return nil, false, err
}
am.StoreEvent(ctx, initiatorId, newAccount.Id, accountId, activity.AccountCreated, nil)
return newAccount, nil
return newAccount, true, nil
}
return nil, status.Errorf(status.Internal, "failed to create new account by private domain")
return nil, false, status.Errorf(status.Internal, "failed to get or create new account by private domain")
}
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
@ -1820,21 +1826,29 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return account, nil
}
// additional check to ensure there is only one account for this domain at the time of update
count, err := am.Store.CountAccountsByPrivateDomain(ctx, account.Domain)
if err != nil {
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)
// error is not a not found error
if handleNotFound(err) != nil {
return nil, err
}
if count > 1 {
return nil, status.Errorf(status.Internal, "more than one account exists with the same private domain")
// a primary account already exists for this private domain
if err == nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
"existingAccountId": existingPrimaryAccountID,
}).Errorf("cannot update account to primary, another account already exists as primary for the same domain")
return nil, status.Errorf(status.Internal, "cannot update account to primary")
}
account.IsDomainPrimaryAccount = true
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to update primary account %s by private domain: %v", account.Id, err)
return nil, status.Errorf(status.Internal, "failed to update primary account %s by private domain", account.Id)
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
}).Errorf("failed to update account to primary: %v", err)
return nil, status.Errorf(status.Internal, "failed to update account to primary")
}
return account, nil

View File

@ -113,7 +113,7 @@ type Manager interface {
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
GetStore() store.Store
CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error)
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)

View File

@ -14,7 +14,6 @@ import (
"time"
"github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/management/server/idp"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -25,6 +24,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@ -3198,7 +3198,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
}
}
func Test_CreateAccountByPrivateDomain(t *testing.T) {
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
@ -3209,9 +3209,10 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) {
initiatorId := "test-user"
domain := "example.com"
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)
assert.True(t, created)
assert.False(t, account.IsDomainPrimaryAccount)
assert.Equal(t, domain, account.Domain)
assert.Equal(t, types.PrivateCategory, account.DomainCategory)
@ -3220,9 +3221,25 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) {
assert.Equal(t, 0, len(account.Users))
assert.Equal(t, 0, len(account.SetupKeys))
// retry should fail
_, err = manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.Error(t, err)
// should return a new account because the previous one is not primary
account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)
assert.True(t, created2)
assert.False(t, account2.IsDomainPrimaryAccount)
assert.Equal(t, domain, account2.Domain)
assert.Equal(t, types.PrivateCategory, account2.DomainCategory)
assert.Equal(t, initiatorId, account2.CreatedBy)
assert.Equal(t, 1, len(account2.Groups))
assert.Equal(t, 0, len(account2.Users))
assert.Equal(t, 0, len(account2.SetupKeys))
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount)
_, err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
assert.Error(t, err, "should not be able to update a second account to primary")
}
func Test_UpdateToPrimaryAccount(t *testing.T) {
@ -3236,14 +3253,21 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
initiatorId := "test-user"
domain := "example.com"
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)
assert.True(t, created)
assert.False(t, account.IsDomainPrimaryAccount)
assert.Equal(t, domain, account.Domain)
// retry should fail
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount)
account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)
assert.False(t, created2)
assert.True(t, account.IsDomainPrimaryAccount)
assert.Equal(t, account.Id, account2.Id)
}
func TestDefaultAccountManager_IsCacheCold(t *testing.T) {

View File

@ -143,11 +143,10 @@ func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events [
return eventUserInfos, nil
}
return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos, userId)
return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos)
}
func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo, userId string) (map[string]eventUserInfo, error) {
externalAccountId := ""
func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo) (map[string]eventUserInfo, error) {
fetched := make(map[string]struct{})
externalUsers := []*types.User{}
for _, id := range externalUserIds {
@ -161,34 +160,30 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
continue
}
if externalAccountId != "" && externalAccountId != externalUser.AccountID {
return nil, fmt.Errorf("multiple external user accounts in events")
}
if externalAccountId == "" {
externalAccountId = externalUser.AccountID
}
fetched[id] = struct{}{}
externalUsers = append(externalUsers, externalUser)
}
// if we couldn't determine an account, return what we have
if externalAccountId == "" {
log.WithContext(ctx).Warnf("failed to determine external user account from users: %v", externalUserIds)
return eventUserInfos, nil
usersByExternalAccount := map[string][]*types.User{}
for _, u := range externalUsers {
if _, ok := usersByExternalAccount[u.AccountID]; !ok {
usersByExternalAccount[u.AccountID] = make([]*types.User, 0)
}
usersByExternalAccount[u.AccountID] = append(usersByExternalAccount[u.AccountID], u)
}
externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, userId, externalUsers)
if err != nil {
return nil, err
}
for externalAccountId, externalUsers := range usersByExternalAccount {
externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, "", externalUsers)
if err != nil {
return nil, err
}
for i, k := range externalUserInfos {
eventUserInfos[i] = eventUserInfo{
email: k.Email,
name: k.Name,
accountId: externalAccountId,
for i, k := range externalUserInfos {
eventUserInfos[i] = eventUserInfo{
email: k.Email,
name: k.Name,
accountId: externalAccountId,
}
}
}

View File

@ -113,11 +113,12 @@ type MockAccountManager struct {
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
GetStoreFunc func() store.Store
CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error)
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
}
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
@ -862,11 +863,11 @@ func (am *MockAccountManager) GetStore() store.Store {
return nil
}
func (am *MockAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
if am.CreateAccountByPrivateDomainFunc != nil {
return am.CreateAccountByPrivateDomainFunc(ctx, initiatorId, domain)
func (am *MockAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) {
if am.GetOrCreateAccountByPrivateDomainFunc != nil {
return am.GetOrCreateAccountByPrivateDomainFunc(ctx, initiatorId, domain)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateAccountByPrivateDomain is not implemented")
return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented")
}
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {

View File

@ -531,9 +531,13 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
groupsMap[group.ID] = group
}
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
if err != nil {
return nil, err
var initiatorUser *types.User
if initiatorUserID != activity.SystemInitiator {
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
if err != nil {
return nil, err
}
initiatorUser = result
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@ -543,7 +547,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
ctx, transaction, groupsMap, accountID, initiatorUser, update, addIfNotExists, settings,
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
)
if err != nil {
return fmt.Errorf("failed to process user update: %w", err)
@ -629,7 +633,7 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac
}
func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transaction store.Store, groupsMap map[string]*types.Group,
accountID string, initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) {
accountID, initiatorUserId string, initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) {
if update == nil {
return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
@ -653,10 +657,12 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
updatedUser.Issued = update.Issued
updatedUser.IntegrationReference = update.IntegrationReference
transferredOwnerRole, err := handleOwnerRoleTransfer(ctx, transaction, initiatorUser, update)
var transferredOwnerRole bool
result, err := handleOwnerRoleTransfer(ctx, transaction, initiatorUser, update)
if err != nil {
return false, nil, nil, nil, err
}
transferredOwnerRole = result
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthUpdate, updatedUser.AccountID, update.Id)
if err != nil {
@ -682,7 +688,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
}
updateAccountPeers := len(userPeers) > 0
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUser.Id, oldUser, updatedUser, transferredOwnerRole)
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole)
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
}
@ -709,7 +715,7 @@ func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, ac
}
func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) {
if initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner {
if initiatorUser != nil && initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner {
newInitiatorUser := initiatorUser.Copy()
newInitiatorUser.Role = types.UserRoleAdmin
@ -737,6 +743,10 @@ func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.Us
// validateUserUpdate validates the update operation for a user.
func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUser, update *types.User) error {
if initiatorUser == nil {
return nil
}
// @todo double check these
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
@ -818,9 +828,13 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
return nil, status.NewPermissionValidationError(err)
}
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
var user *types.User
if initiatorUserID != activity.SystemInitiator {
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
user = result
}
accountUsers := []*types.User{}
@ -830,7 +844,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
if err != nil {
return nil, err
}
case user.AccountID == accountID:
case user != nil && user.AccountID == accountID:
accountUsers = append(accountUsers, user)
default:
return map[string]*types.UserInfo{}, nil