diff --git a/management/client/rest/dns_test.go b/management/client/rest/dns_test.go index 0d57d63d7..b2e0a0bee 100644 --- a/management/client/rest/dns_test.go +++ b/management/client/rest/dns_test.go @@ -260,6 +260,7 @@ func TestDNS_Integration(t *testing.T) { nsGroupReq := api.NameserverGroupRequest{ Description: "Test", Enabled: true, + Domains: []string{}, Groups: []string{"cs1tnh0hhcjnqoiuebeg"}, Name: "test", Nameservers: []api.Nameserver{ diff --git a/management/server/account.go b/management/server/account.go index 2c62a2453..a0c6fd0b0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -67,7 +67,7 @@ type AccountManager interface { SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error) DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error - DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error + DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) @@ -79,7 +79,7 @@ type AccountManager interface { GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error - GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) + GetPATInfo(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error) DeleteAccount(ctx context.Context, accountID, userID string) error MarkPATUsed(ctx context.Context, tokenID string) error GetUserByID(ctx context.Context, id string) (*types.User, error) @@ -96,7 +96,7 @@ type AccountManager interface { DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) - GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) + GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) @@ -149,6 +149,7 @@ type AccountManager interface { GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error UpdateAccountPeers(ctx context.Context, accountID string) + BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) } type DefaultAccountManager struct { @@ -617,6 +618,12 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u if user.Role != types.UserRoleOwner { return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account") } + + userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users)) + if err != nil { + return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err) + } + for _, otherUser := range account.Users { if otherUser.IsServiceUser { continue @@ -626,13 +633,23 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u continue } - deleteUserErr := am.deleteRegularUser(ctx, account, userID, otherUser.Id) + userInfo, ok := userInfosMap[otherUser.Id] + if !ok { + return status.Errorf(status.NotFound, "user info not found for user %s", otherUser.Id) + } + + _, deleteUserErr := am.deleteRegularUser(ctx, accountID, userID, userInfo) if deleteUserErr != nil { return deleteUserErr } } - err = am.deleteRegularUser(ctx, account, userID, userID) + userInfo, ok := userInfosMap[userID] + if !ok { + return status.Errorf(status.NotFound, "user info not found for user %s", userID) + } + + _, err = am.deleteRegularUser(ctx, accountID, userID, userInfo) if err != nil { log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err) return err @@ -689,20 +706,8 @@ func isNil(i idp.Manager) bool { // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { - accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return err - } - cachedAccount := &types.Account{ - Id: accountID, - Users: make(map[string]*types.User), - } - for _, user := range accountUsers { - cachedAccount.Users[user.Id] = user - } - // user can be nil if it wasn't found (e.g., just created) - user, err := am.lookupUserInCache(ctx, userID, cachedAccount) + user, err := am.lookupUserInCache(ctx, userID, accountID) if err != nil { return err } @@ -778,10 +783,15 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e } // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil -func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *types.Account) (*idp.UserData, error) { - users := make(map[string]userLoggedInOnce, len(account.Users)) +func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) { + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + users := make(map[string]userLoggedInOnce, len(accountUsers)) // ignore service users and users provisioned by integrations than are never logged in - for _, user := range account.Users { + for _, user := range accountUsers { if user.IsServiceUser { continue } @@ -790,8 +800,8 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s } users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero()) } - log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, account.Id) - userData, err := am.lookupCache(ctx, users, account.Id) + log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, accountID) + userData, err := am.lookupCache(ctx, users, accountID) if err != nil { return nil, err } @@ -804,13 +814,13 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s // add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP, // or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta - user, err := account.FindUser(userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { - log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, account.Id) + log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID) return nil, err } - key := user.IntegrationReference.CacheKey(account.Id, userID) + key := user.IntegrationReference.CacheKey(accountID, userID) ud, err := am.externalCacheManager.Get(am.ctx, key) if err != nil { log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err) @@ -1050,9 +1060,9 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) defer unlockAccount() - usersMap := make(map[string]*types.User) - usersMap[claims.UserId] = types.NewRegularUser(claims.UserId) - err := am.Store.SaveUsers(domainAccountID, usersMap) + newUser := types.NewRegularUser(claims.UserId) + newUser.AccountID = domainAccountID + err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser) if err != nil { return "", err } @@ -1075,12 +1085,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str return nil } - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - user, err := am.lookupUserInCache(ctx, userID, account) + user, err := am.lookupUserInCache(ctx, userID, accountID) if err != nil { return err } @@ -1090,17 +1095,17 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str } if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite { - log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, account.Id) + log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID) // User has already logged in, meaning that IdP should have set wt_pending_invite to false. // Our job is to just reload cache. go func() { - _, err = am.refreshCache(ctx, account.Id) + _, err = am.refreshCache(ctx, accountID) if err != nil { - log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id) + log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID) return } - log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, account.Id) - am.StoreEvent(ctx, userID, userID, account.Id, activity.UserJoined, nil) + log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID) + am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil) }() } @@ -1109,33 +1114,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str // MarkPATUsed marks a personal access token as used func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error { - - user, err := am.Store.GetUserByTokenID(ctx, tokenID) - if err != nil { - return err - } - - account, err := am.Store.GetAccountByUser(ctx, user.Id) - if err != nil { - return err - } - - unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) - defer unlock() - - account, err = am.Store.GetAccountByUser(ctx, user.Id) - if err != nil { - return err - } - - pat, ok := account.Users[user.Id].PATs[tokenID] - if !ok { - return fmt.Errorf("token not found") - } - - pat.LastUsed = util.ToPtr(time.Now().UTC()) - - return am.Store.SaveAccount(ctx, account) + return am.Store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID) } // GetAccount returns an account associated with this account ID. @@ -1143,52 +1122,64 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin return am.Store.GetAccount(ctx, accountID) } -// GetAccountFromPAT returns Account and User associated with a personal access token -func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) { +// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token. +func (am *DefaultAccountManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) { + user, pat, err = am.extractPATFromToken(ctx, token) + if err != nil { + return nil, nil, "", "", err + } + + domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID) + if err != nil { + return nil, nil, "", "", err + } + + return user, pat, domain, category, nil +} + +// extractPATFromToken validates the token structure and retrieves associated User and PAT. +func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) { if len(token) != types.PATLength { - return nil, nil, nil, fmt.Errorf("token has wrong length") + return nil, nil, fmt.Errorf("token has incorrect length") } prefix := token[:len(types.PATPrefix)] if prefix != types.PATPrefix { - return nil, nil, nil, fmt.Errorf("token has wrong prefix") + return nil, nil, fmt.Errorf("token has wrong prefix") } secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength] encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength] verificationChecksum, err := base62.Decode(encodedChecksum) if err != nil { - return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) + return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) } secretChecksum := crc32.ChecksumIEEE([]byte(secret)) if secretChecksum != verificationChecksum { - return nil, nil, nil, fmt.Errorf("token checksum does not match") + return nil, nil, fmt.Errorf("token checksum does not match") } hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken) + + var user *types.User + var pat *types.PersonalAccessToken + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken) + if err != nil { + return err + } + + user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID) + return err + }) if err != nil { - return nil, nil, nil, err + return nil, nil, err } - user, err := am.Store.GetUserByTokenID(ctx, tokenID) - if err != nil { - return nil, nil, nil, err - } - - account, err := am.Store.GetAccountByUser(ctx, user.Id) - if err != nil { - return nil, nil, nil, err - } - - pat := user.PATs[tokenID] - if pat == nil { - return nil, nil, nil, fmt.Errorf("personal access token not found") - } - - return account, user, pat, nil + return user, pat, nil } // GetAccountByID returns an account associated with this account ID. @@ -1334,7 +1325,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error getting user peers: %w", err) } - updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups) + updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups) if err != nil { return fmt.Errorf("error modifying user peers in groups: %w", err) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 1fc1ceb92..0a7f9119b 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -732,6 +732,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { PATs: map[string]*types.PersonalAccessToken{ "tokenId": { ID: "tokenId", + UserID: "someUser", HashedToken: encodedHashedToken, }, }, @@ -745,14 +746,14 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { Store: store, } - account, user, pat, err := am.GetAccountFromPAT(context.Background(), token) + user, pat, _, _, err := am.GetPATInfo(context.Background(), token) if err != nil { t.Fatalf("Error when getting Account from PAT: %s", err) } - assert.Equal(t, "account_id", account.Id) + assert.Equal(t, "account_id", user.AccountID) assert.Equal(t, "someUser", user.Id) - assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat) + assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID) } func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { diff --git a/management/server/activity/sqlite/sqlite.go b/management/server/activity/sqlite/sqlite.go index 823e0b4ac..ffb863de9 100644 --- a/management/server/activity/sqlite/sqlite.go +++ b/management/server/activity/sqlite/sqlite.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "path/filepath" + "runtime" "time" _ "github.com/mattn/go-sqlite3" @@ -95,6 +96,7 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) ( if err != nil { return nil, err } + db.SetMaxOpenConns(runtime.NumCPU()) crypt, err := NewFieldEncrypt(encryptionKey) if err != nil { diff --git a/management/server/http/handler.go b/management/server/http/handler.go index cc2ad00b7..7ce09fffa 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -43,7 +43,7 @@ func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, network ) authMiddleware := middleware.NewAuthMiddleware( - accountManager.GetAccountFromPAT, + accountManager.GetPATInfo, jwtValidator.ValidateAndParse, accountManager.MarkPATUsed, accountManager.CheckUserAccessByJWTGroups, diff --git a/management/server/http/handlers/events/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go index 17478aba3..fd603f289 100644 --- a/management/server/http/handlers/events/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -32,8 +32,8 @@ func initEventsTestData(account string, events ...*activity.Event) *handler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) { - return make([]*types.UserInfo, 0), nil + GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) { + return make(map[string]*types.UserInfo), nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index 90081830a..ff77cedff 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -52,7 +52,7 @@ var usersTestAccount = &types.Account{ Issued: types.UserIssuedAPI, }, nonDeletableServiceUserID: { - Id: serviceUserID, + Id: nonDeletableServiceUserID, Role: "admin", IsServiceUser: true, NonDeletable: true, @@ -70,10 +70,10 @@ func initUsersTestData() *handler { GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) { return usersTestAccount.Users[id], nil }, - GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) { - users := make([]*types.UserInfo, 0) + GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) { + usersInfos := make(map[string]*types.UserInfo) for _, v := range usersTestAccount.Users { - users = append(users, &types.UserInfo{ + usersInfos[v.Id] = &types.UserInfo{ ID: v.Id, Role: string(v.Role), Name: "", @@ -81,9 +81,9 @@ func initUsersTestData() *handler { IsServiceUser: v.IsServiceUser, NonDeletable: v.NonDeletable, Issued: v.Issued, - }) + } } - return users, nil + return usersInfos, nil }, CreateUserFunc: func(_ context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) { if userID != existingUserID { diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 182c30cf6..dcf73259a 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -19,8 +19,8 @@ import ( "github.com/netbirdio/netbird/management/server/types" ) -// GetAccountFromPATFunc function -type GetAccountFromPATFunc func(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) +// GetAccountInfoFromPATFunc function +type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) // ValidateAndParseTokenFunc function type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error) @@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { - getAccountFromPAT GetAccountFromPATFunc + getAccountInfoFromPAT GetAccountInfoFromPATFunc validateAndParseToken ValidateAndParseTokenFunc markPATUsed MarkPATUsedFunc checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc @@ -47,7 +47,7 @@ const ( ) // NewAuthMiddleware instance constructor -func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, +func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor, audience string, userIdClaim string) *AuthMiddleware { if userIdClaim == "" { @@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse } return &AuthMiddleware{ - getAccountFromPAT: getAccountFromPAT, + getAccountInfoFromPAT: getAccountInfoFromPAT, validateAndParseToken: validateAndParseToken, markPATUsed: markPATUsed, checkUserAccessByJWTGroups: checkUserAccessByJWTGroups, @@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j // CheckPATFromRequest checks if the PAT is valid func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { token, err := getTokenFromPATRequest(auth) - - // If an error occurs, call the error handler and return an error if err != nil { - return fmt.Errorf("Error extracting token: %w", err) + return fmt.Errorf("error extracting token: %w", err) } - account, user, pat, err := m.getAccountFromPAT(r.Context(), token) + user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token) if err != nil { return fmt.Errorf("invalid Token: %w", err) } @@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ claimMaps := jwt.MapClaims{} claimMaps[m.userIDClaim] = user.Id - claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id - claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain - claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory + claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID + claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain + claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory claimMaps[jwtclaims.IsToken] = true jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 41bdb7fc5..c1686ed44 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -34,7 +34,8 @@ var testAccount = &types.Account{ Domain: domain, Users: map[string]*types.User{ userID: { - Id: userID, + Id: userID, + AccountID: accountID, PATs: map[string]*types.PersonalAccessToken{ tokenID: { ID: tokenID, @@ -50,11 +51,11 @@ var testAccount = &types.Account{ }, } -func mockGetAccountFromPAT(_ context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) { +func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) { if token == PAT { - return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil + return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil } - return nil, nil, nil, fmt.Errorf("PAT invalid") + return nil, nil, "", "", fmt.Errorf("PAT invalid") } func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) { @@ -166,7 +167,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { ) authMiddleware := NewAuthMiddleware( - mockGetAccountFromPAT, + mockGetAccountInfoFromPAT, mockValidateAndParseToken, mockMarkPATUsed, mockCheckUserAccessByJWTGroups, diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go index 549a51c0e..0baf76328 100644 --- a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go @@ -35,14 +35,14 @@ var benchCasesUsers = map[string]testing_tools.BenchmarkCase{ func BenchmarkUpdateUser(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Users - XS": {MinMsPerOpLocal: 700, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 8000}, - "Users - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 4, MaxMsPerOpCICD: 50}, - "Users - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 250}, - "Users - L": {MinMsPerOpLocal: 60, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 90, MaxMsPerOpCICD: 700}, - "Peers - L": {MinMsPerOpLocal: 300, MaxMsPerOpLocal: 500, MinMsPerOpCICD: 550, MaxMsPerOpCICD: 2400}, - "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 750, MaxMsPerOpCICD: 5000}, - "Setup Keys - L": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 1000}, - "Users - XL": {MinMsPerOpLocal: 350, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 3500}, + "Users - XS": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 160, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 310}, + "Users - S": {MinMsPerOpLocal: 0.3, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 15}, + "Users - M": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 3, MaxMsPerOpCICD: 20}, + "Users - L": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 50}, + "Peers - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 310}, + "Groups - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 120}, + "Setup Keys - L": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 50}, + "Users - XL": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 280}, } log.SetOutput(io.Discard) @@ -118,14 +118,14 @@ func BenchmarkGetOneUser(b *testing.B) { func BenchmarkGetAllUsers(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Users - XS": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 180}, - "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, - "Users - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 12, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, - "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, - "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, - "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, - "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 140, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 200}, - "Users - XL": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 90}, + "Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10}, + "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10}, + "Users - M": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 15}, + "Users - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 50}, + "Peers - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 55}, + "Groups - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55}, + "Setup Keys - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55}, + "Users - XL": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 120, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, } log.SetOutput(io.Discard) @@ -141,7 +141,7 @@ func BenchmarkGetAllUsers(b *testing.B) { b.ResetTimer() start := time.Now() for i := 0; i < b.N; i++ { - req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId) + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId) apiHandler.ServeHTTP(recorder, req) } @@ -152,14 +152,14 @@ func BenchmarkGetAllUsers(b *testing.B) { func BenchmarkDeleteUsers(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Users - XS": {MinMsPerOpLocal: 1000, MaxMsPerOpLocal: 1600, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 11000}, - "Users - S": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200}, - "Users - M": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 230}, - "Users - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 45, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 190}, - "Peers - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 1800}, - "Groups - L": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 1200, MaxMsPerOpCICD: 7500}, - "Setup Keys - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 40, MaxMsPerOpCICD: 600}, - "Users - XL": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 400}, + "Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Users - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Users - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, } log.SetOutput(io.Discard) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index c8e42d20a..b20eb87bb 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -53,8 +53,8 @@ type MockAccountManager struct { SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error) - GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) - GetAccountFromPATFunc func(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) + GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) + GetPATInfoFunc func(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error) MarkPATUsedFunc func(ctx context.Context, pat string) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) @@ -69,7 +69,7 @@ type MockAccountManager struct { SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error - DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error + DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error) @@ -110,6 +110,7 @@ type MockAccountManager struct { GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error) GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error + BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) } func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { @@ -165,7 +166,7 @@ func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userI } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface -func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*types.UserInfo, error) { +func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) (map[string]*types.UserInfo, error) { if am.GetUsersFromAccountFunc != nil { return am.GetUsersFromAccountFunc(ctx, accountID, userID) } @@ -238,12 +239,12 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface -func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) { - if am.GetAccountFromPATFunc != nil { - return am.GetAccountFromPATFunc(ctx, pat) +// GetPATInfo mock implementation of GetPATInfo from server.AccountManager interface +func (am *MockAccountManager) GetPATInfo(ctx context.Context, pat string) (*types.User, *types.PersonalAccessToken, string, string, error) { + if am.GetPATInfoFunc != nil { + return am.GetPATInfoFunc(ctx, pat) } - return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") + return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetPATInfo is not implemented") } // DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface @@ -550,9 +551,9 @@ func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, } // DeleteRegularUsers mocks DeleteRegularUsers of the AccountManager interface -func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID string, initiatorUserID string, targetUserIDs []string) error { +func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error { if am.DeleteRegularUsersFunc != nil { - return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs) + return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs, userInfos) } return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented") } @@ -849,3 +850,11 @@ func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peer } return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented") } + +// BuildUserInfosForAccount mocks BuildUserInfosForAccount of the AccountManager interface +func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) { + if am.BuildUserInfosForAccountFunc != nil { + return am.BuildUserInfosForAccountFunc(ctx, accountID, initiatorUserID, accountUsers) + } + return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented") +} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index a0417c996..6894d092d 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -28,7 +29,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" - nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" @@ -1554,7 +1554,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) { // Adding peer to group linked with policy should update account peers and send peer update t.Run("adding peer to group linked with policy", func(t *testing.T) { _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ - Enabled: true, + AccountID: account.Id, + Enabled: true, Rules: []*types.PolicyRule{ { Enabled: true, diff --git a/management/server/status/error.go b/management/server/status/error.go index 7e384922d..96b103183 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -93,7 +93,7 @@ func NewPeerNotPartOfAccountError() error { // NewUserNotFoundError creates a new Error with NotFound type for a missing user func NewUserNotFoundError(userKey string) error { - return Errorf(NotFound, "user not found: %s", userKey) + return Errorf(NotFound, "user: %s not found", userKey) } // NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer @@ -191,3 +191,18 @@ func NewResourceNotPartOfNetworkError(resourceID, networkID string) error { func NewRouterNotPartOfNetworkError(routerID, networkID string) error { return Errorf(BadRequest, "router %s is not part of the network %s", routerID, networkID) } + +// NewServiceUserRoleInvalidError creates a new Error with InvalidArgument type for creating a service user with owner role +func NewServiceUserRoleInvalidError() error { + return Errorf(InvalidArgument, "can't create a service user with owner role") +} + +// NewOwnerDeletePermissionError creates a new Error with PermissionDenied type for attempting +// to delete a user with the owner role. +func NewOwnerDeletePermissionError() error { + return Errorf(PermissionDenied, "can't delete a user with the owner role") +} + +func NewPATNotFoundError(patID string) error { + return Errorf(NotFound, "PAT: %s not found", patID) +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 2179f0754..6a6753595 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -15,6 +15,7 @@ import ( "sync" "time" + "github.com/netbirdio/netbird/management/server/util" log "github.com/sirupsen/logrus" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -414,24 +415,16 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr } // SaveUsers saves the given list of users to the database. -// It updates existing users if a conflict occurs. -func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) error { - usersToSave := make([]types.User, 0, len(users)) - for _, user := range users { - user.AccountID = accountID - for id, pat := range user.PATs { - pat.ID = id - user.PATsG = append(user.PATsG, *pat) - } - usersToSave = append(usersToSave, *user) - } - err := s.db.Session(&gorm.Session{FullSaveAssociations: true}). - Clauses(clause.OnConflict{UpdateAll: true}). - Create(&usersToSave).Error - if err != nil { - return status.Errorf(status.Internal, "failed to save users to store: %v", err) +func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error { + if len(users) == 0 { + return nil } + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&users) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error) + return status.Errorf(status.Internal, "failed to save users to store") + } return nil } @@ -439,7 +432,8 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) err func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) if result.Error != nil { - return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error) + log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error) + return status.Errorf(status.Internal, "failed to save user to store") } return nil } @@ -526,30 +520,17 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri return token.ID, nil } -func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) { - var token types.PersonalAccessToken - result := s.db.First(&token, idQueryCondition, tokenID) +func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) { + var user types.User + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id"). + Where("personal_access_tokens.id = ?", patID).First(&user) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + return nil, status.NewPATNotFoundError(patID) } - log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) - return nil, status.NewGetAccountFromStoreError(result.Error) - } - - if token.UserID == "" { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") - } - - var user types.User - result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID) - if result.Error != nil { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") - } - - user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATsG)) - for _, pat := range user.PATsG { - user.PATs[pat.ID] = pat.Copy() + log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error) + return nil, status.NewGetUserFromStoreError() } return &user, nil @@ -557,8 +538,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { var user types.User - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Preload(clause.Associations).First(&user, idQueryCondition, userID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) @@ -569,6 +549,25 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } +func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error { + err := s.db.Transaction(func(tx *gorm.DB) error { + result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&types.PersonalAccessToken{}, "user_id = ?", userID) + if result.Error != nil { + return result.Error + } + + return tx.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&types.User{}, accountAndIDQueryCondition, accountID, userID).Error + }) + if err != nil { + log.WithContext(ctx).Errorf("failed to delete user from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete user from store") + } + + return nil +} + func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) { var users []*types.User result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) @@ -899,6 +898,20 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS return accountSettings.Settings, nil } +func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) { + var createdBy string + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). + Select("created_by").First(&createdBy, idQueryCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", status.NewAccountNotFoundError(accountID) + } + return "", status.NewGetAccountFromStoreError(result.Error) + } + + return createdBy, nil +} + // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { var user types.User @@ -2053,3 +2066,94 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki return nil } + +// GetPATByHashedToken returns a PersonalAccessToken by its hashed token. +func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewPATNotFoundError(hashedToken) + } + log.WithContext(ctx).Errorf("failed to get pat by hash from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get pat by hash from store") + } + + return &pat, nil +} + +// GetPATByID retrieves a personal access token by its ID and user ID. +func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&pat, "id = ? AND user_id = ?", patID, userID) + if err := result.Error; err != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewPATNotFoundError(patID) + } + log.WithContext(ctx).Errorf("failed to get pat from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get pat from store") + } + + return &pat, nil +} + +// GetUserPATs retrieves personal access tokens for a user. +func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) { + var pats []*types.PersonalAccessToken + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get user pat's from store") + } + + return pats, nil +} + +// MarkPATUsed marks a personal access token as used. +func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error { + patCopy := types.PersonalAccessToken{ + LastUsed: util.ToPtr(time.Now().UTC()), + } + + fieldsToUpdate := []string{"last_used"} + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Select(fieldsToUpdate). + Where(idQueryCondition, patID).Updates(&patCopy) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to mark pat as used: %s", result.Error) + return status.Errorf(status.Internal, "failed to mark pat as used") + } + + if result.RowsAffected == 0 { + return status.NewPATNotFoundError(patID) + } + + return nil +} + +// SavePAT saves a personal access token to the database. +func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *types.PersonalAccessToken) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err) + return status.Errorf(status.Internal, "failed to save pat to store") + } + + return nil +} + +// DeletePAT deletes a personal access token from the database. +func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&types.PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete pat from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete pat from store") + } + + if result.RowsAffected == 0 { + return status.NewPATNotFoundError(patID) + } + + return nil +} diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 9350da1c8..4dcdadf44 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -627,29 +627,6 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") } -func TestSqlite_GetUserByTokenID(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - - id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - - user, err := store.GetUserByTokenID(context.Background(), id) - require.NoError(t, err) - require.Equal(t, id, user.PATs[id].ID) - - _, err = store.GetUserByTokenID(context.Background(), "non-existing-id") - require.Error(t, err) - parsedErr, ok := status.FromError(err) - require.True(t, ok) - require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") -} - func TestMigrate(t *testing.T) { if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { t.Skip("skip CI tests on darwin and windows") @@ -962,23 +939,6 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { require.Equal(t, id, token) } -func TestPostgresql_GetUserByTokenID(t *testing.T) { - if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { - t.Skip("skip CI tests on darwin and windows") - } - - t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - - id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - - user, err := store.GetUserByTokenID(context.Background(), id) - require.NoError(t, err) - require.Equal(t, id, user.PATs[id].ID) -} - func TestSqlite_GetTakenIPs(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) @@ -1182,7 +1142,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { assert.NoError(t, err) } -func TestSqlite_GetAccoundUsers(t *testing.T) { +func TestSqlStore_GetAccountUsers(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { @@ -2915,3 +2875,326 @@ func TestSqlStore_DatabaseBlocking(t *testing.T) { t.Logf("Test completed") } + +func TestSqlStore_GetAccountCreatedBy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectError bool + createdBy string + }{ + { + name: "existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectError: false, + createdBy: "edafee4e-63fb-11ec-90d6-0242ac120003", + }, + { + name: "non-existing account ID", + accountID: "nonexistent", + expectError: true, + }, + { + name: "empty account ID", + accountID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + createdBy, err := store.GetAccountCreatedBy(context.Background(), LockingStrengthShare, tt.accountID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Empty(t, createdBy) + } else { + require.NoError(t, err) + require.NotNil(t, createdBy) + require.Equal(t, tt.createdBy, createdBy) + } + }) + } + +} + +func TestSqlStore_GetUserByUserID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + userID string + expectError bool + }{ + { + name: "retrieve existing user", + userID: "edafee4e-63fb-11ec-90d6-0242ac120003", + expectError: false, + }, + { + name: "retrieve non-existing user", + userID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty user ID", + userID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, tt.userID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, user) + } else { + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, tt.userID, user.Id) + } + }) + } +} + +func TestSqlStore_GetUserByPATID(t *testing.T) { + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + + id := "9dj38s35-63fb-11ec-90d6-0242ac120003" + + user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id) + require.NoError(t, err) + require.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id) +} + +func TestSqlStore_SaveUser(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + user := &types.User{ + Id: "user-id", + AccountID: accountID, + Role: types.UserRoleAdmin, + IsServiceUser: false, + AutoGroups: []string{"groupA", "groupB"}, + Blocked: false, + LastLogin: util.ToPtr(time.Now().UTC()), + CreatedAt: time.Now().UTC().Add(-time.Hour), + Issued: types.UserIssuedIntegration, + } + err = store.SaveUser(context.Background(), LockingStrengthUpdate, user) + require.NoError(t, err) + + saveUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, user.Id) + require.NoError(t, err) + require.Equal(t, user.Id, saveUser.Id) + require.Equal(t, user.AccountID, saveUser.AccountID) + require.Equal(t, user.Role, saveUser.Role) + require.Equal(t, user.AutoGroups, saveUser.AutoGroups) + require.WithinDurationf(t, user.GetLastLogin(), saveUser.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal") + require.WithinDurationf(t, user.CreatedAt, saveUser.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal") + require.Equal(t, user.Issued, saveUser.Issued) + require.Equal(t, user.Blocked, saveUser.Blocked) + require.Equal(t, user.IsServiceUser, saveUser.IsServiceUser) +} + +func TestSqlStore_SaveUsers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + accountUsers, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + require.Len(t, accountUsers, 2) + + users := []*types.User{ + { + Id: "user-1", + AccountID: accountID, + Issued: "api", + AutoGroups: []string{"groupA", "groupB"}, + }, + { + Id: "user-2", + AccountID: accountID, + Issued: "integration", + AutoGroups: []string{"groupA"}, + }, + } + err = store.SaveUsers(context.Background(), LockingStrengthUpdate, users) + require.NoError(t, err) + + accountUsers, err = store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + require.Len(t, accountUsers, 4) +} + +func TestSqlStore_DeleteUser(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + userID := "f4f6d672-63fb-11ec-90d6-0242ac120003" + + err = store.DeleteUser(context.Background(), LockingStrengthUpdate, accountID, userID) + require.NoError(t, err) + + user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, userID) + require.Error(t, err) + require.Nil(t, user) + + userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthShare, userID) + require.NoError(t, err) + require.Len(t, userPATs, 0) +} + +func TestSqlStore_GetPATByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + userID := "f4f6d672-63fb-11ec-90d6-0242ac120003" + + tests := []struct { + name string + patID string + expectError bool + }{ + { + name: "retrieve existing PAT", + patID: "9dj38s35-63fb-11ec-90d6-0242ac120003", + expectError: false, + }, + { + name: "retrieve non-existing PAT", + patID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty PAT ID", + patID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, tt.patID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, pat) + } else { + require.NoError(t, err) + require.NotNil(t, pat) + require.Equal(t, tt.patID, pat.ID) + } + }) + } +} + +func TestSqlStore_GetUserPATs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthShare, "f4f6d672-63fb-11ec-90d6-0242ac120003") + require.NoError(t, err) + require.Len(t, userPATs, 1) +} + +func TestSqlStore_GetPATByHashedToken(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "SoMeHaShEdToKeN") + require.NoError(t, err) + require.Equal(t, "9dj38s35-63fb-11ec-90d6-0242ac120003", pat.ID) +} + +func TestSqlStore_MarkPATUsed(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + userID := "f4f6d672-63fb-11ec-90d6-0242ac120003" + patID := "9dj38s35-63fb-11ec-90d6-0242ac120003" + + err = store.MarkPATUsed(context.Background(), LockingStrengthUpdate, patID) + require.NoError(t, err) + + pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, patID) + require.NoError(t, err) + now := time.Now().UTC() + require.WithinRange(t, pat.LastUsed.UTC(), now.Add(-15*time.Second), now, "LastUsed should be within 1 second of now") +} + +func TestSqlStore_SavePAT(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + userID := "edafee4e-63fb-11ec-90d6-0242ac120003" + + pat := &types.PersonalAccessToken{ + ID: "pat-id", + UserID: userID, + Name: "token", + HashedToken: "SoMeHaShEdToKeN", + ExpirationDate: util.ToPtr(time.Now().UTC().Add(12 * time.Hour)), + CreatedBy: userID, + CreatedAt: time.Now().UTC().Add(time.Hour), + LastUsed: util.ToPtr(time.Now().UTC().Add(-15 * time.Minute)), + } + err = store.SavePAT(context.Background(), LockingStrengthUpdate, pat) + require.NoError(t, err) + + savePAT, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, pat.ID) + require.NoError(t, err) + require.Equal(t, pat.ID, savePAT.ID) + require.Equal(t, pat.UserID, savePAT.UserID) + require.Equal(t, pat.HashedToken, savePAT.HashedToken) + require.Equal(t, pat.CreatedBy, savePAT.CreatedBy) + require.WithinDurationf(t, pat.GetExpirationDate(), savePAT.ExpirationDate.UTC(), time.Millisecond, "ExpirationDate should be equal") + require.WithinDurationf(t, pat.CreatedAt, savePAT.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal") + require.WithinDurationf(t, pat.GetLastUsed(), savePAT.LastUsed.UTC(), time.Millisecond, "LastUsed should be equal") +} + +func TestSqlStore_DeletePAT(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + userID := "f4f6d672-63fb-11ec-90d6-0242ac120003" + patID := "9dj38s35-63fb-11ec-90d6-0242ac120003" + + err = store.DeletePAT(context.Background(), LockingStrengthUpdate, userID, patID) + require.NoError(t, err) + + pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, patID) + require.Error(t, err) + require.Nil(t, pat) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 4b4dcfb4f..6d3a409e6 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -59,21 +59,30 @@ type Store interface { GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) + GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) SaveAccount(ctx context.Context, account *types.Account) error DeleteAccount(ctx context.Context, account *types.Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error - GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) + GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) - SaveUsers(accountID string, users map[string]*types.User) error + SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error + DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error + GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error) + GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) + GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) + MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error + SavePAT(ctx context.Context, strength LockingStrength, pat *types.PersonalAccessToken) error + DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error + GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index 1c0767bde..41b8fa2f7 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -37,7 +37,7 @@ CREATE INDEX `idx_network_resources_id` ON `network_resources`(`id`); CREATE INDEX `idx_networks_id` ON `networks`(`id`); CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`); -INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0); INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,''); diff --git a/management/server/types/personal_access_token.go b/management/server/types/personal_access_token.go index ff157fcc6..0aa6b152b 100644 --- a/management/server/types/personal_access_token.go +++ b/management/server/types/personal_access_token.go @@ -75,7 +75,7 @@ type PersonalAccessTokenGenerated struct { // CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User. // Additionally, it will return the token in plain text once, to give to the user and only save a hashed version -func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) { +func CreateNewPAT(name string, expirationInDays int, targetID, createdBy string) (*PersonalAccessTokenGenerated, error) { hashedToken, plainToken, err := generateNewToken() if err != nil { return nil, err @@ -84,6 +84,7 @@ func CreateNewPAT(name string, expirationInDays int, createdBy string) (*Persona return &PersonalAccessTokenGenerated{ PersonalAccessToken: PersonalAccessToken{ ID: xid.New().String(), + UserID: targetID, Name: name, HashedToken: hashedToken, ExpirationDate: util.ToPtr(currentTime.AddDate(0, 0, expirationInDays)), diff --git a/management/server/user.go b/management/server/user.go index 17770a423..6ba9b68d3 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -4,13 +4,10 @@ import ( "context" "errors" "fmt" - "slices" "strings" "time" "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/activity" nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" @@ -20,6 +17,7 @@ import ( "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" + log "github.com/sirupsen/logrus" ) // createServiceUser creates a new service user under the given account. @@ -27,30 +25,29 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { - return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID) + return nil, err } - executingUser := account.Users[initiatorUserID] - if executingUser == nil { - return nil, status.Errorf(status.NotFound, "user not found") + if initiatorUser.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - if !executingUser.HasAdminPower() { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can create service users") + + if !initiatorUser.HasAdminPower() { + return nil, status.NewAdminPermissionError() } if role == types.UserRoleOwner { - return nil, status.Errorf(status.InvalidArgument, "can't create a service user with owner role") + return nil, status.NewServiceUserRoleInvalidError() } newUserID := uuid.New().String() newUser := types.NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, types.UserIssuedAPI) + newUser.AccountID = accountID log.WithContext(ctx).Debugf("New User: %v", newUser) - account.Users[newUserID] = newUser - err = am.Store.SaveAccount(ctx, account) - if err != nil { + if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil { return nil, err } @@ -87,40 +84,67 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return nil, status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites") } - if invite == nil { - return nil, fmt.Errorf("provided user update is nil") + if err := validateUserInvite(invite); err != nil { + return nil, err } - invitedRole := types.StrRoleToUserRole(invite.Role) - - switch { - case invite.Name == "": - return nil, status.Errorf(status.InvalidArgument, "name can't be empty") - case invite.Email == "": - return nil, status.Errorf(status.InvalidArgument, "email can't be empty") - case invitedRole == types.UserRoleOwner: - return nil, status.Errorf(status.InvalidArgument, "can't invite a user with owner role") - default: - } - - account, err := am.Store.GetAccount(ctx, accountID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { - return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID) + return nil, err } - initiatorUser, err := account.FindUser(userID) - if err != nil { - return nil, status.Errorf(status.NotFound, "initiator user with ID %s doesn't exist", userID) + if initiatorUser.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } inviterID := userID if initiatorUser.IsServiceUser { - inviterID = account.CreatedBy + createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + inviterID = createdBy } + idpUser, err := am.createNewIdpUser(ctx, accountID, inviterID, invite) + if err != nil { + return nil, err + } + + newUser := &types.User{ + Id: idpUser.ID, + AccountID: accountID, + Role: types.StrRoleToUserRole(invite.Role), + AutoGroups: invite.AutoGroups, + Issued: invite.Issued, + IntegrationReference: invite.IntegrationReference, + CreatedAt: time.Now().UTC(), + } + + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil { + return nil, err + } + + _, err = am.refreshCache(ctx, accountID) + if err != nil { + return nil, err + } + + am.StoreEvent(ctx, userID, newUser.Id, accountID, activity.UserInvited, nil) + + return newUser.ToUserInfo(idpUser, settings) +} + +// createNewIdpUser validates the invite and creates a new user in the IdP +func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) { // inviterUser is the one who is inviting the new user - inviterUser, err := am.lookupUserInCache(ctx, inviterID, account) - if err != nil || inviterUser == nil { + inviterUser, err := am.lookupUserInCache(ctx, inviterID, accountID) + if err != nil { return nil, status.Errorf(status.NotFound, "inviter user with ID %s doesn't exist in IdP", inviterID) } @@ -143,34 +167,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account") } - idpUser, err := am.idpManager.CreateUser(ctx, invite.Email, invite.Name, accountID, inviterUser.Email) - if err != nil { - return nil, err - } - - newUser := &types.User{ - Id: idpUser.ID, - Role: invitedRole, - AutoGroups: invite.AutoGroups, - Issued: invite.Issued, - IntegrationReference: invite.IntegrationReference, - CreatedAt: time.Now().UTC(), - } - account.Users[idpUser.ID] = newUser - - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - - _, err = am.refreshCache(ctx, account.Id) - if err != nil { - return nil, err - } - - am.StoreEvent(ctx, userID, newUser.Id, accountID, activity.UserInvited, nil) - - return newUser.ToUserInfo(idpUser, account.Settings) + return am.idpManager.CreateUser(ctx, invite.Email, invite.Name, accountID, inviterUser.Email) } func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { @@ -210,60 +207,51 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A // ListUsers returns lists of all users under the account. // It doesn't populate user information such as email or name. func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - - users := make([]*types.User, 0, len(account.Users)) - for _, item := range account.Users { - users = append(users, item) - } - - return users, nil + return am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *types.Account, initiatorUserID string, targetUser *types.User) { +func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error { + if err := am.Store.DeleteUser(ctx, store.LockingStrengthUpdate, accountID, targetUser.Id); err != nil { + return err + } meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt} - am.StoreEvent(ctx, initiatorUserID, targetUser.Id, account.Id, activity.ServiceUserDeleted, meta) - delete(account.Users, targetUser.Id) + am.StoreEvent(ctx, initiatorUserID, targetUser.Id, accountID, activity.ServiceUserDeleted, meta) + return nil } // DeleteUser deletes a user from the given account. -func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error { +func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { if initiatorUserID == targetUserID { return status.Errorf(status.InvalidArgument, "self deletion is not allowed") } + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return err } - executingUser := account.Users[initiatorUserID] - if executingUser == nil { - return status.Errorf(status.NotFound, "user not found") - } - if !executingUser.HasAdminPower() { - return status.Errorf(status.PermissionDenied, "only users with admin power can delete users") + if initiatorUser.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - targetUser := account.Users[targetUserID] - if targetUser == nil { - return status.Errorf(status.NotFound, "target user not found") + if !initiatorUser.HasAdminPower() { + return status.NewAdminPermissionError() + } + + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) + if err != nil { + return err } if targetUser.Role == types.UserRoleOwner { - return status.Errorf(status.PermissionDenied, "unable to delete a user with owner role") + return status.NewOwnerDeletePermissionError() } // disable deleting integration user if the initiator is not admin service user - if targetUser.Issued == types.UserIssuedIntegration && !executingUser.IsServiceUser { + if targetUser.Issued == types.UserIssuedIntegration && !initiatorUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only integration service user can delete this user") } @@ -273,64 +261,26 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return status.Errorf(status.PermissionDenied, "service user is marked as non-deletable") } - am.deleteServiceUser(ctx, account, initiatorUserID, targetUser) - return am.Store.SaveAccount(ctx, account) + return am.deleteServiceUser(ctx, accountID, initiatorUserID, targetUser) } - return am.deleteRegularUser(ctx, account, initiatorUserID, targetUserID) -} - -func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *types.Account, initiatorUserID, targetUserID string) error { - meta, updateAccountPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) + userInfo, err := am.getUserInfo(ctx, targetUser, accountID) if err != nil { return err } - delete(account.Users, targetUserID) - if updateAccountPeers { - account.Network.IncSerial() - } - - err = am.Store.SaveAccount(ctx, account) + updateAccountPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) if err != nil { return err } - am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) if updateAccountPeers { - am.UpdateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, accountID) } return nil } -func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *types.Account) (bool, error) { - peers, err := account.FindUserPeers(targetUserID) - if err != nil { - return false, status.Errorf(status.Internal, "failed to find user peers") - } - - hadPeers := len(peers) > 0 - if !hadPeers { - return false, nil - } - - eventsToStore, err := deletePeers(ctx, am, am.Store, account.Id, initiatorUserID, peers) - if err != nil { - return false, err - } - - for _, storeEvent := range eventsToStore { - storeEvent() - } - - for _, peer := range peers { - account.DeletePeer(peer.ID) - } - - return hadPeers, nil -} - // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -340,13 +290,17 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites") } - account, err := am.Store.GetAccount(ctx, accountID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { - return status.Errorf(status.NotFound, "account %s doesn't exist", accountID) + return err + } + + if initiatorUser.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } // check if the user is already registered with this ID - user, err := am.lookupUserInCache(ctx, targetUserID, account) + user, err := am.lookupUserInCache(ctx, targetUserID, accountID) if err != nil { return err } @@ -384,35 +338,31 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") } - account, err := am.Store.GetAccount(ctx, accountID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } - targetUser, ok := account.Users[targetUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") + if initiatorUser.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - executingUser, ok := account.Users[initiatorUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) + if err != nil { + return nil, err } - if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) { - return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user") + if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) { + return nil, status.NewAdminPermissionError() } - pat, err := types.CreateNewPAT(tokenName, expiresIn, executingUser.Id) + pat, err := types.CreateNewPAT(tokenName, expiresIn, targetUserID, initiatorUser.Id) if err != nil { return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) } - targetUser.PATs[pat.ID] = &pat.PersonalAccessToken - - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, status.Errorf(status.Internal, "failed to save account: %v", err) + if err = am.Store.SavePAT(ctx, store.LockingStrengthUpdate, &pat.PersonalAccessToken); err != nil { + return nil, err } meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} @@ -426,48 +376,36 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { - return status.Errorf(status.NotFound, "account not found: %s", err) + return err } - targetUser, ok := account.Users[targetUserID] - if !ok { - return status.Errorf(status.NotFound, "user not found") + if initiatorUser.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - executingUser, ok := account.Users[initiatorUserID] - if !ok { - return status.Errorf(status.NotFound, "user not found") + if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() { + return status.NewAdminPermissionError() } - if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) { - return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user") - } - - pat := targetUser.PATs[tokenID] - if pat == nil { - return status.Errorf(status.NotFound, "PAT not found") - } - - err = am.Store.DeleteTokenID2UserIDIndex(pat.ID) + pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID) if err != nil { - return status.Errorf(status.Internal, "Failed to delete token id index: %s", err) + return err } - err = am.Store.DeleteHashedPAT2TokenIDIndex(pat.HashedToken) + + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) if err != nil { - return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err) + return err + } + + if err = am.Store.DeletePAT(ctx, store.LockingStrengthUpdate, targetUserID, tokenID); err != nil { + return err } meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) - delete(targetUser.PATs, tokenID) - - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return status.Errorf(status.Internal, "Failed to save account: %s", err) - } return nil } @@ -478,22 +416,15 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) - if err != nil { - return nil, err + if initiatorUser.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") + if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() { + return nil, status.NewAdminPermissionError() } - for _, pat := range targetUser.PATsG { - if pat.ID == tokenID { - return pat.Copy(), nil - } - } - - return nil, status.Errorf(status.NotFound, "PAT not found") + return am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID) } // GetAllPATs returns all PATs for a user @@ -503,21 +434,15 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) - if err != nil { - return nil, err + if initiatorUser.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") + if initiatorUserID != targetUserID && initiatorUser.IsRegularUser() { + return nil, status.NewAdminPermissionError() } - pats := make([]*types.PersonalAccessToken, 0, len(targetUser.PATsG)) - for _, pat := range targetUser.PATsG { - pats = append(pats, pat.Copy()) - } - - return pats, nil + return am.Store.GetUserPATs(ctx, store.LockingStrengthShare, targetUserID) } // SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error. @@ -528,10 +453,6 @@ func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initia // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) { - if update == nil { - return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") - } - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -555,125 +476,113 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, nil //nolint:nilnil } - account, err := am.Store.GetAccount(ctx, accountID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } - initiatorUser, err := account.FindUser(initiatorUserID) - if err != nil { - return nil, err + if initiatorUser.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } if !initiatorUser.HasAdminPower() || initiatorUser.IsBlocked() { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations") + return nil, status.NewAdminPermissionError() } - updatedUsers := make([]*types.UserInfo, 0, len(updates)) - var ( - expiredPeers []*nbpeer.Peer - userIDs []string - eventsToStore []func() - ) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } - for _, update := range updates { - if update == nil { - return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") - } + var updateAccountPeers bool + var peersToExpire []*nbpeer.Peer + var addUserEvents []func() + var usersToSave = make([]*types.User, 0, len(updates)) - userIDs = append(userIDs, update.Id) + groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("error getting account groups: %w", err) + } - oldUser := account.Users[update.Id] - if oldUser == nil { - if !addIfNotExists { - return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id) + groupsMap := make(map[string]*types.Group, len(groups)) + for _, group := range groups { + groupsMap[group.ID] = group + } + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + for _, update := range updates { + if update == nil { + return status.Errorf(status.InvalidArgument, "provided user update is nil") } - // when addIfNotExists is set to true, the newUser will use all fields from the update input - oldUser = update - } - if err := validateUserUpdate(account, initiatorUser, oldUser, update); err != nil { - return nil, err - } - - // only auto groups, revoked status, and integration reference can be updated for now - newUser := oldUser.Copy() - newUser.Role = update.Role - newUser.Blocked = update.Blocked - newUser.AutoGroups = update.AutoGroups - // these two fields can't be set via API, only via direct call to the method - newUser.Issued = update.Issued - newUser.IntegrationReference = update.IntegrationReference - - transferredOwnerRole := handleOwnerRoleTransfer(account, initiatorUser, update) - account.Users[newUser.Id] = newUser - - if !oldUser.IsBlocked() && update.IsBlocked() { - // expire peers that belong to the user who's getting blocked - blockedPeers, err := account.FindUserPeers(update.Id) + userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate( + ctx, transaction, groupsMap, initiatorUser, update, addIfNotExists, settings, + ) if err != nil { - return nil, err + return fmt.Errorf("failed to process user update: %w", err) + } + usersToSave = append(usersToSave, updatedUser) + addUserEvents = append(addUserEvents, userEvents...) + peersToExpire = append(peersToExpire, userPeersToExpire...) + + if userHadPeers { + updateAccountPeers = true } - expiredPeers = append(expiredPeers, blockedPeers...) } - - peerGroupsAdded := make(map[string][]string) - peerGroupsRemoved := make(map[string][]string) - if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { - removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups) - // need force update all auto groups in any case they will not be duplicated - peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) - peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) - } - - userUpdateEvents := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole) - eventsToStore = append(eventsToStore, userUpdateEvents...) - - userGroupsEvents := am.prepareUserGroupsEvents(ctx, initiatorUser.Id, oldUser, newUser, account, peerGroupsAdded, peerGroupsRemoved) - eventsToStore = append(eventsToStore, userGroupsEvents...) - - updatedUserInfo, err := getUserInfo(ctx, am, newUser, account) - if err != nil { - return nil, err - } - updatedUsers = append(updatedUsers, updatedUserInfo) + return transaction.SaveUsers(ctx, store.LockingStrengthUpdate, usersToSave) + }) + if err != nil { + return nil, err } - if len(expiredPeers) > 0 { - if err := am.expireAndUpdatePeers(ctx, account.Id, expiredPeers); err != nil { + var updatedUsersInfo = make([]*types.UserInfo, 0, len(updates)) + + userInfos, err := am.GetUsersFromAccount(ctx, accountID, initiatorUserID) + if err != nil { + return nil, err + } + + for _, updatedUser := range usersToSave { + updatedUserInfo, ok := userInfos[updatedUser.Id] + if !ok || updatedUserInfo == nil { + return nil, fmt.Errorf("failed to get user: %s updated user info", updatedUser.Id) + } + updatedUsersInfo = append(updatedUsersInfo, updatedUserInfo) + } + + for _, addUserEvent := range addUserEvents { + addUserEvent() + } + + if len(peersToExpire) > 0 { + if err := am.expireAndUpdatePeers(ctx, accountID, peersToExpire); err != nil { log.WithContext(ctx).Errorf("failed update expired peers: %s", err) return nil, err } } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return nil, err + if settings.GroupsPropagationEnabled && updateAccountPeers { + if err = am.Store.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return nil, fmt.Errorf("failed to increment network serial: %w", err) + } + am.UpdateAccountPeers(ctx, accountID) } - if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { - am.UpdateAccountPeers(ctx, account.Id) - } - - for _, storeEvent := range eventsToStore { - storeEvent() - } - - return updatedUsers, nil + return updatedUsersInfo, nil } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. -func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, transferredOwnerRole bool) []func() { +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() { var eventsToStore []func() if oldUser.IsBlocked() != newUser.IsBlocked() { if newUser.IsBlocked() { eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserBlocked, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil) }) } else { eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserUnblocked, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil) }) } } @@ -681,115 +590,126 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in switch { case transferredOwnerRole: eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.TransferredOwnerRole, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil) }) case oldUser.Role != newUser.Role: eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) }) } return eventsToStore } -func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() { - var eventsToStore []func() - if newUser.AutoGroups != nil { - removedGroups := util.Difference(oldUser.AutoGroups, newUser.AutoGroups) - addedGroups := util.Difference(newUser.AutoGroups, oldUser.AutoGroups) +func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transaction store.Store, groupsMap map[string]*types.Group, + initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) { - removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved) - eventsToStore = append(eventsToStore, removedEvents...) - - addedEvents := am.handleGroupAddedToUser(ctx, initiatorUserID, oldUser, newUser, account, addedGroups, peerGroupsAdded) - eventsToStore = append(eventsToStore, addedEvents...) + if update == nil { + return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } - return eventsToStore + + oldUser, err := getUserOrCreateIfNotExists(ctx, transaction, update, addIfNotExists) + if err != nil { + return false, nil, nil, nil, err + } + + if err := validateUserUpdate(groupsMap, initiatorUser, oldUser, update); err != nil { + return false, nil, nil, nil, err + } + + // only auto groups, revoked status, and integration reference can be updated for now + updatedUser := oldUser.Copy() + updatedUser.AccountID = initiatorUser.AccountID + updatedUser.Role = update.Role + updatedUser.Blocked = update.Blocked + updatedUser.AutoGroups = update.AutoGroups + // these two fields can't be set via API, only via direct call to the method + updatedUser.Issued = update.Issued + updatedUser.IntegrationReference = update.IntegrationReference + + transferredOwnerRole, err := handleOwnerRoleTransfer(ctx, transaction, initiatorUser, update) + if err != nil { + return false, nil, nil, nil, err + } + + userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthUpdate, updatedUser.AccountID, update.Id) + if err != nil { + return false, nil, nil, nil, err + } + + var peersToExpire []*nbpeer.Peer + + if !oldUser.IsBlocked() && update.IsBlocked() { + peersToExpire = userPeers + } + + if update.AutoGroups != nil && settings.GroupsPropagationEnabled { + removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups) + updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups) + if err != nil { + return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err) + } + + if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, updatedGroups); err != nil { + return false, nil, nil, nil, fmt.Errorf("error saving groups: %w", err) + } + } + + updateAccountPeers := len(userPeers) > 0 + userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUser.Id, oldUser, updatedUser, transferredOwnerRole) + + return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil } -func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() { - var eventsToStore []func() - for _, g := range addedGroups { - group := account.GetGroup(g) - if group != nil { - eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser, - map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) - }) +// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist. +func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, update *types.User, addIfNotExists bool) (*types.User, error) { + existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id) + if err != nil { + if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { + if !addIfNotExists { + return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id) + } + return update, nil // use all fields from update if addIfNotExists is true } + return nil, err } - for groupID, peerIDs := range peerGroupsAdded { - group := account.GetGroup(groupID) - for _, peerID := range peerIDs { - peer := account.GetPeer(peerID) - eventsToStore = append(eventsToStore, func() { - meta := map[string]any{ - "group": group.Name, "group_id": group.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), - } - am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupAddedToPeer, meta) - }) - } - } - return eventsToStore + return existingUser, nil } -func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() { - var eventsToStore []func() - for _, g := range removedGroups { - group := account.GetGroup(g) - if group != nil { - eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) - }) - - } else { - log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) - } - } - for groupID, peerIDs := range peerGroupsRemoved { - group := account.GetGroup(groupID) - for _, peerID := range peerIDs { - peer := account.GetPeer(peerID) - eventsToStore = append(eventsToStore, func() { - meta := map[string]any{ - "group": group.Name, "group_id": group.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), - } - am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupRemovedFromPeer, meta) - }) - } - } - return eventsToStore -} - -func handleOwnerRoleTransfer(account *types.Account, initiatorUser, update *types.User) bool { +func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) { if initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner { newInitiatorUser := initiatorUser.Copy() newInitiatorUser.Role = types.UserRoleAdmin - account.Users[initiatorUser.Id] = newInitiatorUser - return true + + if err := transaction.SaveUser(ctx, store.LockingStrengthUpdate, newInitiatorUser); err != nil { + return false, err + } + return true, nil } - return false + return false, nil } // getUserInfo retrieves the UserInfo for a given User and Account. // If the AccountManager has a non-nil idpManager and the User is not a service user, // it will attempt to look up the UserData from the cache. -func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *types.User, account *types.Account) (*types.UserInfo, error) { +func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.User, accountID string) (*types.UserInfo, error) { + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + if !isNil(am.idpManager) && !user.IsServiceUser { - userData, err := am.lookupUserInCache(ctx, user.Id, account) + userData, err := am.lookupUserInCache(ctx, user.Id, accountID) if err != nil { return nil, err } - return user.ToUserInfo(userData, account.Settings) + return user.ToUserInfo(userData, settings) } - return user.ToUserInfo(nil, account.Settings) + return user.ToUserInfo(nil, settings) } // validateUserUpdate validates the update operation for a user. -func validateUserUpdate(account *types.Account, initiatorUser, oldUser, update *types.User) error { +func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUser, update *types.User) error { if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") } @@ -810,12 +730,12 @@ func validateUserUpdate(account *types.Account, initiatorUser, oldUser, update * } for _, newGroupID := range update.AutoGroups { - group, ok := account.Groups[newGroupID] + group, ok := groupsMap[newGroupID] if !ok { return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist", newGroupID, update.Id) } - if group.Name == "All" { + if group.IsGroupAll() { return status.Errorf(status.InvalidArgument, "can't add All group to the user") } } @@ -864,22 +784,38 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // based on provided user role. -func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) { - account, err := am.Store.GetAccount(ctx, accountID) +func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, initiatorUserID string) (map[string]*types.UserInfo, error) { + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) + if err != nil { + return nil, err + } + + if initiatorUser.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + return am.BuildUserInfosForAccount(ctx, accountID, initiatorUserID, accountUsers) +} + +// BuildUserInfosForAccount builds user info for the given account. +func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) { + var queriedUsers []*idp.UserData + var err error + + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } - queriedUsers := make([]*idp.UserData, 0) if !isNil(am.idpManager) { - users := make(map[string]userLoggedInOnce, len(account.Users)) + users := make(map[string]userLoggedInOnce, len(accountUsers)) usersFromIntegration := make([]*idp.UserData, 0) - for _, user := range account.Users { + for _, user := range accountUsers { if user.Issued == types.UserIssuedIntegration { key := user.IntegrationReference.CacheKey(accountID, user.Id) info, err := am.externalCacheManager.Get(am.ctx, key) @@ -904,33 +840,40 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun queriedUsers = append(queriedUsers, usersFromIntegration...) } - userInfos := make([]*types.UserInfo, 0) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + userInfosMap := make(map[string]*types.UserInfo) // in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo if len(queriedUsers) == 0 { - for _, accountUser := range account.Users { - if !(user.HasAdminPower() || user.IsServiceUser || user.Id == accountUser.Id) { + for _, accountUser := range accountUsers { + if initiatorUser.IsRegularUser() && initiatorUser.Id != accountUser.Id { // if user is not an admin then show only current user and do not show other users continue } - info, err := accountUser.ToUserInfo(nil, account.Settings) + + info, err := accountUser.ToUserInfo(nil, settings) if err != nil { return nil, err } - userInfos = append(userInfos, info) + userInfosMap[accountUser.Id] = info } - return userInfos, nil + + return userInfosMap, nil } - for _, localUser := range account.Users { - if !(user.HasAdminPower() || user.IsServiceUser) && user.Id != localUser.Id { + for _, localUser := range accountUsers { + if initiatorUser.IsRegularUser() && initiatorUser.Id != localUser.Id { // if user is not an admin then show only current user and do not show other users continue } var info *types.UserInfo if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { - info, err = localUser.ToUserInfo(queriedUser, account.Settings) + info, err = localUser.ToUserInfo(queriedUser, settings) if err != nil { return nil, err } @@ -943,7 +886,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun dashboardViewPermissions := "full" if !localUser.HasAdminPower() { dashboardViewPermissions = "limited" - if account.Settings.RegularUsersViewBlocked { + if settings.RegularUsersViewBlocked { dashboardViewPermissions = "blocked" } } @@ -960,10 +903,10 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun Permissions: types.UserPermissions{DashboardView: dashboardViewPermissions}, } } - userInfos = append(userInfos, info) + userInfosMap[info.ID] = info } - return userInfos, nil + return userInfosMap, nil } // expireAndUpdatePeers expires all peers of the given user and updates them in the account @@ -1017,55 +960,34 @@ func (am *DefaultAccountManager) deleteUserFromIDP(ctx context.Context, targetUs return nil } -func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(ctx context.Context, accountId, initiatorId, targetId string) (string, string, error) { - userInfos, err := am.GetUsersFromAccount(ctx, accountId, initiatorId) - if err != nil { - return "", "", err - } - for _, ui := range userInfos { - if ui.ID == targetId { - return ui.Email, ui.Name, nil - } - } - - return "", "", fmt.Errorf("user info not found for user: %s", targetId) -} - // DeleteRegularUsers deletes regular users from an account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. // // If an error occurs while deleting the user, the function skips it and continues deleting other users. // Errors are collected and returned at the end. -func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error { - account, err := am.Store.GetAccount(ctx, accountID) +func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error { + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return err } - executingUser := account.Users[initiatorUserID] - if executingUser == nil { - return status.Errorf(status.NotFound, "user not found") - } - if !executingUser.HasAdminPower() { - return status.Errorf(status.PermissionDenied, "only users with admin power can delete users") + if !initiatorUser.HasAdminPower() { + return status.NewAdminPermissionError() } - var ( - allErrors error - updateAccountPeers bool - ) + var allErrors error + var updateAccountPeers bool - deletedUsersMeta := make(map[string]map[string]any) for _, targetUserID := range targetUserIDs { if initiatorUserID == targetUserID { allErrors = errors.Join(allErrors, errors.New("self deletion is not allowed")) continue } - targetUser := account.Users[targetUserID] - if targetUser == nil { - allErrors = errors.Join(allErrors, fmt.Errorf("target user: %s not found", targetUserID)) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) + if err != nil { + allErrors = errors.Join(allErrors, err) continue } @@ -1075,88 +997,97 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } // disable deleting integration user if the initiator is not admin service user - if targetUser.Issued == types.UserIssuedIntegration && !executingUser.IsServiceUser { + if targetUser.Issued == types.UserIssuedIntegration && !initiatorUser.IsServiceUser { allErrors = errors.Join(allErrors, errors.New("only integration service user can delete this user")) continue } - meta, hadPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) - if err != nil { - allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err)) + userInfo, ok := userInfos[targetUserID] + if !ok || userInfo == nil { + allErrors = errors.Join(allErrors, fmt.Errorf("user info not found for user: %s", targetUserID)) continue } - if hadPeers { - updateAccountPeers = true + userHadPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) + if err != nil { + allErrors = errors.Join(allErrors, err) + continue } - delete(account.Users, targetUserID) - deletedUsersMeta[targetUserID] = meta - } - - if updateAccountPeers { - account.Network.IncSerial() - } - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return fmt.Errorf("failed to delete users: %w", err) + if userHadPeers { + updateAccountPeers = true + } } if updateAccountPeers { am.UpdateAccountPeers(ctx, accountID) } - for targetUserID, meta := range deletedUsersMeta { - am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) - } - return allErrors } -func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *types.Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) { - tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID) - if err != nil { - log.WithContext(ctx).Errorf("failed to resolve email address: %s", err) - return nil, false, err - } - +// deleteRegularUser deletes a specified user and their related peers from the account. +func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountID, initiatorUserID string, targetUserInfo *types.UserInfo) (bool, error) { if !isNil(am.idpManager) { // Delete if the user already exists in the IdP. Necessary in cases where a user account // was created where a user account was provisioned but the user did not sign in - _, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id}) + _, err := am.idpManager.GetUserDataByID(ctx, targetUserInfo.ID, idp.AppMetadata{WTAccountID: accountID}) if err == nil { - err = am.deleteUserFromIDP(ctx, targetUserID, account.Id) + err = am.deleteUserFromIDP(ctx, targetUserInfo.ID, accountID) if err != nil { - log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID) - return nil, false, err + log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserInfo.ID) + return false, err } } else { - log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err) + log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserInfo.ID, err) } } - hadPeers, err := am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account) + var addPeerRemovedEvents []func() + var updateAccountPeers bool + var targetUser *types.User + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserInfo.ID) + if err != nil { + return fmt.Errorf("failed to get user to delete: %w", err) + } + + userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, targetUserInfo.ID) + if err != nil { + return fmt.Errorf("failed to get user peers: %w", err) + } + + if len(userPeers) > 0 { + updateAccountPeers = true + addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, targetUserInfo.ID, userPeers) + if err != nil { + return fmt.Errorf("failed to delete user peers: %w", err) + } + } + + if err = transaction.DeleteUser(ctx, store.LockingStrengthUpdate, accountID, targetUserInfo.ID); err != nil { + return fmt.Errorf("failed to delete user: %s %w", targetUserInfo.ID, err) + } + + return nil + }) if err != nil { - return nil, false, err + return false, err } - u, err := account.FindUser(targetUserID) - if err != nil { - log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err) + for _, addPeerRemovedEvent := range addPeerRemovedEvents { + addPeerRemovedEvent() } + meta := map[string]any{"name": targetUserInfo.Name, "email": targetUserInfo.Email, "created_at": targetUser.CreatedAt} + am.StoreEvent(ctx, initiatorUserID, targetUser.Id, accountID, activity.UserDeleted, meta) - var tuCreatedAt time.Time - if u != nil { - tuCreatedAt = u.CreatedAt - } - - return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, hadPeers, nil + return updateAccountPeers, nil } // updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. -func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, - groupsToRemove []string) (groupsToUpdate []*types.Group, err error) { - +func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, groupsToRemove []string) (groupsToUpdate []*types.Group, err error) { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { return } @@ -1230,12 +1161,22 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa return nil, false } -// areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account. -func areUsersLinkedToPeers(account *types.Account, userIDs []string) bool { - for _, peer := range account.Peers { - if slices.Contains(userIDs, peer.UserID) { - return true - } +func validateUserInvite(invite *types.UserInfo) error { + if invite == nil { + return fmt.Errorf("provided user update is nil") } - return false + + invitedRole := types.StrRoleToUserRole(invite.Role) + + switch { + case invite.Name == "": + return status.Errorf(status.InvalidArgument, "name can't be empty") + case invite.Email == "": + return status.Errorf(status.InvalidArgument, "email can't be empty") + case invitedRole == types.UserRoleOwner: + return status.Errorf(status.InvalidArgument, "can't invite a user with owner role") + default: + } + + return nil } diff --git a/management/server/user_test.go b/management/server/user_test.go index a028d164b..4a532c8a6 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -11,6 +11,7 @@ import ( cacheStore "github.com/eko/gocache/v3/store" "github.com/google/go-cmp/cmp" "github.com/netbirdio/netbird/management/server/util" + "golang.org/x/exp/maps" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" @@ -45,7 +46,7 @@ const ( ) func TestUser_CreatePAT_ForSameUser(t *testing.T) { - store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) if err != nil { t.Fatalf("Error when creating store: %s", err) } @@ -53,13 +54,13 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err = store.SaveAccount(context.Background(), account) + err = s.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } am := DefaultAccountManager{ - Store: store, + Store: s, eventStore: &activity.InMemoryEventStore{}, } @@ -81,7 +82,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { assert.Equal(t, pat.ID, tokenID) - user, err := am.Store.GetUserByTokenID(context.Background(), tokenID) + user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID) if err != nil { t.Fatalf("Error when getting user by token ID: %s", err) } @@ -855,7 +856,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { { name: "Delete non-existent user", userIDs: []string{"non-existent-user"}, - expectedReasons: []string{"target user: non-existent-user not found"}, + expectedReasons: []string{"user: non-existent-user not found"}, expectedNotDeleted: []string{}, }, { @@ -867,7 +868,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs) + userInfos, err := am.BuildUserInfosForAccount(context.Background(), mockAccountID, mockUserID, maps.Values(account.Users)) + assert.NoError(t, err) + + err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs, userInfos) if len(tc.expectedReasons) > 0 { assert.Error(t, err) var foundExpectedErrors int