mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-27 05:01:43 +02:00
[management] support account retrieval and creation by private domain (#3825)
* [management] sys initiator save user (#3911) * [management] activity events with multiple external account users (#3914)
This commit is contained in:
parent
0cd36baf67
commit
87148c503f
@ -1730,23 +1730,26 @@ func (am *DefaultAccountManager) GetStore() store.Store {
|
|||||||
return am.Store
|
return am.Store
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates account by private domain.
|
func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) {
|
||||||
// Expects domain value to be a valid and a private dns domain.
|
|
||||||
func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
|
|
||||||
cancel := am.Store.AcquireGlobalLock(ctx)
|
cancel := am.Store.AcquireGlobalLock(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
domain = strings.ToLower(domain)
|
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
|
||||||
|
if handleNotFound(err) != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
count, err := am.Store.CountAccountsByPrivateDomain(ctx, domain)
|
// a primary account already exists for this private domain
|
||||||
|
if err == nil {
|
||||||
|
existingAccount, err := am.Store.GetAccount(ctx, existingPrimaryAccountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if count > 0 {
|
return existingAccount, false, nil
|
||||||
return nil, status.Errorf(status.InvalidArgument, "account with private domain already exists")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create a new account for this private domain
|
||||||
// retry twice for new ID clashes
|
// retry twice for new ID clashes
|
||||||
for range 2 {
|
for range 2 {
|
||||||
accountId := xid.New().String()
|
accountId := xid.New().String()
|
||||||
@ -1776,7 +1779,7 @@ func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Contex
|
|||||||
Users: users,
|
Users: users,
|
||||||
// @todo check if using the MSP owner id here is ok
|
// @todo check if using the MSP owner id here is ok
|
||||||
CreatedBy: initiatorId,
|
CreatedBy: initiatorId,
|
||||||
Domain: domain,
|
Domain: strings.ToLower(domain),
|
||||||
DomainCategory: types.PrivateCategory,
|
DomainCategory: types.PrivateCategory,
|
||||||
IsDomainPrimaryAccount: false,
|
IsDomainPrimaryAccount: false,
|
||||||
Routes: routes,
|
Routes: routes,
|
||||||
@ -1795,19 +1798,22 @@ func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Contex
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := newAccount.AddAllGroup(); err != nil {
|
if err := newAccount.AddAllGroup(); err != nil {
|
||||||
return nil, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
|
return nil, false, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := am.Store.SaveAccount(ctx, newAccount); err != nil {
|
if err := am.Store.SaveAccount(ctx, newAccount); err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to save new account %s by private domain: %v", newAccount.Id, err)
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
return nil, err
|
"accountId": newAccount.Id,
|
||||||
|
"domain": domain,
|
||||||
|
}).Errorf("failed to create new account: %v", err)
|
||||||
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, initiatorId, newAccount.Id, accountId, activity.AccountCreated, nil)
|
am.StoreEvent(ctx, initiatorId, newAccount.Id, accountId, activity.AccountCreated, nil)
|
||||||
return newAccount, nil
|
return newAccount, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, status.Errorf(status.Internal, "failed to create new account by private domain")
|
return nil, false, status.Errorf(status.Internal, "failed to get or create new account by private domain")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
|
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
|
||||||
@ -1820,21 +1826,29 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
|
|||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// additional check to ensure there is only one account for this domain at the time of update
|
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)
|
||||||
count, err := am.Store.CountAccountsByPrivateDomain(ctx, account.Domain)
|
|
||||||
if err != nil {
|
// error is not a not found error
|
||||||
|
if handleNotFound(err) != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if count > 1 {
|
// a primary account already exists for this private domain
|
||||||
return nil, status.Errorf(status.Internal, "more than one account exists with the same private domain")
|
if err == nil {
|
||||||
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
|
"accountId": accountId,
|
||||||
|
"existingAccountId": existingPrimaryAccountID,
|
||||||
|
}).Errorf("cannot update account to primary, another account already exists as primary for the same domain")
|
||||||
|
return nil, status.Errorf(status.Internal, "cannot update account to primary")
|
||||||
}
|
}
|
||||||
|
|
||||||
account.IsDomainPrimaryAccount = true
|
account.IsDomainPrimaryAccount = true
|
||||||
|
|
||||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to update primary account %s by private domain: %v", account.Id, err)
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
return nil, status.Errorf(status.Internal, "failed to update primary account %s by private domain", account.Id)
|
"accountId": accountId,
|
||||||
|
}).Errorf("failed to update account to primary: %v", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to update account to primary")
|
||||||
}
|
}
|
||||||
|
|
||||||
return account, nil
|
return account, nil
|
||||||
|
@ -113,7 +113,7 @@ type Manager interface {
|
|||||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||||
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||||
GetStore() store.Store
|
GetStore() store.Store
|
||||||
CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error)
|
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
|
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
|
||||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||||
|
@ -14,7 +14,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -25,6 +24,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/cache"
|
"github.com/netbirdio/netbird/management/server/cache"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
@ -3198,7 +3198,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_CreateAccountByPrivateDomain(t *testing.T) {
|
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -3209,9 +3209,10 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) {
|
|||||||
initiatorId := "test-user"
|
initiatorId := "test-user"
|
||||||
domain := "example.com"
|
domain := "example.com"
|
||||||
|
|
||||||
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.True(t, created)
|
||||||
assert.False(t, account.IsDomainPrimaryAccount)
|
assert.False(t, account.IsDomainPrimaryAccount)
|
||||||
assert.Equal(t, domain, account.Domain)
|
assert.Equal(t, domain, account.Domain)
|
||||||
assert.Equal(t, types.PrivateCategory, account.DomainCategory)
|
assert.Equal(t, types.PrivateCategory, account.DomainCategory)
|
||||||
@ -3220,9 +3221,25 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) {
|
|||||||
assert.Equal(t, 0, len(account.Users))
|
assert.Equal(t, 0, len(account.Users))
|
||||||
assert.Equal(t, 0, len(account.SetupKeys))
|
assert.Equal(t, 0, len(account.SetupKeys))
|
||||||
|
|
||||||
// retry should fail
|
// should return a new account because the previous one is not primary
|
||||||
_, err = manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||||
assert.Error(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.True(t, created2)
|
||||||
|
assert.False(t, account2.IsDomainPrimaryAccount)
|
||||||
|
assert.Equal(t, domain, account2.Domain)
|
||||||
|
assert.Equal(t, types.PrivateCategory, account2.DomainCategory)
|
||||||
|
assert.Equal(t, initiatorId, account2.CreatedBy)
|
||||||
|
assert.Equal(t, 1, len(account2.Groups))
|
||||||
|
assert.Equal(t, 0, len(account2.Users))
|
||||||
|
assert.Equal(t, 0, len(account2.SetupKeys))
|
||||||
|
|
||||||
|
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, account.IsDomainPrimaryAccount)
|
||||||
|
|
||||||
|
_, err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
|
||||||
|
assert.Error(t, err, "should not be able to update a second account to primary")
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_UpdateToPrimaryAccount(t *testing.T) {
|
func Test_UpdateToPrimaryAccount(t *testing.T) {
|
||||||
@ -3236,14 +3253,21 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
|
|||||||
initiatorId := "test-user"
|
initiatorId := "test-user"
|
||||||
domain := "example.com"
|
domain := "example.com"
|
||||||
|
|
||||||
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, created)
|
||||||
assert.False(t, account.IsDomainPrimaryAccount)
|
assert.False(t, account.IsDomainPrimaryAccount)
|
||||||
|
assert.Equal(t, domain, account.Domain)
|
||||||
|
|
||||||
// retry should fail
|
|
||||||
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, account.IsDomainPrimaryAccount)
|
assert.True(t, account.IsDomainPrimaryAccount)
|
||||||
|
|
||||||
|
account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, created2)
|
||||||
|
assert.True(t, account.IsDomainPrimaryAccount)
|
||||||
|
assert.Equal(t, account.Id, account2.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
||||||
|
@ -143,11 +143,10 @@ func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events [
|
|||||||
return eventUserInfos, nil
|
return eventUserInfos, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos, userId)
|
return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo, userId string) (map[string]eventUserInfo, error) {
|
func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo) (map[string]eventUserInfo, error) {
|
||||||
externalAccountId := ""
|
|
||||||
fetched := make(map[string]struct{})
|
fetched := make(map[string]struct{})
|
||||||
externalUsers := []*types.User{}
|
externalUsers := []*types.User{}
|
||||||
for _, id := range externalUserIds {
|
for _, id := range externalUserIds {
|
||||||
@ -161,25 +160,20 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if externalAccountId != "" && externalAccountId != externalUser.AccountID {
|
|
||||||
return nil, fmt.Errorf("multiple external user accounts in events")
|
|
||||||
}
|
|
||||||
|
|
||||||
if externalAccountId == "" {
|
|
||||||
externalAccountId = externalUser.AccountID
|
|
||||||
}
|
|
||||||
|
|
||||||
fetched[id] = struct{}{}
|
fetched[id] = struct{}{}
|
||||||
externalUsers = append(externalUsers, externalUser)
|
externalUsers = append(externalUsers, externalUser)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if we couldn't determine an account, return what we have
|
usersByExternalAccount := map[string][]*types.User{}
|
||||||
if externalAccountId == "" {
|
for _, u := range externalUsers {
|
||||||
log.WithContext(ctx).Warnf("failed to determine external user account from users: %v", externalUserIds)
|
if _, ok := usersByExternalAccount[u.AccountID]; !ok {
|
||||||
return eventUserInfos, nil
|
usersByExternalAccount[u.AccountID] = make([]*types.User, 0)
|
||||||
|
}
|
||||||
|
usersByExternalAccount[u.AccountID] = append(usersByExternalAccount[u.AccountID], u)
|
||||||
}
|
}
|
||||||
|
|
||||||
externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, userId, externalUsers)
|
for externalAccountId, externalUsers := range usersByExternalAccount {
|
||||||
|
externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, "", externalUsers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -191,6 +185,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
|
|||||||
accountId: externalAccountId,
|
accountId: externalAccountId,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return eventUserInfos, nil
|
return eventUserInfos, nil
|
||||||
}
|
}
|
||||||
|
@ -113,11 +113,12 @@ type MockAccountManager struct {
|
|||||||
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||||
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||||
GetStoreFunc func() store.Store
|
GetStoreFunc func() store.Store
|
||||||
CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error)
|
|
||||||
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
|
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
|
||||||
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
||||||
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||||
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
||||||
|
|
||||||
|
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||||
@ -862,11 +863,11 @@ func (am *MockAccountManager) GetStore() store.Store {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
|
func (am *MockAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) {
|
||||||
if am.CreateAccountByPrivateDomainFunc != nil {
|
if am.GetOrCreateAccountByPrivateDomainFunc != nil {
|
||||||
return am.CreateAccountByPrivateDomainFunc(ctx, initiatorId, domain)
|
return am.GetOrCreateAccountByPrivateDomainFunc(ctx, initiatorId, domain)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateAccountByPrivateDomain is not implemented")
|
return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
|
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
|
||||||
|
@ -531,10 +531,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
groupsMap[group.ID] = group
|
groupsMap[group.ID] = group
|
||||||
}
|
}
|
||||||
|
|
||||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
var initiatorUser *types.User
|
||||||
|
if initiatorUserID != activity.SystemInitiator {
|
||||||
|
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
initiatorUser = result
|
||||||
|
}
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
for _, update := range updates {
|
for _, update := range updates {
|
||||||
@ -543,7 +547,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
}
|
}
|
||||||
|
|
||||||
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
|
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
|
||||||
ctx, transaction, groupsMap, accountID, initiatorUser, update, addIfNotExists, settings,
|
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to process user update: %w", err)
|
return fmt.Errorf("failed to process user update: %w", err)
|
||||||
@ -629,7 +633,7 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transaction store.Store, groupsMap map[string]*types.Group,
|
func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transaction store.Store, groupsMap map[string]*types.Group,
|
||||||
accountID string, initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) {
|
accountID, initiatorUserId string, initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) {
|
||||||
|
|
||||||
if update == nil {
|
if update == nil {
|
||||||
return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||||
@ -653,10 +657,12 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
|||||||
updatedUser.Issued = update.Issued
|
updatedUser.Issued = update.Issued
|
||||||
updatedUser.IntegrationReference = update.IntegrationReference
|
updatedUser.IntegrationReference = update.IntegrationReference
|
||||||
|
|
||||||
transferredOwnerRole, err := handleOwnerRoleTransfer(ctx, transaction, initiatorUser, update)
|
var transferredOwnerRole bool
|
||||||
|
result, err := handleOwnerRoleTransfer(ctx, transaction, initiatorUser, update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, nil, nil, nil, err
|
return false, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
transferredOwnerRole = result
|
||||||
|
|
||||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthUpdate, updatedUser.AccountID, update.Id)
|
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthUpdate, updatedUser.AccountID, update.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -682,7 +688,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
|||||||
}
|
}
|
||||||
|
|
||||||
updateAccountPeers := len(userPeers) > 0
|
updateAccountPeers := len(userPeers) > 0
|
||||||
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUser.Id, oldUser, updatedUser, transferredOwnerRole)
|
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole)
|
||||||
|
|
||||||
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
|
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
|
||||||
}
|
}
|
||||||
@ -709,7 +715,7 @@ func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, ac
|
|||||||
}
|
}
|
||||||
|
|
||||||
func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) {
|
func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) {
|
||||||
if initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner {
|
if initiatorUser != nil && initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner {
|
||||||
newInitiatorUser := initiatorUser.Copy()
|
newInitiatorUser := initiatorUser.Copy()
|
||||||
newInitiatorUser.Role = types.UserRoleAdmin
|
newInitiatorUser.Role = types.UserRoleAdmin
|
||||||
|
|
||||||
@ -737,6 +743,10 @@ func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.Us
|
|||||||
|
|
||||||
// validateUserUpdate validates the update operation for a user.
|
// validateUserUpdate validates the update operation for a user.
|
||||||
func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUser, update *types.User) error {
|
func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUser, update *types.User) error {
|
||||||
|
if initiatorUser == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// @todo double check these
|
// @todo double check these
|
||||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
|
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
|
||||||
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||||
@ -818,10 +828,14 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
|||||||
return nil, status.NewPermissionValidationError(err)
|
return nil, status.NewPermissionValidationError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
var user *types.User
|
||||||
|
if initiatorUserID != activity.SystemInitiator {
|
||||||
|
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||||
}
|
}
|
||||||
|
user = result
|
||||||
|
}
|
||||||
|
|
||||||
accountUsers := []*types.User{}
|
accountUsers := []*types.User{}
|
||||||
switch {
|
switch {
|
||||||
@ -830,7 +844,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
case user.AccountID == accountID:
|
case user != nil && user.AccountID == accountID:
|
||||||
accountUsers = append(accountUsers, user)
|
accountUsers = append(accountUsers, user)
|
||||||
default:
|
default:
|
||||||
return map[string]*types.UserInfo{}, nil
|
return map[string]*types.UserInfo{}, nil
|
||||||
|
Loading…
x
Reference in New Issue
Block a user