mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-18 15:56:41 +02:00
[management] Refactor users to use store methods (#2917)
* Refactor setup key handling to use store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add lock to get account groups Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add check for regular user Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * get only required groups for auto-group validation Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add account lock and return auto groups map on validation Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor account peers update Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor groups to use store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor GetGroupByID and add NewGroupNotFoundError Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add AddPeer and RemovePeer methods to Group struct Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Preserve store engine in SqlStore transactions Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Run groups ops in transaction Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix missing group removed from setup key activity Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor posture checks to remove get and save account Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix refactor Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix sonar Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Change setup key log level to debug for missing group Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Retrieve modified peers once for group events Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor policy get and save account to use store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Retrieve policy groups and posture checks once for validation Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix typo Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add policy tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor anyGroupHasPeers to retrieve all groups once Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor dns settings to use store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add account locking and merge group deletion methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor name server groups to use store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add peer store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor ephemeral peers Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add lock for peer store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor peer handlers Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor peer to use store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix typo Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add locks and remove log Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * run peer ops in transaction Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * remove duplicate store method Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix peer fields updated after save Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Use update strength and simplify check Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * prevent changing ruleID when not empty Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * prevent duplicate rules during updates Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix lint Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor auth middleware Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor account methods and mock Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor user and PAT handling Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Remove db query context and fix get user by id Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix database transaction locking issue Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Use UTC time in test Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add account locks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix prevent users from creating PATs for other users Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add store locks and prevent fetching setup keys peers when retrieving user peers with empty userID Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add missing tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor test names and remove duplicate TestPostgresql_SavePeerStatus Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add account locks and remove redundant ephemeral check Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Retrieve all groups for peers and restrict groups for regular users Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix store tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * use account object to get validated peers Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Improve peer performance Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Get account direct from store without buffer Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add get peer groups tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Adjust benchmarks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Adjust benchmarks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * [management] Update benchmark workflow (#3181) * update local benchmark expectations * update cloud expectations * Add status error for generic result error Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Use integrated validator direct Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * update expectations * update expectations * update expectations * Refactor peer scheduler to retry every 3 seconds on errors Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * update expectations * fix validator * fix validator * fix validator * update timeouts * Refactor ToGroupsInfo to process slices of groups Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * update expectations * update expectations * update expectations * Bump integrations version Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor GetValidatedPeers Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * go mod tidy Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Use peers and groups map for peers validation Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * remove mysql from api benchmark tests * Fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix blocked db calls on user auto groups update Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * update expectations Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * update expectations Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Skip user check for system initiated peer deletion Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Remove context in db calls Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * update expectations Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * [management] Improve group peer/resource counting (#3192) * Fix sonar Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Adjust bench expectations Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Rename GetAccountInfoFromPAT to GetTokenInfo Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Remove global account lock for ListUsers Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * build userinfo after updating users in db Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * [management] Optimize user bulk deletion (#3315) * refactor building user infos Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * remove unused code Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor GetUsersFromAccount to return a map of UserInfo instead of a slice Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Export BuildUserInfosForAccount to account manager Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fetch account user info once for bulk users save Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Update user deletion expectations Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Set max open conns for activity store Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Update bench expectations Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Co-authored-by: Pascal Fischer <pascal@netbird.io> Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com>
This commit is contained in:
parent
abe8da697c
commit
4cdb2e533a
@ -260,6 +260,7 @@ func TestDNS_Integration(t *testing.T) {
|
||||
nsGroupReq := api.NameserverGroupRequest{
|
||||
Description: "Test",
|
||||
Enabled: true,
|
||||
Domains: []string{},
|
||||
Groups: []string{"cs1tnh0hhcjnqoiuebeg"},
|
||||
Name: "test",
|
||||
Nameservers: []api.Nameserver{
|
||||
|
@ -67,7 +67,7 @@ type AccountManager interface {
|
||||
SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
|
||||
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
|
||||
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
|
||||
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error)
|
||||
@ -79,7 +79,7 @@ type AccountManager interface {
|
||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
|
||||
GetPATInfo(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
|
||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||
@ -96,7 +96,7 @@ type AccountManager interface {
|
||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error)
|
||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error)
|
||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
|
||||
@ -149,6 +149,7 @@ type AccountManager interface {
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||
UpdateAccountPeers(ctx context.Context, accountID string)
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
@ -617,6 +618,12 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
if user.Role != types.UserRoleOwner {
|
||||
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
|
||||
}
|
||||
|
||||
userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users))
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
for _, otherUser := range account.Users {
|
||||
if otherUser.IsServiceUser {
|
||||
continue
|
||||
@ -626,13 +633,23 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
continue
|
||||
}
|
||||
|
||||
deleteUserErr := am.deleteRegularUser(ctx, account, userID, otherUser.Id)
|
||||
userInfo, ok := userInfosMap[otherUser.Id]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "user info not found for user %s", otherUser.Id)
|
||||
}
|
||||
|
||||
_, deleteUserErr := am.deleteRegularUser(ctx, accountID, userID, userInfo)
|
||||
if deleteUserErr != nil {
|
||||
return deleteUserErr
|
||||
}
|
||||
}
|
||||
|
||||
err = am.deleteRegularUser(ctx, account, userID, userID)
|
||||
userInfo, ok := userInfosMap[userID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "user info not found for user %s", userID)
|
||||
}
|
||||
|
||||
_, err = am.deleteRegularUser(ctx, accountID, userID, userInfo)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err)
|
||||
return err
|
||||
@ -689,20 +706,8 @@ func isNil(i idp.Manager) bool {
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cachedAccount := &types.Account{
|
||||
Id: accountID,
|
||||
Users: make(map[string]*types.User),
|
||||
}
|
||||
for _, user := range accountUsers {
|
||||
cachedAccount.Users[user.Id] = user
|
||||
}
|
||||
|
||||
// user can be nil if it wasn't found (e.g., just created)
|
||||
user, err := am.lookupUserInCache(ctx, userID, cachedAccount)
|
||||
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -778,10 +783,15 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
|
||||
}
|
||||
|
||||
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
|
||||
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *types.Account) (*idp.UserData, error) {
|
||||
users := make(map[string]userLoggedInOnce, len(account.Users))
|
||||
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make(map[string]userLoggedInOnce, len(accountUsers))
|
||||
// ignore service users and users provisioned by integrations than are never logged in
|
||||
for _, user := range account.Users {
|
||||
for _, user := range accountUsers {
|
||||
if user.IsServiceUser {
|
||||
continue
|
||||
}
|
||||
@ -790,8 +800,8 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
|
||||
}
|
||||
users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero())
|
||||
}
|
||||
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, account.Id)
|
||||
userData, err := am.lookupCache(ctx, users, account.Id)
|
||||
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, accountID)
|
||||
userData, err := am.lookupCache(ctx, users, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -804,13 +814,13 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
|
||||
|
||||
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
|
||||
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
|
||||
user, err := account.FindUser(userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, account.Id)
|
||||
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := user.IntegrationReference.CacheKey(account.Id, userID)
|
||||
key := user.IntegrationReference.CacheKey(accountID, userID)
|
||||
ud, err := am.externalCacheManager.Get(am.ctx, key)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err)
|
||||
@ -1050,9 +1060,9 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||
defer unlockAccount()
|
||||
|
||||
usersMap := make(map[string]*types.User)
|
||||
usersMap[claims.UserId] = types.NewRegularUser(claims.UserId)
|
||||
err := am.Store.SaveUsers(domainAccountID, usersMap)
|
||||
newUser := types.NewRegularUser(claims.UserId)
|
||||
newUser.AccountID = domainAccountID
|
||||
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -1075,12 +1085,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1090,17 +1095,17 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
|
||||
}
|
||||
|
||||
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
|
||||
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, account.Id)
|
||||
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID)
|
||||
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
|
||||
// Our job is to just reload cache.
|
||||
go func() {
|
||||
_, err = am.refreshCache(ctx, account.Id)
|
||||
_, err = am.refreshCache(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id)
|
||||
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID)
|
||||
return
|
||||
}
|
||||
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, account.Id)
|
||||
am.StoreEvent(ctx, userID, userID, account.Id, activity.UserJoined, nil)
|
||||
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID)
|
||||
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
@ -1109,33 +1114,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
|
||||
|
||||
// MarkPATUsed marks a personal access token as used
|
||||
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
|
||||
|
||||
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByUser(ctx, user.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
account, err = am.Store.GetAccountByUser(ctx, user.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pat, ok := account.Users[user.Id].PATs[tokenID]
|
||||
if !ok {
|
||||
return fmt.Errorf("token not found")
|
||||
}
|
||||
|
||||
pat.LastUsed = util.ToPtr(time.Now().UTC())
|
||||
|
||||
return am.Store.SaveAccount(ctx, account)
|
||||
return am.Store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID)
|
||||
}
|
||||
|
||||
// GetAccount returns an account associated with this account ID.
|
||||
@ -1143,52 +1122,64 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetAccountFromPAT returns Account and User associated with a personal access token
|
||||
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
|
||||
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
|
||||
func (am *DefaultAccountManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
|
||||
user, pat, err = am.extractPATFromToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
|
||||
if err != nil {
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
return user, pat, domain, category, nil
|
||||
}
|
||||
|
||||
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
|
||||
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
|
||||
if len(token) != types.PATLength {
|
||||
return nil, nil, nil, fmt.Errorf("token has wrong length")
|
||||
return nil, nil, fmt.Errorf("token has incorrect length")
|
||||
}
|
||||
|
||||
prefix := token[:len(types.PATPrefix)]
|
||||
if prefix != types.PATPrefix {
|
||||
return nil, nil, nil, fmt.Errorf("token has wrong prefix")
|
||||
return nil, nil, fmt.Errorf("token has wrong prefix")
|
||||
}
|
||||
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
|
||||
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
|
||||
|
||||
verificationChecksum, err := base62.Decode(encodedChecksum)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
|
||||
return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
|
||||
}
|
||||
|
||||
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
if secretChecksum != verificationChecksum {
|
||||
return nil, nil, nil, fmt.Errorf("token checksum does not match")
|
||||
return nil, nil, fmt.Errorf("token checksum does not match")
|
||||
}
|
||||
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken)
|
||||
|
||||
var user *types.User
|
||||
var pat *types.PersonalAccessToken
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
|
||||
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByUser(ctx, user.Id)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
pat := user.PATs[tokenID]
|
||||
if pat == nil {
|
||||
return nil, nil, nil, fmt.Errorf("personal access token not found")
|
||||
}
|
||||
|
||||
return account, user, pat, nil
|
||||
return user, pat, nil
|
||||
}
|
||||
|
||||
// GetAccountByID returns an account associated with this account ID.
|
||||
@ -1334,7 +1325,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return fmt.Errorf("error getting user peers: %w", err)
|
||||
}
|
||||
|
||||
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
|
||||
updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error modifying user peers in groups: %w", err)
|
||||
}
|
||||
|
@ -732,6 +732,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
UserID: "someUser",
|
||||
HashedToken: encodedHashedToken,
|
||||
},
|
||||
},
|
||||
@ -745,14 +746,14 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
Store: store,
|
||||
}
|
||||
|
||||
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
|
||||
user, pat, _, _, err := am.GetPATInfo(context.Background(), token)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting Account from PAT: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, "account_id", account.Id)
|
||||
assert.Equal(t, "account_id", user.AccountID)
|
||||
assert.Equal(t, "someUser", user.Id)
|
||||
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
|
||||
assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
@ -95,6 +96,7 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.SetMaxOpenConns(runtime.NumCPU())
|
||||
|
||||
crypt, err := NewFieldEncrypt(encryptionKey)
|
||||
if err != nil {
|
||||
|
@ -43,7 +43,7 @@ func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, network
|
||||
)
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
accountManager.GetAccountFromPAT,
|
||||
accountManager.GetPATInfo,
|
||||
jwtValidator.ValidateAndParse,
|
||||
accountManager.MarkPATUsed,
|
||||
accountManager.CheckUserAccessByJWTGroups,
|
||||
|
@ -32,8 +32,8 @@ func initEventsTestData(account string, events ...*activity.Event) *handler {
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) {
|
||||
return make([]*types.UserInfo, 0), nil
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) {
|
||||
return make(map[string]*types.UserInfo), nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
|
@ -52,7 +52,7 @@ var usersTestAccount = &types.Account{
|
||||
Issued: types.UserIssuedAPI,
|
||||
},
|
||||
nonDeletableServiceUserID: {
|
||||
Id: serviceUserID,
|
||||
Id: nonDeletableServiceUserID,
|
||||
Role: "admin",
|
||||
IsServiceUser: true,
|
||||
NonDeletable: true,
|
||||
@ -70,10 +70,10 @@ func initUsersTestData() *handler {
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
|
||||
return usersTestAccount.Users[id], nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) {
|
||||
users := make([]*types.UserInfo, 0)
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) {
|
||||
usersInfos := make(map[string]*types.UserInfo)
|
||||
for _, v := range usersTestAccount.Users {
|
||||
users = append(users, &types.UserInfo{
|
||||
usersInfos[v.Id] = &types.UserInfo{
|
||||
ID: v.Id,
|
||||
Role: string(v.Role),
|
||||
Name: "",
|
||||
@ -81,9 +81,9 @@ func initUsersTestData() *handler {
|
||||
IsServiceUser: v.IsServiceUser,
|
||||
NonDeletable: v.NonDeletable,
|
||||
Issued: v.Issued,
|
||||
})
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
return usersInfos, nil
|
||||
},
|
||||
CreateUserFunc: func(_ context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) {
|
||||
if userID != existingUserID {
|
||||
|
@ -19,8 +19,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// GetAccountFromPATFunc function
|
||||
type GetAccountFromPATFunc func(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
|
||||
// GetAccountInfoFromPATFunc function
|
||||
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
|
||||
|
||||
// ValidateAndParseTokenFunc function
|
||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||
@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
getAccountFromPAT GetAccountFromPATFunc
|
||||
getAccountInfoFromPAT GetAccountInfoFromPATFunc
|
||||
validateAndParseToken ValidateAndParseTokenFunc
|
||||
markPATUsed MarkPATUsedFunc
|
||||
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
||||
@ -47,7 +47,7 @@ const (
|
||||
)
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||
audience string, userIdClaim string) *AuthMiddleware {
|
||||
if userIdClaim == "" {
|
||||
@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
|
||||
}
|
||||
|
||||
return &AuthMiddleware{
|
||||
getAccountFromPAT: getAccountFromPAT,
|
||||
getAccountInfoFromPAT: getAccountInfoFromPAT,
|
||||
validateAndParseToken: validateAndParseToken,
|
||||
markPATUsed: markPATUsed,
|
||||
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
||||
@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||
token, err := getTokenFromPATRequest(auth)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
|
||||
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
claimMaps := jwt.MapClaims{}
|
||||
claimMaps[m.userIDClaim] = user.Id
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
|
||||
claimMaps[jwtclaims.IsToken] = true
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
||||
|
@ -35,6 +35,7 @@ var testAccount = &types.Account{
|
||||
Users: map[string]*types.User{
|
||||
userID: {
|
||||
Id: userID,
|
||||
AccountID: accountID,
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
tokenID: {
|
||||
ID: tokenID,
|
||||
@ -50,11 +51,11 @@ var testAccount = &types.Account{
|
||||
},
|
||||
}
|
||||
|
||||
func mockGetAccountFromPAT(_ context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
|
||||
func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
|
||||
if token == PAT {
|
||||
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
|
||||
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
|
||||
}
|
||||
return nil, nil, nil, fmt.Errorf("PAT invalid")
|
||||
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
|
||||
@ -166,7 +167,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
)
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockGetAccountFromPAT,
|
||||
mockGetAccountInfoFromPAT,
|
||||
mockValidateAndParseToken,
|
||||
mockMarkPATUsed,
|
||||
mockCheckUserAccessByJWTGroups,
|
||||
|
@ -35,14 +35,14 @@ var benchCasesUsers = map[string]testing_tools.BenchmarkCase{
|
||||
|
||||
func BenchmarkUpdateUser(b *testing.B) {
|
||||
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||
"Users - XS": {MinMsPerOpLocal: 700, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 8000},
|
||||
"Users - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 4, MaxMsPerOpCICD: 50},
|
||||
"Users - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 250},
|
||||
"Users - L": {MinMsPerOpLocal: 60, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 90, MaxMsPerOpCICD: 700},
|
||||
"Peers - L": {MinMsPerOpLocal: 300, MaxMsPerOpLocal: 500, MinMsPerOpCICD: 550, MaxMsPerOpCICD: 2400},
|
||||
"Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 750, MaxMsPerOpCICD: 5000},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 1000},
|
||||
"Users - XL": {MinMsPerOpLocal: 350, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 3500},
|
||||
"Users - XS": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 160, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 310},
|
||||
"Users - S": {MinMsPerOpLocal: 0.3, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 15},
|
||||
"Users - M": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 3, MaxMsPerOpCICD: 20},
|
||||
"Users - L": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 50},
|
||||
"Peers - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 310},
|
||||
"Groups - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 120},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 50},
|
||||
"Users - XL": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 280},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
@ -118,14 +118,14 @@ func BenchmarkGetOneUser(b *testing.B) {
|
||||
|
||||
func BenchmarkGetAllUsers(b *testing.B) {
|
||||
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||
"Users - XS": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 180},
|
||||
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
||||
"Users - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 12, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
||||
"Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
||||
"Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
||||
"Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 140, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 200},
|
||||
"Users - XL": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 90},
|
||||
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
|
||||
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
|
||||
"Users - M": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 15},
|
||||
"Users - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 50},
|
||||
"Peers - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 55},
|
||||
"Groups - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
|
||||
"Users - XL": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 120, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
@ -141,7 +141,7 @@ func BenchmarkGetAllUsers(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId)
|
||||
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
@ -152,14 +152,14 @@ func BenchmarkGetAllUsers(b *testing.B) {
|
||||
|
||||
func BenchmarkDeleteUsers(b *testing.B) {
|
||||
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||
"Users - XS": {MinMsPerOpLocal: 1000, MaxMsPerOpLocal: 1600, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 11000},
|
||||
"Users - S": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200},
|
||||
"Users - M": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 230},
|
||||
"Users - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 45, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 190},
|
||||
"Peers - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 1800},
|
||||
"Groups - L": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 1200, MaxMsPerOpCICD: 7500},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 40, MaxMsPerOpCICD: 600},
|
||||
"Users - XL": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 400},
|
||||
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Users - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Users - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
|
@ -53,8 +53,8 @@ type MockAccountManager struct {
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error)
|
||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error)
|
||||
GetAccountFromPATFunc func(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||
GetPATInfoFunc func(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
|
||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@ -69,7 +69,7 @@ type MockAccountManager struct {
|
||||
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error)
|
||||
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
|
||||
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
|
||||
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
@ -110,6 +110,7 @@ type MockAccountManager struct {
|
||||
GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error)
|
||||
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
@ -165,7 +166,7 @@ func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userI
|
||||
}
|
||||
|
||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*types.UserInfo, error) {
|
||||
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) (map[string]*types.UserInfo, error) {
|
||||
if am.GetUsersFromAccountFunc != nil {
|
||||
return am.GetUsersFromAccountFunc(ctx, accountID, userID)
|
||||
}
|
||||
@ -238,12 +239,12 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
|
||||
if am.GetAccountFromPATFunc != nil {
|
||||
return am.GetAccountFromPATFunc(ctx, pat)
|
||||
// GetPATInfo mock implementation of GetPATInfo from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetPATInfo(ctx context.Context, pat string) (*types.User, *types.PersonalAccessToken, string, string, error) {
|
||||
if am.GetPATInfoFunc != nil {
|
||||
return am.GetPATInfoFunc(ctx, pat)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
|
||||
return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetPATInfo is not implemented")
|
||||
}
|
||||
|
||||
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
||||
@ -550,9 +551,9 @@ func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string,
|
||||
}
|
||||
|
||||
// DeleteRegularUsers mocks DeleteRegularUsers of the AccountManager interface
|
||||
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID string, initiatorUserID string, targetUserIDs []string) error {
|
||||
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error {
|
||||
if am.DeleteRegularUsersFunc != nil {
|
||||
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs)
|
||||
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs, userInfos)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented")
|
||||
}
|
||||
@ -849,3 +850,11 @@ func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peer
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
|
||||
}
|
||||
|
||||
// BuildUserInfosForAccount mocks BuildUserInfosForAccount of the AccountManager interface
|
||||
func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) {
|
||||
if am.BuildUserInfosForAccountFunc != nil {
|
||||
return am.BuildUserInfosForAccountFunc(ctx, accountID, initiatorUserID, accountUsers)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented")
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -28,7 +29,6 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
@ -1554,6 +1554,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
// Adding peer to group linked with policy should update account peers and send peer update
|
||||
t.Run("adding peer to group linked with policy", func(t *testing.T) {
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
|
@ -93,7 +93,7 @@ func NewPeerNotPartOfAccountError() error {
|
||||
|
||||
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
|
||||
func NewUserNotFoundError(userKey string) error {
|
||||
return Errorf(NotFound, "user not found: %s", userKey)
|
||||
return Errorf(NotFound, "user: %s not found", userKey)
|
||||
}
|
||||
|
||||
// NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer
|
||||
@ -191,3 +191,18 @@ func NewResourceNotPartOfNetworkError(resourceID, networkID string) error {
|
||||
func NewRouterNotPartOfNetworkError(routerID, networkID string) error {
|
||||
return Errorf(BadRequest, "router %s is not part of the network %s", routerID, networkID)
|
||||
}
|
||||
|
||||
// NewServiceUserRoleInvalidError creates a new Error with InvalidArgument type for creating a service user with owner role
|
||||
func NewServiceUserRoleInvalidError() error {
|
||||
return Errorf(InvalidArgument, "can't create a service user with owner role")
|
||||
}
|
||||
|
||||
// NewOwnerDeletePermissionError creates a new Error with PermissionDenied type for attempting
|
||||
// to delete a user with the owner role.
|
||||
func NewOwnerDeletePermissionError() error {
|
||||
return Errorf(PermissionDenied, "can't delete a user with the owner role")
|
||||
}
|
||||
|
||||
func NewPATNotFoundError(patID string) error {
|
||||
return Errorf(NotFound, "PAT: %s not found", patID)
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
@ -414,24 +415,16 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr
|
||||
}
|
||||
|
||||
// SaveUsers saves the given list of users to the database.
|
||||
// It updates existing users if a conflict occurs.
|
||||
func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) error {
|
||||
usersToSave := make([]types.User, 0, len(users))
|
||||
for _, user := range users {
|
||||
user.AccountID = accountID
|
||||
for id, pat := range user.PATs {
|
||||
pat.ID = id
|
||||
user.PATsG = append(user.PATsG, *pat)
|
||||
}
|
||||
usersToSave = append(usersToSave, *user)
|
||||
}
|
||||
err := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||
Create(&usersToSave).Error
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to save users to store: %v", err)
|
||||
func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error {
|
||||
if len(users) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&users)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save users to store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -439,7 +432,8 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) err
|
||||
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save user to store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -526,30 +520,17 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
|
||||
return token.ID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) {
|
||||
var token types.PersonalAccessToken
|
||||
result := s.db.First(&token, idQueryCondition, tokenID)
|
||||
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) {
|
||||
var user types.User
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
|
||||
Where("personal_access_tokens.id = ?", patID).First(&user)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
return nil, status.NewPATNotFoundError(patID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
if token.UserID == "" {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
|
||||
var user types.User
|
||||
result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
|
||||
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATsG))
|
||||
for _, pat := range user.PATsG {
|
||||
user.PATs[pat.ID] = pat.Copy()
|
||||
log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error)
|
||||
return nil, status.NewGetUserFromStoreError()
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
@ -557,8 +538,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
||||
var user types.User
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Preload(clause.Associations).First(&user, idQueryCondition, userID)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
@ -569,6 +549,25 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error {
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&types.PersonalAccessToken{}, "user_id = ?", userID)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&types.User{}, accountAndIDQueryCondition, accountID, userID).Error
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete user from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete user from store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
|
||||
var users []*types.User
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
|
||||
@ -899,6 +898,20 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
|
||||
return accountSettings.Settings, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
|
||||
var createdBy string
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
|
||||
Select("created_by").First(&createdBy, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
return createdBy, nil
|
||||
}
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
var user types.User
|
||||
@ -2053,3 +2066,94 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
|
||||
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
|
||||
var pat types.PersonalAccessToken
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewPATNotFoundError(hashedToken)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get pat by hash from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get pat by hash from store")
|
||||
}
|
||||
|
||||
return &pat, nil
|
||||
}
|
||||
|
||||
// GetPATByID retrieves a personal access token by its ID and user ID.
|
||||
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
|
||||
var pat types.PersonalAccessToken
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&pat, "id = ? AND user_id = ?", patID, userID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewPATNotFoundError(patID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get pat from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get pat from store")
|
||||
}
|
||||
|
||||
return &pat, nil
|
||||
}
|
||||
|
||||
// GetUserPATs retrieves personal access tokens for a user.
|
||||
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
|
||||
var pats []*types.PersonalAccessToken
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get user pat's from store")
|
||||
}
|
||||
|
||||
return pats, nil
|
||||
}
|
||||
|
||||
// MarkPATUsed marks a personal access token as used.
|
||||
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error {
|
||||
patCopy := types.PersonalAccessToken{
|
||||
LastUsed: util.ToPtr(time.Now().UTC()),
|
||||
}
|
||||
|
||||
fieldsToUpdate := []string{"last_used"}
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Select(fieldsToUpdate).
|
||||
Where(idQueryCondition, patID).Updates(&patCopy)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to mark pat as used: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to mark pat as used")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewPATNotFoundError(patID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SavePAT saves a personal access token to the database.
|
||||
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *types.PersonalAccessToken) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save pat to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePAT deletes a personal access token from the database.
|
||||
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&types.PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete pat from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete pat from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewPATNotFoundError(patID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -627,29 +627,6 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
}
|
||||
|
||||
func TestSqlite_GetUserByTokenID(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
user, err := store.GetUserByTokenID(context.Background(), id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, user.PATs[id].ID)
|
||||
|
||||
_, err = store.GetUserByTokenID(context.Background(), "non-existing-id")
|
||||
require.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
@ -962,23 +939,6 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
|
||||
require.Equal(t, id, token)
|
||||
}
|
||||
|
||||
func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
user, err := store.GetUserByTokenID(context.Background(), id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, user.PATs[id].ID)
|
||||
}
|
||||
|
||||
func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
@ -1182,7 +1142,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSqlite_GetAccoundUsers(t *testing.T) {
|
||||
func TestSqlStore_GetAccountUsers(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
@ -2915,3 +2875,326 @@ func TestSqlStore_DatabaseBlocking(t *testing.T) {
|
||||
|
||||
t.Logf("Test completed")
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountCreatedBy(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectError bool
|
||||
createdBy string
|
||||
}{
|
||||
{
|
||||
name: "existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectError: false,
|
||||
createdBy: "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||
},
|
||||
{
|
||||
name: "non-existing account ID",
|
||||
accountID: "nonexistent",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty account ID",
|
||||
accountID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
createdBy, err := store.GetAccountCreatedBy(context.Background(), LockingStrengthShare, tt.accountID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Empty(t, createdBy)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, createdBy)
|
||||
require.Equal(t, tt.createdBy, createdBy)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserByUserID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing user",
|
||||
userID: "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing user",
|
||||
userID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty user ID",
|
||||
userID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, tt.userID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, user)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.Equal(t, tt.userID, user.Id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserByPATID(t *testing.T) {
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id)
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveUser(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
user := &types.User{
|
||||
Id: "user-id",
|
||||
AccountID: accountID,
|
||||
Role: types.UserRoleAdmin,
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"groupA", "groupB"},
|
||||
Blocked: false,
|
||||
LastLogin: util.ToPtr(time.Now().UTC()),
|
||||
CreatedAt: time.Now().UTC().Add(-time.Hour),
|
||||
Issued: types.UserIssuedIntegration,
|
||||
}
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
saveUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, user.Id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user.Id, saveUser.Id)
|
||||
require.Equal(t, user.AccountID, saveUser.AccountID)
|
||||
require.Equal(t, user.Role, saveUser.Role)
|
||||
require.Equal(t, user.AutoGroups, saveUser.AutoGroups)
|
||||
require.WithinDurationf(t, user.GetLastLogin(), saveUser.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal")
|
||||
require.WithinDurationf(t, user.CreatedAt, saveUser.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
|
||||
require.Equal(t, user.Issued, saveUser.Issued)
|
||||
require.Equal(t, user.Blocked, saveUser.Blocked)
|
||||
require.Equal(t, user.IsServiceUser, saveUser.IsServiceUser)
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveUsers(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
accountUsers, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, accountUsers, 2)
|
||||
|
||||
users := []*types.User{
|
||||
{
|
||||
Id: "user-1",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
AutoGroups: []string{"groupA", "groupB"},
|
||||
},
|
||||
{
|
||||
Id: "user-2",
|
||||
AccountID: accountID,
|
||||
Issued: "integration",
|
||||
AutoGroups: []string{"groupA"},
|
||||
},
|
||||
}
|
||||
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, users)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountUsers, err = store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, accountUsers, 4)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteUser(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
err = store.DeleteUser(context.Background(), LockingStrengthUpdate, accountID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, userID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, user)
|
||||
|
||||
userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthShare, userID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, userPATs, 0)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPATByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
patID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing PAT",
|
||||
patID: "9dj38s35-63fb-11ec-90d6-0242ac120003",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing PAT",
|
||||
patID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty PAT ID",
|
||||
patID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, tt.patID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, pat)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pat)
|
||||
require.Equal(t, tt.patID, pat.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserPATs(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthShare, "f4f6d672-63fb-11ec-90d6-0242ac120003")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, userPATs, 1)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPATByHashedToken(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "SoMeHaShEdToKeN")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "9dj38s35-63fb-11ec-90d6-0242ac120003", pat.ID)
|
||||
}
|
||||
|
||||
func TestSqlStore_MarkPATUsed(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
|
||||
patID := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
err = store.MarkPATUsed(context.Background(), LockingStrengthUpdate, patID)
|
||||
require.NoError(t, err)
|
||||
|
||||
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, patID)
|
||||
require.NoError(t, err)
|
||||
now := time.Now().UTC()
|
||||
require.WithinRange(t, pat.LastUsed.UTC(), now.Add(-15*time.Second), now, "LastUsed should be within 1 second of now")
|
||||
}
|
||||
|
||||
func TestSqlStore_SavePAT(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "edafee4e-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
pat := &types.PersonalAccessToken{
|
||||
ID: "pat-id",
|
||||
UserID: userID,
|
||||
Name: "token",
|
||||
HashedToken: "SoMeHaShEdToKeN",
|
||||
ExpirationDate: util.ToPtr(time.Now().UTC().Add(12 * time.Hour)),
|
||||
CreatedBy: userID,
|
||||
CreatedAt: time.Now().UTC().Add(time.Hour),
|
||||
LastUsed: util.ToPtr(time.Now().UTC().Add(-15 * time.Minute)),
|
||||
}
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
savePAT, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, pat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pat.ID, savePAT.ID)
|
||||
require.Equal(t, pat.UserID, savePAT.UserID)
|
||||
require.Equal(t, pat.HashedToken, savePAT.HashedToken)
|
||||
require.Equal(t, pat.CreatedBy, savePAT.CreatedBy)
|
||||
require.WithinDurationf(t, pat.GetExpirationDate(), savePAT.ExpirationDate.UTC(), time.Millisecond, "ExpirationDate should be equal")
|
||||
require.WithinDurationf(t, pat.CreatedAt, savePAT.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
|
||||
require.WithinDurationf(t, pat.GetLastUsed(), savePAT.LastUsed.UTC(), time.Millisecond, "LastUsed should be equal")
|
||||
}
|
||||
|
||||
func TestSqlStore_DeletePAT(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
|
||||
patID := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
err = store.DeletePAT(context.Background(), LockingStrengthUpdate, userID, patID)
|
||||
require.NoError(t, err)
|
||||
|
||||
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, patID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, pat)
|
||||
}
|
||||
|
@ -59,21 +59,30 @@ type Store interface {
|
||||
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error)
|
||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error)
|
||||
GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
|
||||
SaveAccount(ctx context.Context, account *types.Account) error
|
||||
DeleteAccount(ctx context.Context, account *types.Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error)
|
||||
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
|
||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
|
||||
SaveUsers(accountID string, users map[string]*types.User) error
|
||||
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
|
||||
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
|
||||
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
|
||||
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)
|
||||
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
|
||||
SavePAT(ctx context.Context, strength LockingStrength, pat *types.PersonalAccessToken) error
|
||||
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
|
||||
|
||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
|
||||
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
|
||||
|
2
management/server/testdata/store.sql
vendored
2
management/server/testdata/store.sql
vendored
@ -37,7 +37,7 @@ CREATE INDEX `idx_network_resources_id` ON `network_resources`(`id`);
|
||||
CREATE INDEX `idx_networks_id` ON `networks`(`id`);
|
||||
CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`);
|
||||
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0);
|
||||
INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');
|
||||
|
@ -75,7 +75,7 @@ type PersonalAccessTokenGenerated struct {
|
||||
|
||||
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
|
||||
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
|
||||
func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) {
|
||||
func CreateNewPAT(name string, expirationInDays int, targetID, createdBy string) (*PersonalAccessTokenGenerated, error) {
|
||||
hashedToken, plainToken, err := generateNewToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -84,6 +84,7 @@ func CreateNewPAT(name string, expirationInDays int, createdBy string) (*Persona
|
||||
return &PersonalAccessTokenGenerated{
|
||||
PersonalAccessToken: PersonalAccessToken{
|
||||
ID: xid.New().String(),
|
||||
UserID: targetID,
|
||||
Name: name,
|
||||
HashedToken: hashedToken,
|
||||
ExpirationDate: util.ToPtr(currentTime.AddDate(0, 0, expirationInDays)),
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -11,6 +11,7 @@ import (
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@ -45,7 +46,7 @@ const (
|
||||
)
|
||||
|
||||
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
@ -53,13 +54,13 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
err = s.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
Store: s,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
@ -81,7 +82,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
|
||||
assert.Equal(t, pat.ID, tokenID)
|
||||
|
||||
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID)
|
||||
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting user by token ID: %s", err)
|
||||
}
|
||||
@ -855,7 +856,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
{
|
||||
name: "Delete non-existent user",
|
||||
userIDs: []string{"non-existent-user"},
|
||||
expectedReasons: []string{"target user: non-existent-user not found"},
|
||||
expectedReasons: []string{"user: non-existent-user not found"},
|
||||
expectedNotDeleted: []string{},
|
||||
},
|
||||
{
|
||||
@ -867,7 +868,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs)
|
||||
userInfos, err := am.BuildUserInfosForAccount(context.Background(), mockAccountID, mockUserID, maps.Values(account.Users))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs, userInfos)
|
||||
if len(tc.expectedReasons) > 0 {
|
||||
assert.Error(t, err)
|
||||
var foundExpectedErrors int
|
||||
|
Loading…
x
Reference in New Issue
Block a user