diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index bcec2472f..31bff26cb 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -91,13 +92,13 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) - + permissionsManagerMock := permissions.NewManagerMock() ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) - accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager) + accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 72e7c6d1c..352abd62b 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -49,6 +49,7 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -1438,6 +1439,8 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) + permissionsManagerMock := permissions.NewManagerMock() + ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) @@ -1446,7 +1449,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri Return(&types.Settings{}, nil). AnyTimes() - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager) + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index a765cceb5..8ee8294cf 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -9,10 +9,11 @@ import ( "github.com/golang/mock/gomock" "github.com/netbirdio/management-integrations/integrations" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" @@ -23,6 +24,7 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -198,11 +200,12 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) + permissionsManagerMock := permissions.NewManagerMock() ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager) + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) if err != nil { return nil, "", err } diff --git a/go.mod b/go.mod index db70dfe79..af800282e 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250327214345-49bce94ab4d7 + github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 6fbf353bd..25891fbf9 100644 --- a/go.sum +++ b/go.sum @@ -490,8 +490,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250327214345-49bce94ab4d7 h1:Quma+ju/eiI6/p6XcHO9rBUtj4gdBPyA6AVIBym6Q0Y= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250327214345-49bce94ab4d7/go.mod h1:3LvBPnW+i06K9fQr1SYwsbhvnxQHtIC8vvO4PjLmmy0= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203 h1:uxxbLPXQgC9VO15epNPtrD6zazyd5rZeqC5hQSmCdZU= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203/go.mod h1:2ZE6/tBBCKHQggPfO2UOQjyjXI7k+JDVl2ymorTOVQs= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= diff --git a/management/client/client_test.go b/management/client/client_test.go index 24204688d..6c30ff371 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -74,6 +75,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) + permissionsManagerMock := permissions.NewManagerMock() ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) @@ -87,7 +89,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { Return(&types.Settings{}, nil). AnyTimes() - accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager) + accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index f0b8d5d12..d6735f955 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -52,7 +52,6 @@ import ( "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -203,15 +202,14 @@ var ( return fmt.Errorf("failed to initialize integrated peer validator: %v", err) } + permissionsManager := integrations.InitPermissionsManager(store) userManager := users.NewManager(store) extraSettingsManager := integrations.NewManager(eventStore) - settingsManager := settings.NewManager(store, userManager, extraSettingsManager) - permissionsManager := permissions.NewManager(userManager, settingsManager) + settingsManager := settings.NewManager(store, userManager, extraSettingsManager, permissionsManager) peersManager := peers.NewManager(store, permissionsManager) proxyController := integrations.NewController(store) - accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, - dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager) + dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager) if err != nil { return fmt.Errorf("failed to build default manager: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index 0567a0d78..0b52df2f0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -29,6 +29,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/status" @@ -89,6 +90,8 @@ type DefaultAccountManager struct { integratedPeerValidator integrated_validator.IntegratedValidator metrics telemetry.AppMetrics + + permissionsManager permissions.Manager } // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. @@ -156,6 +159,7 @@ func BuildManager( metrics telemetry.AppMetrics, proxyController port_forwarding.Controller, settingsManager settings.Manager, + permissionsManager permissions.Manager, ) (*DefaultAccountManager, error) { start := time.Now() defer func() { @@ -180,6 +184,7 @@ func BuildManager( requestBuffer: NewAccountRequestBuffer(ctx, store), proxyController: proxyController, settingsManager: settingsManager, + permissionsManager: permissionsManager, } accountsCounter, err := store.GetAccountsCounter(ctx) if err != nil { @@ -253,13 +258,13 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, err } - user, err := account.FindUser(userID) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Settings, permissions.Write) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to validate user permissions: %w", err) } - if !user.HasAdminPower() { - return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") + if !allowed { + return nil, status.NewPermissionDeniedError() } err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) @@ -503,16 +508,12 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return err } - user, err := account.FindUser(userID) + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Accounts, permissions.Write) if err != nil { - return err + return fmt.Errorf("failed to validate user permissions: %w", err) } - if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, "user is not allowed to delete account") - } - - if user.Role != types.UserRoleOwner { + if !allowed { return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account") } @@ -542,14 +543,12 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u } userInfo, ok := userInfosMap[userID] - if !ok { - return status.Errorf(status.NotFound, "user info not found for user %s", userID) - } - - _, err = am.deleteRegularUser(ctx, accountID, userID, userInfo) - if err != nil { - log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err) - return err + if ok { + _, err = am.deleteRegularUser(ctx, accountID, userID, userInfo) + if err != nil { + log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err) + return err + } } err = am.Store.DeleteAccount(ctx, account) @@ -1027,8 +1026,8 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s return nil, err } - if user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } return am.Store.GetAccount(ctx, accountID) @@ -1061,8 +1060,8 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u return accountID, user.Id, nil } - if user.AccountID != accountID { - return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", userAuth.UserId, accountID) + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return "", "", err } if !user.IsServiceUser && userAuth.Invited { @@ -1521,7 +1520,11 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account return nil, err } - if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err + } + + if !user.HasAdminPower() && !user.IsServiceUser { return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") } @@ -1606,3 +1609,113 @@ func separateGroups(autoGroups []string, allGroups []*types.Group) ([]string, ma func (am *DefaultAccountManager) GetStore() store.Store { return am.Store } + +// Creates account by private domain. +// Expects domain value to be a valid and a private dns domain. +func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) { + cancel := am.Store.AcquireGlobalLock(ctx) + defer cancel() + + domain = strings.ToLower(domain) + + count, err := am.Store.CountAccountsByPrivateDomain(ctx, domain) + if err != nil { + return nil, err + } + + if count > 0 { + return nil, status.Errorf(status.InvalidArgument, "account with private domain already exists") + } + + // retry twice for new ID clashes + for range 2 { + accountId := xid.New().String() + + exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId) + if err != nil || exists { + continue + } + + network := types.NewNetwork() + peers := make(map[string]*nbpeer.Peer) + users := make(map[string]*types.User) + routes := make(map[route.ID]*route.Route) + setupKeys := map[string]*types.SetupKey{} + nameServersGroups := make(map[string]*nbdns.NameServerGroup) + + dnsSettings := types.DNSSettings{ + DisabledManagementGroups: make([]string, 0), + } + + newAccount := &types.Account{ + Id: accountId, + CreatedAt: time.Now().UTC(), + SetupKeys: setupKeys, + Network: network, + Peers: peers, + Users: users, + // @todo check if using the MSP owner id here is ok + CreatedBy: initiatorId, + Domain: domain, + DomainCategory: types.PrivateCategory, + IsDomainPrimaryAccount: false, + Routes: routes, + NameServerGroups: nameServersGroups, + DNSSettings: dnsSettings, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + RoutingPeerDNSResolutionEnabled: true, + }, + } + + if err := newAccount.AddAllGroup(); err != nil { + return nil, status.Errorf(status.Internal, "failed to add all group to new account by private domain") + } + + if err := am.Store.SaveAccount(ctx, newAccount); err != nil { + log.WithContext(ctx).Errorf("failed to save new account %s by private domain: %v", newAccount.Id, err) + return nil, err + } + + am.StoreEvent(ctx, initiatorId, newAccount.Id, accountId, activity.AccountCreated, nil) + return newAccount, nil + } + + return nil, status.Errorf(status.Internal, "failed to create new account by private domain") +} + +func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { + account, err := am.Store.GetAccount(ctx, accountId) + if err != nil { + return nil, err + } + + if account.IsDomainPrimaryAccount { + return account, nil + } + + // additional check to ensure there is only one account for this domain at the time of update + count, err := am.Store.CountAccountsByPrivateDomain(ctx, account.Domain) + if err != nil { + return nil, err + } + + if count > 1 { + return nil, status.Errorf(status.Internal, "more than one account exists with the same private domain") + } + + account.IsDomainPrimaryAccount = true + + if err := am.Store.SaveAccount(ctx, account); err != nil { + log.WithContext(ctx).Errorf("failed to update primary account %s by private domain: %v", account.Id, err) + return nil, status.Errorf(status.Internal, "failed to update primary account %s by private domain", account.Id) + } + + return account, nil +} diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 37c50267b..807d05067 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -111,4 +111,7 @@ type Manager interface { BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error GetStore() store.Store + CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) + UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) + GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 715cfab84..49a7464e3 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -17,6 +17,7 @@ import ( nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/util" @@ -2815,6 +2816,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) { return nil, err } + permissionsManagerMock := permissions.NewManagerMock() + ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) @@ -2828,7 +2831,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) { Return(false, nil). AnyTimes() - manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager) + manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) if err != nil { return nil, err } @@ -3150,3 +3153,51 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { }) } } + +func Test_CreateAccountByPrivateDomain(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + ctx := context.Background() + initiatorId := "test-user" + domain := "example.com" + + account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain) + assert.NoError(t, err) + + assert.False(t, account.IsDomainPrimaryAccount) + assert.Equal(t, domain, account.Domain) + assert.Equal(t, types.PrivateCategory, account.DomainCategory) + assert.Equal(t, initiatorId, account.CreatedBy) + assert.Equal(t, 1, len(account.Groups)) + assert.Equal(t, 0, len(account.Users)) + assert.Equal(t, 0, len(account.SetupKeys)) + + // retry should fail + _, err = manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain) + assert.Error(t, err) +} + +func Test_UpdateToPrimaryAccount(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + ctx := context.Background() + initiatorId := "test-user" + domain := "example.com" + + account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain) + assert.NoError(t, err) + assert.False(t, account.IsDomainPrimaryAccount) + + // retry should fail + account, err = manager.UpdateToPrimaryAccount(ctx, account.Id) + assert.NoError(t, err) + assert.True(t, account.IsDomainPrimaryAccount) +} diff --git a/management/server/dns.go b/management/server/dns.go index 39dc11eb2..8dcc59413 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -67,8 +67,8 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } if user.IsRegularUser() { @@ -89,8 +89,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID return err } - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return err } if !user.HasAdminPower() { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 824557356..aeccc6187 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -13,6 +13,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -210,13 +211,14 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) + permissionsManagerMock := permissions.NewManagerMock() ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager) + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) } func createDNSStore(t *testing.T) (store.Store, error) { diff --git a/management/server/event.go b/management/server/event.go index 788d1b51c..58c6c70fb 100644 --- a/management/server/event.go +++ b/management/server/event.go @@ -10,6 +10,8 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) func isEnabled() bool { @@ -19,16 +21,12 @@ func isEnabled() bool { // GetEvents returns a list of activity events of an account func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { return nil, err } @@ -58,6 +56,11 @@ func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userI filtered = append(filtered, event) } + err = am.fillEventsWithUserInfo(ctx, events, accountID, user) + if err != nil { + return nil, err + } + return filtered, nil } @@ -79,3 +82,156 @@ func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, ta }() } } + +type eventUserInfo struct { + email string + name string + accountId string +} + +func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, events []*activity.Event, accountId string, user *types.User) error { + eventUserInfo, err := am.getEventsUserInfo(ctx, events, accountId, user) + if err != nil { + return err + } + + for _, event := range events { + if !fillEventInitiatorInfo(eventUserInfo, event) { + log.WithContext(ctx).Warnf("failed to resolve user info for initiator: %s", event.InitiatorID) + } + + fillEventTargetInfo(eventUserInfo, event) + } + return nil +} + +func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, user *types.User) (map[string]eventUserInfo, error) { + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId) + if err != nil { + return nil, err + } + + // @note check whether using a external initiator user here is an issue + userInfos, err := am.BuildUserInfosForAccount(ctx, accountId, user.Id, accountUsers) + if err != nil { + return nil, err + } + + eventUserInfos := make(map[string]eventUserInfo) + for i, k := range userInfos { + eventUserInfos[i] = eventUserInfo{ + email: k.Email, + name: k.Name, + accountId: accountId, + } + } + + externalUserIds := []string{} + for _, event := range events { + if _, ok := eventUserInfos[event.InitiatorID]; ok { + continue + } + + if event.InitiatorID == activity.SystemInitiator || + event.InitiatorID == accountId || + event.Activity == activity.PeerAddedWithSetupKey { + // @todo other events to be excluded if never initiated by a user + continue + } + + externalUserIds = append(externalUserIds, event.InitiatorID) + } + + if len(externalUserIds) == 0 { + return eventUserInfos, nil + } + + return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos, user) +} + +func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo, user *types.User) (map[string]eventUserInfo, error) { + externalAccountId := "" + fetched := make(map[string]struct{}) + externalUsers := []*types.User{} + for _, id := range externalUserIds { + if _, ok := fetched[id]; ok { + continue + } + + externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) + if err != nil { + // @todo consider logging + continue + } + + if externalAccountId != "" && externalAccountId != externalUser.AccountID { + return nil, fmt.Errorf("multiple external user accounts in events") + } + + if externalAccountId == "" { + externalAccountId = externalUser.AccountID + } + + fetched[id] = struct{}{} + externalUsers = append(externalUsers, externalUser) + } + + // if we couldn't determine an account, return what we have + if externalAccountId == "" { + log.WithContext(ctx).Warnf("failed to determine external user account from users: %v", externalUserIds) + return eventUserInfos, nil + } + + externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, user.Id, externalUsers) + if err != nil { + return nil, err + } + + for i, k := range externalUserInfos { + eventUserInfos[i] = eventUserInfo{ + email: k.Email, + name: k.Name, + accountId: externalAccountId, + } + } + + return eventUserInfos, nil +} + +func fillEventTargetInfo(eventUserInfo map[string]eventUserInfo, event *activity.Event) { + userInfo, ok := eventUserInfo[event.TargetID] + if !ok { + return + } + + if event.Meta == nil { + event.Meta = make(map[string]any) + } + + event.Meta["email"] = userInfo.email + event.Meta["username"] = userInfo.name +} + +func fillEventInitiatorInfo(eventUserInfo map[string]eventUserInfo, event *activity.Event) bool { + userInfo, ok := eventUserInfo[event.InitiatorID] + if !ok { + return false + } + + if event.InitiatorEmail == "" { + event.InitiatorEmail = userInfo.email + } + + if event.InitiatorName == "" { + event.InitiatorName = userInfo.name + } + + if event.AccountID != userInfo.accountId { + if event.Meta == nil { + event.Meta = make(map[string]any) + } + + event.Meta["external"] = true + } + return true +} diff --git a/management/server/group.go b/management/server/group.go index 69140bc00..01ebb457c 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -35,8 +35,8 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco return err } - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return err } if user.IsRegularUser() { @@ -83,8 +83,8 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return err } - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return err } if user.IsRegularUser() { @@ -215,8 +215,8 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us return err } - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return err } if user.IsRegularUser() { diff --git a/management/server/group_test.go b/management/server/group_test.go index 8cdef1dd8..dffaa80e3 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -11,7 +11,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/management-integrations/integrations" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/networks" @@ -20,10 +19,8 @@ import ( routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/route" ) @@ -691,10 +688,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { // Saving a group linked to network router should update account peers and send peer update t.Run("saving group linked to network router", func(t *testing.T) { - userManager := users.NewManager(manager.Store) - extraSettingsManager := integrations.NewManager(nil) - settingsManager := settings.NewManager(manager.Store, userManager, extraSettingsManager) - permissionsManager := permissions.NewManager(userManager, settingsManager) + permissionsManager := permissions.NewManager(manager.Store) groupsManager := groups.NewManager(manager.Store, permissionsManager, manager) resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager) routersManager := routers.NewManager(manager.Store, permissionsManager, manager) diff --git a/management/server/http/handlers/events/events_handler.go b/management/server/http/handlers/events/events_handler.go index 7ebdef78f..eee5d8aa7 100644 --- a/management/server/http/handlers/events/events_handler.go +++ b/management/server/http/handlers/events/events_handler.go @@ -1,7 +1,6 @@ package events import ( - "context" "fmt" "net/http" @@ -47,66 +46,15 @@ func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } + events := make([]*api.Event, len(accountEvents)) for i, e := range accountEvents { events[i] = toEventResponse(e) } - err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - util.WriteJSONObject(r.Context(), w, events) } -func (h *handler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { - // build email, name maps based on users - userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId) - if err != nil { - log.WithContext(ctx).Errorf("failed to get users from account: %s", err) - return err - } - - emails := make(map[string]string) - names := make(map[string]string) - for _, ui := range userInfos { - emails[ui.ID] = ui.Email - names[ui.ID] = ui.Name - } - - var ok bool - for _, event := range events { - // fill initiator - if event.InitiatorEmail == "" { - event.InitiatorEmail, ok = emails[event.InitiatorId] - if !ok { - log.WithContext(ctx).Warnf("failed to resolve email for initiator: %s", event.InitiatorId) - } - } - - if event.InitiatorName == "" { - // here to allowed to be empty because in the first release we did not store the name - event.InitiatorName = names[event.InitiatorId] - } - - // fill target meta - email, ok := emails[event.TargetId] - if !ok { - continue - } - event.Meta["email"] = email - - username, ok := names[event.TargetId] - if !ok { - continue - } - event.Meta["username"] = username - } - return nil -} - func toEventResponse(event *activity.Event) *api.Event { meta := make(map[string]string) if event.Meta != nil { diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 9342d84a3..ae7255e5f 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -250,7 +250,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { return } - user, err := account.FindUser(userID) + user, err := h.accountManager.GetUserByID(r.Context(), userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -258,7 +258,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { // If the user is regular user and does not own the peer // with the given peerID return an empty list - if !user.HasAdminPower() && !user.IsServiceUser { + if !user.HasAdminPower() && !user.IsServiceUser && !userAuth.IsChild { peer, ok := account.Peers[peerID] if !ok { util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w) diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index cb60ae4f1..a03c3c29d 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -122,6 +122,18 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { } return p, nil }, + GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) { + switch id { + case adminUser: + return account.Users[adminUser], nil + case regularUser: + return account.Users[regularUser], nil + case serviceUser: + return account.Users[serviceUser], nil + default: + return nil, fmt.Errorf("user not found") + } + }, GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { return peers, nil }, diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go index 0f1c37eb7..ea731d9d8 100644 --- a/management/server/http/handlers/routes/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -301,7 +301,7 @@ func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) { foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w) + util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go index 01c4adcf3..31ea06460 100644 --- a/management/server/http/testing/testing_tools/tools.go +++ b/management/server/http/testing/testing_tools/tools.go @@ -16,19 +16,18 @@ import ( "github.com/golang-jwt/jwt" - "github.com/netbirdio/management-integrations/integrations" - - "github.com/netbirdio/netbird/management/server/account" - "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/users" - "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" @@ -124,8 +123,9 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve validatorMock := server.MocIntegratedValidator{} proxyController := integrations.NewController(store) userManager := users.NewManager(store) - settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{})) - am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager) + permissionsManagerMock := permissions.NewManagerMock() + settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManagerMock) + am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManagerMock) if err != nil { t.Fatalf("Failed to create manager: %v", err) } @@ -143,7 +143,6 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve resourcesManagerMock := resources.NewManagerMock() routersManagerMock := routers.NewManagerMock() groupsManagerMock := groups.NewManagerMock() - permissionsManagerMock := permissions.NewManagerMock() peersManager := peers.NewManager(store, permissionsManagerMock) apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManagerMock, peersManager, settingsManager) diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index d4933dd94..c87fe05ce 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -25,6 +25,7 @@ import ( mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -431,6 +432,8 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) + permissionsManagerMock := permissions.NewManagerMock() + ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) @@ -441,7 +444,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config) Return(&types.Settings{}, nil) accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", - eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager) + eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) if err != nil { cleanup() diff --git a/management/server/management_test.go b/management/server/management_test.go index 689a05623..dd987c005 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -194,6 +195,7 @@ func startServer( Return(&types.Settings{}, nil). AnyTimes() + permissionsManagerMock := permissions.NewManagerMock() accountManager, err := server.BuildManager( context.Background(), str, @@ -208,6 +210,7 @@ func startServer( metrics, port_forwarding.NewControllerMock(), settingsMockManager, + permissionsManagerMock, ) if err != nil { t.Fatalf("failed creating an account manager: %v", err) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index cb8d598f8..008a7059f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -112,6 +112,9 @@ type MockAccountManager struct { DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) GetStoreFunc func() store.Store + CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error) + UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) + GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) } func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { @@ -847,3 +850,24 @@ func (am *MockAccountManager) GetStore() store.Store { } return nil } + +func (am *MockAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) { + if am.CreateAccountByPrivateDomainFunc != nil { + return am.CreateAccountByPrivateDomainFunc(ctx, initiatorId, domain) + } + return nil, status.Errorf(codes.Unimplemented, "method CreateAccountByPrivateDomain is not implemented") +} + +func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { + if am.UpdateToPrimaryAccountFunc != nil { + return am.UpdateToPrimaryAccountFunc(ctx, accountId) + } + return nil, status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented") +} + +func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) { + if am.GetOwnerInfoFunc != nil { + return am.GetOwnerInfoFunc(ctx, accountId) + } + return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented") +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 1a01c7a89..b1cf2bc72 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -25,8 +25,8 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } if user.IsRegularUser() { @@ -46,8 +46,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } newNSGroup := &nbdns.NameServerGroup{ @@ -108,8 +108,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return err } var updateAccountPeers bool @@ -159,8 +159,8 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return err } - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return err } var nsGroup *nbdns.NameServerGroup @@ -203,8 +203,8 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } if user.IsRegularUser() { diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 9b260d237..13039ae63 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -774,11 +775,12 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) + permissionsManagerMock := permissions.NewManagerMock() ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager) + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) } func createNSStore(t *testing.T) (store.Store, error) { diff --git a/management/server/peer.go b/management/server/peer.go index 4e70fe6e3..e7d4b29f5 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -37,8 +37,8 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) @@ -188,8 +188,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } var peer *nbpeer.Peer @@ -321,8 +321,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return err } } @@ -621,7 +621,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if addedByUser { err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { - return fmt.Errorf("failed to update user last login: %w", err) + log.WithContext(ctx).Debugf("failed to update user last login: %v", err) } } else { err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID) @@ -1054,7 +1054,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.GetLastLogin()) if err != nil { - return err + log.WithContext(ctx).Debugf("failed to update user last login: %v", err) } am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) @@ -1099,8 +1099,8 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 0b91ff37d..b2563dcb0 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -20,9 +20,10 @@ import ( "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/util" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -1264,7 +1265,8 @@ func Test_RegisterPeerByUser(t *testing.T) { t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager) + permissionsManagerMock := permissions.NewManagerMock() + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1332,7 +1334,8 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager) + permissionsManagerMock := permissions.NewManagerMock() + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1403,7 +1406,8 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager) + permissionsManagerMock := permissions.NewManagerMock() + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 0345405fe..24ac09d1a 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -5,10 +5,9 @@ import ( "errors" "fmt" - "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/server/users" ) type Module string @@ -17,6 +16,8 @@ const ( Networks Module = "networks" Peers Module = "peers" Groups Module = "groups" + Settings Module = "settings" + Accounts Module = "accounts" ) type Operation string @@ -28,42 +29,50 @@ const ( type Manager interface { ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) + ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error } type managerImpl struct { - userManager users.Manager - settingsManager settings.Manager + store store.Store } type managerMock struct { } -func NewManager(userManager users.Manager, settingsManager settings.Manager) Manager { +func NewManager(store store.Store) Manager { return &managerImpl{ - userManager: userManager, - settingsManager: settingsManager, + store: store, } } func (m *managerImpl) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) { - user, err := m.userManager.GetUser(ctx, userID) + user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return false, err } if user == nil { - return false, errors.New("user not found") + return false, status.NewUserNotFoundError(userID) } - if user.AccountID != accountID { - return false, errors.New("user does not belong to account") + if err := m.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return false, err + } + + switch module { + case Accounts: + if operation == Write && user.Role != types.UserRoleOwner { + return false, nil + } + return true, nil + default: } switch user.Role { case types.UserRoleAdmin, types.UserRoleOwner: return true, nil case types.UserRoleUser: - return m.validateRegularUserPermissions(ctx, accountID, userID, module, operation) + return m.validateRegularUserPermissions(ctx, accountID, module, operation) case types.UserRoleBillingAdmin: return false, nil default: @@ -71,8 +80,8 @@ func (m *managerImpl) ValidateUserPermissions(ctx context.Context, accountID, us } } -func (m *managerImpl) validateRegularUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) { - settings, err := m.settingsManager.GetSettings(ctx, accountID, activity.SystemInitiator) +func (m *managerImpl) validateRegularUserPermissions(ctx context.Context, accountID string, module Module, operation Operation) (bool, error) { + settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return false, fmt.Errorf("failed to get settings: %w", err) } @@ -91,13 +100,30 @@ func (m *managerImpl) validateRegularUserPermissions(ctx context.Context, accoun return false, nil } +func (m *managerImpl) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error { + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + return nil +} + func NewManagerMock() Manager { return &managerMock{} } func (m *managerMock) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) { - if userID == "allowedUser" { + switch userID { + case "a23efe53-63fb-11ec-90d6-0242ac120003", "allowedUser", "testingUser", "account_creator", "serviceUserID", "test_user": return true, nil + default: + return false, nil } - return false, nil +} + +func (m *managerMock) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error { + // @note managers explicitly checked this, so should the mock + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + return nil } diff --git a/management/server/policy.go b/management/server/policy.go index d222bba8a..15111ba06 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -22,8 +22,8 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } if user.IsRegularUser() { @@ -43,8 +43,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } if user.IsRegularUser() { @@ -100,8 +100,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return err } if user.IsRegularUser() { @@ -148,8 +148,8 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } if user.IsRegularUser() { diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 1690f8e33..859ae6332 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -22,8 +22,8 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } if !user.HasAdminPower() { @@ -43,8 +43,8 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } if !user.HasAdminPower() { @@ -99,8 +99,8 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun return err } - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return err } if !user.HasAdminPower() { @@ -141,8 +141,8 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } if !user.HasAdminPower() { diff --git a/management/server/route.go b/management/server/route.go index 94663dc80..6198cb520 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -25,7 +25,11 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err + } + + if !user.IsAdminOrServiceUser() { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } @@ -119,6 +123,15 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err + } + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err @@ -236,6 +249,15 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + if err != nil { + return err + } + + if err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return err + } + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err @@ -310,6 +332,15 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + if err != nil { + return err + } + + if err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return err + } + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err @@ -342,7 +373,11 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err + } + + if !user.IsAdminOrServiceUser() { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } diff --git a/management/server/route_test.go b/management/server/route_test.go index 473fbd862..c8776b9ff 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -21,6 +21,7 @@ import ( routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -1259,6 +1260,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) + permissionsManagerMock := permissions.NewManagerMock() ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) @@ -1281,7 +1283,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { AnyTimes(). Return(&types.ExtraSettings{}, nil) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager) + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) } func createRouterStore(t *testing.T) (store.Store, error) { diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go index 28a984875..2b3f4877b 100644 --- a/management/server/settings/manager.go +++ b/management/server/settings/manager.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/extra_settings" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -25,13 +26,15 @@ type managerImpl struct { store store.Store extraSettingsManager extra_settings.Manager userManager users.Manager + permissionsManager permissions.Manager } -func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager) Manager { +func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager) Manager { return &managerImpl{ store: store, extraSettingsManager: extraSettingsManager, userManager: userManager, + permissionsManager: permissionsManager, } } @@ -41,13 +44,12 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager { func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) { if userID != activity.SystemInitiator { - user, err := m.userManager.GetUser(ctx, userID) + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Settings, permissions.Read) if err != nil { - return nil, fmt.Errorf("get user: %w", err) + return nil, status.NewPermissionValidationError(err) } - - if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { - return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + if !ok { + return nil, status.NewPermissionDeniedError() } } diff --git a/management/server/setupkey.go b/management/server/setupkey.go index b0bdad4e5..8b73a7d1e 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -61,8 +61,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } if user.IsRegularUser() { @@ -118,8 +118,8 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } if user.IsRegularUser() { @@ -180,8 +180,8 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } if user.IsRegularUser() { @@ -198,8 +198,8 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return nil, err } - if user.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return nil, err } if user.IsRegularUser() { @@ -226,8 +226,8 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return err } - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil { + return err } if user.IsRegularUser() { diff --git a/management/server/status/error.go b/management/server/status/error.go index adf7e060c..5ab6f4e9e 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -183,7 +183,7 @@ func NewPermissionDeniedError() error { } func NewPermissionValidationError(err error) error { - return Errorf(PermissionDenied, "failed to vlidate user permissions: %s", err) + return Errorf(PermissionDenied, "failed to validate user permissions: %s", err) } func NewResourceNotPartOfNetworkError(resourceID, networkID string) error { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 9bdf51bd9..aacb56ab8 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -586,6 +586,19 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre return users, nil } +func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) { + var user types.User + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed") + } + return nil, status.Errorf(status.Internal, "failed to get account owner from the store") + } + + return &user, nil +} + func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) { var groups []*types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) @@ -2194,3 +2207,17 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength return &peer, nil } + +func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) { + var count int64 + result := s.db.Model(&types.Account{}). + Where("domain = ? AND domain_category = ?", + strings.ToLower(domain), types.PrivateCategory, + ).Count(&count) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to count accounts by private domain %s: %s", domain, result.Error) + return 0, status.Errorf(status.Internal, "failed to count accounts by private domain") + } + + return count, nil +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 1975f11b2..c13a8dfe6 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -69,10 +69,12 @@ type Store interface { DeleteAccount(ctx context.Context, account *types.Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error + CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) + GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error diff --git a/management/server/user.go b/management/server/user.go index 381879ae6..c446bd8ea 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -30,8 +30,8 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI return nil, err } - if initiatorUser.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil { + return nil, err } if !initiatorUser.HasAdminPower() { @@ -93,8 +93,8 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return nil, err } - if initiatorUser.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil { + return nil, err } inviterID := userID @@ -142,12 +142,21 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u // createNewIdpUser validates the invite and creates a new user in the IdP func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) { + inviter, err := am.GetUserByID(ctx, inviterID) + if err != nil { + return nil, fmt.Errorf("failed to get inviter user: %w", err) + } + // inviterUser is the one who is inviting the new user - inviterUser, err := am.lookupUserInCache(ctx, inviterID, accountID) + inviterUser, err := am.lookupUserInCache(ctx, inviterID, inviter.AccountID) if err != nil { return nil, status.Errorf(status.NotFound, "inviter user with ID %s doesn't exist in IdP", inviterID) } + if inviterUser == nil { + return nil, status.Errorf(status.NotFound, "inviter user with ID %s is empty", inviterID) + } + // check if the user is already registered with this email => reject user, err := am.lookupUserInCacheByEmail(ctx, invite.Email, accountID) if err != nil { @@ -188,7 +197,7 @@ func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAu err = am.Store.SaveUserLastLogin(ctx, userAuth.AccountId, userAuth.UserId, userAuth.LastLogin) if err != nil { - log.WithContext(ctx).Errorf("failed saving user last login: %v", err) + log.WithContext(ctx).Debugf("failed to update user last login: %v", err) } if newLogin { @@ -228,8 +237,8 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return err } - if initiatorUser.AccountID != accountID { - return status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil { + return err } if !initiatorUser.HasAdminPower() { @@ -290,8 +299,8 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin return err } - if initiatorUser.AccountID != accountID { - return status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil { + return err } // check if the user is already registered with this ID @@ -338,8 +347,8 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, err } - if initiatorUser.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil { + return nil, err } targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) @@ -376,8 +385,8 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string return err } - if initiatorUser.AccountID != accountID { - return status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil { + return err } if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() { @@ -411,8 +420,8 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i return nil, err } - if initiatorUser.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil { + return nil, err } if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() { @@ -429,8 +438,8 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin return nil, err } - if initiatorUser.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil { + return nil, err } if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() { @@ -476,8 +485,8 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, err } - if initiatorUser.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil { + return nil, err } if !initiatorUser.HasAdminPower() || initiatorUser.IsBlocked() { @@ -511,7 +520,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate( - ctx, transaction, groupsMap, initiatorUser, update, addIfNotExists, settings, + ctx, transaction, groupsMap, accountID, initiatorUser, update, addIfNotExists, settings, ) if err != nil { return fmt.Errorf("failed to process user update: %w", err) @@ -597,13 +606,13 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac } func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transaction store.Store, groupsMap map[string]*types.Group, - initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) { + accountID string, initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) { if update == nil { return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } - oldUser, err := getUserOrCreateIfNotExists(ctx, transaction, update, addIfNotExists) + oldUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists) if err != nil { return false, nil, nil, nil, err } @@ -614,7 +623,6 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact // only auto groups, revoked status, and integration reference can be updated for now updatedUser := oldUser.Copy() - updatedUser.AccountID = initiatorUser.AccountID updatedUser.Role = update.Role updatedUser.Blocked = update.Blocked updatedUser.AutoGroups = update.AutoGroups @@ -657,17 +665,23 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } // getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist. -func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, update *types.User, addIfNotExists bool) (*types.User, error) { +func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) { existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id) if err != nil { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if !addIfNotExists { return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id) } + update.AccountID = accountID return update, nil // use all fields from update if addIfNotExists is true } return nil, err } + + if existingUser.AccountID != accountID { + return nil, status.Errorf(status.InvalidArgument, "user account ID mismatch") + } + return existingUser, nil } @@ -705,6 +719,7 @@ func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.Us // validateUserUpdate validates the update operation for a user. func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUser, update *types.User) error { + // @todo double check these if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") } @@ -790,8 +805,8 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun return nil, err } - if initiatorUser.AccountID != accountID { - return nil, status.NewUserNotPartOfAccountError() + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil { + return nil, err } return am.BuildUserInfosForAccount(ctx, accountID, initiatorUserID, accountUsers) @@ -967,6 +982,10 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return err } + if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, initiatorUser, false); err != nil { + return err + } + if !initiatorUser.HasAdminPower() { return status.NewAdminPermissionError() } @@ -1081,6 +1100,25 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI return updateAccountPeers, nil } +// GetOwnerInfo retrieves the owner information for a given account ID. +func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) { + owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + if owner == nil { + return nil, status.Errorf(status.NotFound, "owner not found") + } + + userInfo, err := am.getUserInfo(ctx, owner, accountID) + if err != nil { + return nil, err + } + + return userInfo, nil +} + // updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, groupsToRemove []string) (groupsToUpdate []*types.Group, err error) { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { diff --git a/management/server/user_test.go b/management/server/user_test.go index 13df2694f..d3344738b 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -8,11 +8,11 @@ import ( "time" "github.com/google/go-cmp/cmp" - "golang.org/x/exp/maps" nbcache "github.com/netbirdio/netbird/management/server/cache" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/util" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -59,9 +59,11 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: s, - eventStore: &activity.InMemoryEventStore{}, + Store: s, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) @@ -107,9 +109,11 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) @@ -133,9 +137,11 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) @@ -160,9 +166,11 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn) @@ -183,9 +191,11 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) @@ -214,9 +224,11 @@ func TestUser_DeletePAT(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } err = am.DeletePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1) @@ -255,9 +267,11 @@ func TestUser_GetPAT(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } pat, err := am.GetPAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1) @@ -296,9 +310,11 @@ func TestUser_GetAllPATs(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } pats, err := am.GetAllPATs(context.Background(), mockAccountID, mockUserID, mockUserID) @@ -390,9 +406,11 @@ func TestUser_CreateServiceUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } user, err := am.createServiceUser(context.Background(), mockAccountID, mockUserID, mockRole, mockServiceUserName, false, []string{"group1", "group2"}) @@ -435,9 +453,11 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ @@ -481,9 +501,11 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } _, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ @@ -510,10 +532,12 @@ func TestUser_InviteNewUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, - cacheLoading: map[string]chan struct{}{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + cacheLoading: map[string]chan struct{}{}, + permissionsManager: permissionsMananagerMock, } cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval) @@ -616,9 +640,11 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockServiceUserID) @@ -652,9 +678,11 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockUserID) @@ -704,10 +732,12 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ Store: store, eventStore: &activity.InMemoryEventStore{}, integratedPeerValidator: MocIntegratedValidator{}, + permissionsManager: permissionsMananagerMock, } testCases := []struct { @@ -812,10 +842,12 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ Store: store, eventStore: &activity.InMemoryEventStore{}, integratedPeerValidator: MocIntegratedValidator{}, + permissionsManager: permissionsMananagerMock, } testCases := []struct { @@ -921,9 +953,11 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } claims := nbcontext.UserAuth{ @@ -957,9 +991,11 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } users, err := am.ListUsers(context.Background(), mockAccountID) @@ -1044,9 +1080,11 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } users, err := am.ListUsers(context.Background(), mockAccountID) @@ -1087,11 +1125,13 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, - idpManager: &idp.GoogleWorkspaceManager{}, // empty manager - cacheLoading: map[string]chan struct{}{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + idpManager: &idp.GoogleWorkspaceManager{}, // empty manager + cacheLoading: map[string]chan struct{}{}, + permissionsManager: permissionsMananagerMock, } cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval) @@ -1148,9 +1188,11 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID) @@ -1180,9 +1222,11 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + permissionsMananagerMock := permissions.NewManagerMock() am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsMananagerMock, } users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockServiceUserID) @@ -1525,3 +1569,41 @@ func TestUserAccountPeersUpdate(t *testing.T) { } }) } + +func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) { + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + + account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "") + targetId := "user2" + account1.Users[targetId] = &types.User{ + Id: targetId, + AccountID: account1.Id, + ServiceUserName: "user2username", + } + require.NoError(t, s.SaveAccount(context.Background(), account1)) + + account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "") + require.NoError(t, s.SaveAccount(context.Background(), account2)) + + permissionsManagerMock := permissions.NewManagerMock() + am := DefaultAccountManager{ + Store: s, + eventStore: &activity.InMemoryEventStore{}, + idpManager: nil, + cacheLoading: map[string]chan struct{}{}, + permissionsManager: permissionsManagerMock, + } + + _, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true) + assert.Error(t, err, "update user to another account should fail") + + user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthShare, targetId) + require.NoError(t, err) + assert.Equal(t, account1.Users[targetId].Id, user.Id) + assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID) + assert.Equal(t, account1.Users[targetId].AutoGroups, user.AutoGroups) +}