mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-27 05:01:43 +02:00
[management] support account retrieval and creation by private domain (#3825)
* [management] sys initiator save user (#3911) * [management] activity events with multiple external account users (#3914)
This commit is contained in:
parent
0cd36baf67
commit
87148c503f
@ -1730,23 +1730,26 @@ func (am *DefaultAccountManager) GetStore() store.Store {
|
||||
return am.Store
|
||||
}
|
||||
|
||||
// Creates account by private domain.
|
||||
// Expects domain value to be a valid and a private dns domain.
|
||||
func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
|
||||
func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) {
|
||||
cancel := am.Store.AcquireGlobalLock(ctx)
|
||||
defer cancel()
|
||||
|
||||
domain = strings.ToLower(domain)
|
||||
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
|
||||
if handleNotFound(err) != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
count, err := am.Store.CountAccountsByPrivateDomain(ctx, domain)
|
||||
// a primary account already exists for this private domain
|
||||
if err == nil {
|
||||
existingAccount, err := am.Store.GetAccount(ctx, existingPrimaryAccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return nil, status.Errorf(status.InvalidArgument, "account with private domain already exists")
|
||||
return existingAccount, false, nil
|
||||
}
|
||||
|
||||
// create a new account for this private domain
|
||||
// retry twice for new ID clashes
|
||||
for range 2 {
|
||||
accountId := xid.New().String()
|
||||
@ -1776,7 +1779,7 @@ func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Contex
|
||||
Users: users,
|
||||
// @todo check if using the MSP owner id here is ok
|
||||
CreatedBy: initiatorId,
|
||||
Domain: domain,
|
||||
Domain: strings.ToLower(domain),
|
||||
DomainCategory: types.PrivateCategory,
|
||||
IsDomainPrimaryAccount: false,
|
||||
Routes: routes,
|
||||
@ -1795,19 +1798,22 @@ func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Contex
|
||||
}
|
||||
|
||||
if err := newAccount.AddAllGroup(); err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
|
||||
return nil, false, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
|
||||
}
|
||||
|
||||
if err := am.Store.SaveAccount(ctx, newAccount); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save new account %s by private domain: %v", newAccount.Id, err)
|
||||
return nil, err
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"accountId": newAccount.Id,
|
||||
"domain": domain,
|
||||
}).Errorf("failed to create new account: %v", err)
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, initiatorId, newAccount.Id, accountId, activity.AccountCreated, nil)
|
||||
return newAccount, nil
|
||||
return newAccount, true, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.Internal, "failed to create new account by private domain")
|
||||
return nil, false, status.Errorf(status.Internal, "failed to get or create new account by private domain")
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
|
||||
@ -1820,21 +1826,29 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// additional check to ensure there is only one account for this domain at the time of update
|
||||
count, err := am.Store.CountAccountsByPrivateDomain(ctx, account.Domain)
|
||||
if err != nil {
|
||||
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)
|
||||
|
||||
// error is not a not found error
|
||||
if handleNotFound(err) != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
return nil, status.Errorf(status.Internal, "more than one account exists with the same private domain")
|
||||
// a primary account already exists for this private domain
|
||||
if err == nil {
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"accountId": accountId,
|
||||
"existingAccountId": existingPrimaryAccountID,
|
||||
}).Errorf("cannot update account to primary, another account already exists as primary for the same domain")
|
||||
return nil, status.Errorf(status.Internal, "cannot update account to primary")
|
||||
}
|
||||
|
||||
account.IsDomainPrimaryAccount = true
|
||||
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update primary account %s by private domain: %v", account.Id, err)
|
||||
return nil, status.Errorf(status.Internal, "failed to update primary account %s by private domain", account.Id)
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"accountId": accountId,
|
||||
}).Errorf("failed to update account to primary: %v", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to update account to primary")
|
||||
}
|
||||
|
||||
return account, nil
|
||||
|
@ -113,7 +113,7 @@ type Manager interface {
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||
GetStore() store.Store
|
||||
CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error)
|
||||
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
|
||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
|
@ -14,7 +14,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -25,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/cache"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
@ -3198,7 +3198,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateAccountByPrivateDomain(t *testing.T) {
|
||||
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -3209,9 +3209,10 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) {
|
||||
initiatorId := "test-user"
|
||||
domain := "example.com"
|
||||
|
||||
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, created)
|
||||
assert.False(t, account.IsDomainPrimaryAccount)
|
||||
assert.Equal(t, domain, account.Domain)
|
||||
assert.Equal(t, types.PrivateCategory, account.DomainCategory)
|
||||
@ -3220,9 +3221,25 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) {
|
||||
assert.Equal(t, 0, len(account.Users))
|
||||
assert.Equal(t, 0, len(account.SetupKeys))
|
||||
|
||||
// retry should fail
|
||||
_, err = manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.Error(t, err)
|
||||
// should return a new account because the previous one is not primary
|
||||
account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, created2)
|
||||
assert.False(t, account2.IsDomainPrimaryAccount)
|
||||
assert.Equal(t, domain, account2.Domain)
|
||||
assert.Equal(t, types.PrivateCategory, account2.DomainCategory)
|
||||
assert.Equal(t, initiatorId, account2.CreatedBy)
|
||||
assert.Equal(t, 1, len(account2.Groups))
|
||||
assert.Equal(t, 0, len(account2.Users))
|
||||
assert.Equal(t, 0, len(account2.SetupKeys))
|
||||
|
||||
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, account.IsDomainPrimaryAccount)
|
||||
|
||||
_, err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
|
||||
assert.Error(t, err, "should not be able to update a second account to primary")
|
||||
}
|
||||
|
||||
func Test_UpdateToPrimaryAccount(t *testing.T) {
|
||||
@ -3236,14 +3253,21 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
|
||||
initiatorId := "test-user"
|
||||
domain := "example.com"
|
||||
|
||||
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, created)
|
||||
assert.False(t, account.IsDomainPrimaryAccount)
|
||||
assert.Equal(t, domain, account.Domain)
|
||||
|
||||
// retry should fail
|
||||
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, account.IsDomainPrimaryAccount)
|
||||
|
||||
account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, created2)
|
||||
assert.True(t, account.IsDomainPrimaryAccount)
|
||||
assert.Equal(t, account.Id, account2.Id)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
||||
|
@ -143,11 +143,10 @@ func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events [
|
||||
return eventUserInfos, nil
|
||||
}
|
||||
|
||||
return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos, userId)
|
||||
return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo, userId string) (map[string]eventUserInfo, error) {
|
||||
externalAccountId := ""
|
||||
func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo) (map[string]eventUserInfo, error) {
|
||||
fetched := make(map[string]struct{})
|
||||
externalUsers := []*types.User{}
|
||||
for _, id := range externalUserIds {
|
||||
@ -161,25 +160,20 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
|
||||
continue
|
||||
}
|
||||
|
||||
if externalAccountId != "" && externalAccountId != externalUser.AccountID {
|
||||
return nil, fmt.Errorf("multiple external user accounts in events")
|
||||
}
|
||||
|
||||
if externalAccountId == "" {
|
||||
externalAccountId = externalUser.AccountID
|
||||
}
|
||||
|
||||
fetched[id] = struct{}{}
|
||||
externalUsers = append(externalUsers, externalUser)
|
||||
}
|
||||
|
||||
// if we couldn't determine an account, return what we have
|
||||
if externalAccountId == "" {
|
||||
log.WithContext(ctx).Warnf("failed to determine external user account from users: %v", externalUserIds)
|
||||
return eventUserInfos, nil
|
||||
usersByExternalAccount := map[string][]*types.User{}
|
||||
for _, u := range externalUsers {
|
||||
if _, ok := usersByExternalAccount[u.AccountID]; !ok {
|
||||
usersByExternalAccount[u.AccountID] = make([]*types.User, 0)
|
||||
}
|
||||
usersByExternalAccount[u.AccountID] = append(usersByExternalAccount[u.AccountID], u)
|
||||
}
|
||||
|
||||
externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, userId, externalUsers)
|
||||
for externalAccountId, externalUsers := range usersByExternalAccount {
|
||||
externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, "", externalUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -191,6 +185,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
|
||||
accountId: externalAccountId,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return eventUserInfos, nil
|
||||
}
|
||||
|
@ -113,11 +113,12 @@ type MockAccountManager struct {
|
||||
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
GetStoreFunc func() store.Store
|
||||
CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error)
|
||||
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
|
||||
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
||||
|
||||
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
@ -862,11 +863,11 @@ func (am *MockAccountManager) GetStore() store.Store {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
|
||||
if am.CreateAccountByPrivateDomainFunc != nil {
|
||||
return am.CreateAccountByPrivateDomainFunc(ctx, initiatorId, domain)
|
||||
func (am *MockAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) {
|
||||
if am.GetOrCreateAccountByPrivateDomainFunc != nil {
|
||||
return am.GetOrCreateAccountByPrivateDomainFunc(ctx, initiatorId, domain)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateAccountByPrivateDomain is not implemented")
|
||||
return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
|
||||
|
@ -531,10 +531,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||
var initiatorUser *types.User
|
||||
if initiatorUserID != activity.SystemInitiator {
|
||||
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
initiatorUser = result
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, update := range updates {
|
||||
@ -543,7 +547,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
|
||||
ctx, transaction, groupsMap, accountID, initiatorUser, update, addIfNotExists, settings,
|
||||
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process user update: %w", err)
|
||||
@ -629,7 +633,7 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transaction store.Store, groupsMap map[string]*types.Group,
|
||||
accountID string, initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) {
|
||||
accountID, initiatorUserId string, initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) {
|
||||
|
||||
if update == nil {
|
||||
return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
@ -653,10 +657,12 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
updatedUser.Issued = update.Issued
|
||||
updatedUser.IntegrationReference = update.IntegrationReference
|
||||
|
||||
transferredOwnerRole, err := handleOwnerRoleTransfer(ctx, transaction, initiatorUser, update)
|
||||
var transferredOwnerRole bool
|
||||
result, err := handleOwnerRoleTransfer(ctx, transaction, initiatorUser, update)
|
||||
if err != nil {
|
||||
return false, nil, nil, nil, err
|
||||
}
|
||||
transferredOwnerRole = result
|
||||
|
||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthUpdate, updatedUser.AccountID, update.Id)
|
||||
if err != nil {
|
||||
@ -682,7 +688,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
}
|
||||
|
||||
updateAccountPeers := len(userPeers) > 0
|
||||
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUser.Id, oldUser, updatedUser, transferredOwnerRole)
|
||||
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole)
|
||||
|
||||
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
|
||||
}
|
||||
@ -709,7 +715,7 @@ func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, ac
|
||||
}
|
||||
|
||||
func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) {
|
||||
if initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner {
|
||||
if initiatorUser != nil && initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner {
|
||||
newInitiatorUser := initiatorUser.Copy()
|
||||
newInitiatorUser.Role = types.UserRoleAdmin
|
||||
|
||||
@ -737,6 +743,10 @@ func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.Us
|
||||
|
||||
// validateUserUpdate validates the update operation for a user.
|
||||
func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUser, update *types.User) error {
|
||||
if initiatorUser == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// @todo double check these
|
||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
|
||||
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||
@ -818,10 +828,14 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||
var user *types.User
|
||||
if initiatorUserID != activity.SystemInitiator {
|
||||
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
user = result
|
||||
}
|
||||
|
||||
accountUsers := []*types.User{}
|
||||
switch {
|
||||
@ -830,7 +844,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case user.AccountID == accountID:
|
||||
case user != nil && user.AccountID == accountID:
|
||||
accountUsers = append(accountUsers, user)
|
||||
default:
|
||||
return map[string]*types.UserInfo{}, nil
|
||||
|
Loading…
x
Reference in New Issue
Block a user