[management] Refactor users to use store methods (#2917)

* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor account peers update

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor groups to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor GetGroupByID and add NewGroupNotFoundError

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add AddPeer and RemovePeer methods to Group struct

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Preserve store engine in SqlStore transactions

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run groups ops in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor posture checks to remove get and save account

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Change setup key log level to debug for missing group

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve modified peers once for group events

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor policy get and save account to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve policy groups and posture checks once for validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix typo

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add policy tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor anyGroupHasPeers to retrieve all groups once

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor dns settings to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add account locking and merge group deletion methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor name server groups to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add peer store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor ephemeral peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add lock for peer store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor peer handlers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor peer to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix typo

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add locks and remove log

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* run peer ops in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove duplicate store method

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix peer fields updated after save

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Use update strength and simplify check

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* prevent changing ruleID when not empty

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* prevent duplicate rules during updates

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor auth middleware

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor account methods and mock

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor user and PAT handling

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove db query context and fix get user by id

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix database transaction locking issue

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Use UTC time in test

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add account locks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix prevent users from creating PATs for other users

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add store locks and prevent fetching setup keys peers when retrieving user peers with empty userID

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add missing tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor test names and remove duplicate TestPostgresql_SavePeerStatus

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add account locks and remove redundant ephemeral check

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve all groups for peers and restrict groups for regular users

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix store tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* use account object to get validated peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Improve peer performance

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Get account direct from store without buffer

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add get peer groups tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Adjust benchmarks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Adjust benchmarks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* [management] Update benchmark workflow (#3181)

* update local benchmark expectations

* update cloud expectations

* Add status error for generic result error

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Use integrated validator direct

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update expectations

* update expectations

* update expectations

* Refactor peer scheduler to retry every 3 seconds on errors

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update expectations

* fix validator

* fix validator

* fix validator

* update timeouts

* Refactor ToGroupsInfo to process slices of groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update expectations

* update expectations

* update expectations

* Bump integrations version

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor GetValidatedPeers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* go mod tidy

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Use peers and groups map for peers validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove mysql from api benchmark tests

* Fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix blocked db calls on user auto groups update

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update expectations

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update expectations

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Skip user check for system initiated peer deletion

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove context in db calls

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update expectations

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* [management] Improve group peer/resource counting (#3192)

* Fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Adjust bench expectations

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Rename GetAccountInfoFromPAT to GetTokenInfo

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove global account lock for ListUsers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* build userinfo after updating users in db

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* [management] Optimize user bulk deletion  (#3315)

* refactor building user infos

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove unused code

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor GetUsersFromAccount to return a map of UserInfo instead of a slice

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Export BuildUserInfosForAccount to account manager

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fetch account user info once for bulk users save

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Update user deletion expectations

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Set max open conns for activity store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Update bench expectations

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com>
Co-authored-by: Pascal Fischer <pascal@netbird.io>
Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com>
This commit is contained in:
Bethuel Mmbaga 2025-02-17 21:43:12 +03:00 committed by GitHub
parent abe8da697c
commit 4cdb2e533a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1080 additions and 719 deletions

View File

@ -260,6 +260,7 @@ func TestDNS_Integration(t *testing.T) {
nsGroupReq := api.NameserverGroupRequest{
Description: "Test",
Enabled: true,
Domains: []string{},
Groups: []string{"cs1tnh0hhcjnqoiuebeg"},
Name: "test",
Nameservers: []api.Nameserver{

View File

@ -67,7 +67,7 @@ type AccountManager interface {
SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error)
@ -79,7 +79,7 @@ type AccountManager interface {
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
GetPATInfo(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
DeleteAccount(ctx context.Context, accountID, userID string) error
MarkPATUsed(ctx context.Context, tokenID string) error
GetUserByID(ctx context.Context, id string) (*types.User, error)
@ -96,7 +96,7 @@ type AccountManager interface {
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error)
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error)
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
@ -149,6 +149,7 @@ type AccountManager interface {
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
UpdateAccountPeers(ctx context.Context, accountID string)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
}
type DefaultAccountManager struct {
@ -617,6 +618,12 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
if user.Role != types.UserRoleOwner {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
}
userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users))
if err != nil {
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
}
for _, otherUser := range account.Users {
if otherUser.IsServiceUser {
continue
@ -626,13 +633,23 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
continue
}
deleteUserErr := am.deleteRegularUser(ctx, account, userID, otherUser.Id)
userInfo, ok := userInfosMap[otherUser.Id]
if !ok {
return status.Errorf(status.NotFound, "user info not found for user %s", otherUser.Id)
}
_, deleteUserErr := am.deleteRegularUser(ctx, accountID, userID, userInfo)
if deleteUserErr != nil {
return deleteUserErr
}
}
err = am.deleteRegularUser(ctx, account, userID, userID)
userInfo, ok := userInfosMap[userID]
if !ok {
return status.Errorf(status.NotFound, "user info not found for user %s", userID)
}
_, err = am.deleteRegularUser(ctx, accountID, userID, userInfo)
if err != nil {
log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err)
return err
@ -689,20 +706,8 @@ func isNil(i idp.Manager) bool {
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return err
}
cachedAccount := &types.Account{
Id: accountID,
Users: make(map[string]*types.User),
}
for _, user := range accountUsers {
cachedAccount.Users[user.Id] = user
}
// user can be nil if it wasn't found (e.g., just created)
user, err := am.lookupUserInCache(ctx, userID, cachedAccount)
user, err := am.lookupUserInCache(ctx, userID, accountID)
if err != nil {
return err
}
@ -778,10 +783,15 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
}
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *types.Account) (*idp.UserData, error) {
users := make(map[string]userLoggedInOnce, len(account.Users))
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
users := make(map[string]userLoggedInOnce, len(accountUsers))
// ignore service users and users provisioned by integrations than are never logged in
for _, user := range account.Users {
for _, user := range accountUsers {
if user.IsServiceUser {
continue
}
@ -790,8 +800,8 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
}
users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero())
}
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(ctx, users, account.Id)
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, accountID)
userData, err := am.lookupCache(ctx, users, accountID)
if err != nil {
return nil, err
}
@ -804,13 +814,13 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
user, err := account.FindUser(userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, account.Id)
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
return nil, err
}
key := user.IntegrationReference.CacheKey(account.Id, userID)
key := user.IntegrationReference.CacheKey(accountID, userID)
ud, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err)
@ -1050,9 +1060,9 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
usersMap := make(map[string]*types.User)
usersMap[claims.UserId] = types.NewRegularUser(claims.UserId)
err := am.Store.SaveUsers(domainAccountID, usersMap)
newUser := types.NewRegularUser(claims.UserId)
newUser.AccountID = domainAccountID
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser)
if err != nil {
return "", err
}
@ -1075,12 +1085,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
return nil
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := am.lookupUserInCache(ctx, userID, account)
user, err := am.lookupUserInCache(ctx, userID, accountID)
if err != nil {
return err
}
@ -1090,17 +1095,17 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
}
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, account.Id)
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID)
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
// Our job is to just reload cache.
go func() {
_, err = am.refreshCache(ctx, account.Id)
_, err = am.refreshCache(ctx, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id)
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID)
return
}
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, account.Id)
am.StoreEvent(ctx, userID, userID, account.Id, activity.UserJoined, nil)
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID)
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
}()
}
@ -1109,33 +1114,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
// MarkPATUsed marks a personal access token as used
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
if err != nil {
return err
}
account, err := am.Store.GetAccountByUser(ctx, user.Id)
if err != nil {
return err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(ctx, user.Id)
if err != nil {
return err
}
pat, ok := account.Users[user.Id].PATs[tokenID]
if !ok {
return fmt.Errorf("token not found")
}
pat.LastUsed = util.ToPtr(time.Now().UTC())
return am.Store.SaveAccount(ctx, account)
return am.Store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID)
}
// GetAccount returns an account associated with this account ID.
@ -1143,52 +1122,64 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
return am.Store.GetAccount(ctx, accountID)
}
// GetAccountFromPAT returns Account and User associated with a personal access token
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
func (am *DefaultAccountManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
user, pat, err = am.extractPATFromToken(ctx, token)
if err != nil {
return nil, nil, "", "", err
}
domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
if err != nil {
return nil, nil, "", "", err
}
return user, pat, domain, category, nil
}
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
if len(token) != types.PATLength {
return nil, nil, nil, fmt.Errorf("token has wrong length")
return nil, nil, fmt.Errorf("token has incorrect length")
}
prefix := token[:len(types.PATPrefix)]
if prefix != types.PATPrefix {
return nil, nil, nil, fmt.Errorf("token has wrong prefix")
return nil, nil, fmt.Errorf("token has wrong prefix")
}
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
verificationChecksum, err := base62.Decode(encodedChecksum)
if err != nil {
return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
}
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
if secretChecksum != verificationChecksum {
return nil, nil, nil, fmt.Errorf("token checksum does not match")
return nil, nil, fmt.Errorf("token checksum does not match")
}
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken)
var user *types.User
var pat *types.PersonalAccessToken
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
if err != nil {
return err
}
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
return err
})
if err != nil {
return nil, nil, nil, err
return nil, nil, err
}
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
if err != nil {
return nil, nil, nil, err
}
account, err := am.Store.GetAccountByUser(ctx, user.Id)
if err != nil {
return nil, nil, nil, err
}
pat := user.PATs[tokenID]
if pat == nil {
return nil, nil, nil, fmt.Errorf("personal access token not found")
}
return account, user, pat, nil
return user, pat, nil
}
// GetAccountByID returns an account associated with this account ID.
@ -1334,7 +1325,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error getting user peers: %w", err)
}
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
if err != nil {
return fmt.Errorf("error modifying user peers in groups: %w", err)
}

View File

@ -732,6 +732,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
PATs: map[string]*types.PersonalAccessToken{
"tokenId": {
ID: "tokenId",
UserID: "someUser",
HashedToken: encodedHashedToken,
},
},
@ -745,14 +746,14 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
Store: store,
}
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
user, pat, _, _, err := am.GetPATInfo(context.Background(), token)
if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err)
}
assert.Equal(t, "account_id", account.Id)
assert.Equal(t, "account_id", user.AccountID)
assert.Equal(t, "someUser", user.Id)
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID)
}
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {

View File

@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"path/filepath"
"runtime"
"time"
_ "github.com/mattn/go-sqlite3"
@ -95,6 +96,7 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
if err != nil {
return nil, err
}
db.SetMaxOpenConns(runtime.NumCPU())
crypt, err := NewFieldEncrypt(encryptionKey)
if err != nil {

View File

@ -43,7 +43,7 @@ func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, network
)
authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetAccountFromPAT,
accountManager.GetPATInfo,
jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed,
accountManager.CheckUserAccessByJWTGroups,

View File

@ -32,8 +32,8 @@ func initEventsTestData(account string, events ...*activity.Event) *handler {
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) {
return make([]*types.UserInfo, 0), nil
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) {
return make(map[string]*types.UserInfo), nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@ -52,7 +52,7 @@ var usersTestAccount = &types.Account{
Issued: types.UserIssuedAPI,
},
nonDeletableServiceUserID: {
Id: serviceUserID,
Id: nonDeletableServiceUserID,
Role: "admin",
IsServiceUser: true,
NonDeletable: true,
@ -70,10 +70,10 @@ func initUsersTestData() *handler {
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
return usersTestAccount.Users[id], nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) {
users := make([]*types.UserInfo, 0)
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) {
usersInfos := make(map[string]*types.UserInfo)
for _, v := range usersTestAccount.Users {
users = append(users, &types.UserInfo{
usersInfos[v.Id] = &types.UserInfo{
ID: v.Id,
Role: string(v.Role),
Name: "",
@ -81,9 +81,9 @@ func initUsersTestData() *handler {
IsServiceUser: v.IsServiceUser,
NonDeletable: v.NonDeletable,
Issued: v.Issued,
})
}
}
return users, nil
return usersInfos, nil
},
CreateUserFunc: func(_ context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) {
if userID != existingUserID {

View File

@ -19,8 +19,8 @@ import (
"github.com/netbirdio/netbird/management/server/types"
)
// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
// GetAccountInfoFromPATFunc function
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc
getAccountInfoFromPAT GetAccountInfoFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
@ -47,7 +47,7 @@ const (
)
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" {
@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
}
return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT,
getAccountInfoFromPAT: getAccountInfoFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
token, err := getTokenFromPATRequest(auth)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
return fmt.Errorf("error extracting token: %w", err)
}
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
}
@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
claimMaps := jwt.MapClaims{}
claimMaps[m.userIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
claimMaps[jwtclaims.IsToken] = true
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint

View File

@ -34,7 +34,8 @@ var testAccount = &types.Account{
Domain: domain,
Users: map[string]*types.User{
userID: {
Id: userID,
Id: userID,
AccountID: accountID,
PATs: map[string]*types.PersonalAccessToken{
tokenID: {
ID: tokenID,
@ -50,11 +51,11 @@ var testAccount = &types.Account{
},
}
func mockGetAccountFromPAT(_ context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
}
return nil, nil, nil, fmt.Errorf("PAT invalid")
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
@ -166,7 +167,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
)
authMiddleware := NewAuthMiddleware(
mockGetAccountFromPAT,
mockGetAccountInfoFromPAT,
mockValidateAndParseToken,
mockMarkPATUsed,
mockCheckUserAccessByJWTGroups,

View File

@ -35,14 +35,14 @@ var benchCasesUsers = map[string]testing_tools.BenchmarkCase{
func BenchmarkUpdateUser(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Users - XS": {MinMsPerOpLocal: 700, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 8000},
"Users - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 4, MaxMsPerOpCICD: 50},
"Users - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 250},
"Users - L": {MinMsPerOpLocal: 60, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 90, MaxMsPerOpCICD: 700},
"Peers - L": {MinMsPerOpLocal: 300, MaxMsPerOpLocal: 500, MinMsPerOpCICD: 550, MaxMsPerOpCICD: 2400},
"Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 750, MaxMsPerOpCICD: 5000},
"Setup Keys - L": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 1000},
"Users - XL": {MinMsPerOpLocal: 350, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 3500},
"Users - XS": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 160, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 310},
"Users - S": {MinMsPerOpLocal: 0.3, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 15},
"Users - M": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 3, MaxMsPerOpCICD: 20},
"Users - L": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 50},
"Peers - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 310},
"Groups - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 120},
"Setup Keys - L": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 50},
"Users - XL": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 280},
}
log.SetOutput(io.Discard)
@ -118,14 +118,14 @@ func BenchmarkGetOneUser(b *testing.B) {
func BenchmarkGetAllUsers(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Users - XS": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 180},
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
"Users - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 12, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
"Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
"Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
"Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
"Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 140, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 200},
"Users - XL": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 90},
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
"Users - M": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 15},
"Users - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 50},
"Peers - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 55},
"Groups - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
"Setup Keys - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
"Users - XL": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 120, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300},
}
log.SetOutput(io.Discard)
@ -141,7 +141,7 @@ func BenchmarkGetAllUsers(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId)
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
@ -152,14 +152,14 @@ func BenchmarkGetAllUsers(b *testing.B) {
func BenchmarkDeleteUsers(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Users - XS": {MinMsPerOpLocal: 1000, MaxMsPerOpLocal: 1600, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 11000},
"Users - S": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200},
"Users - M": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 230},
"Users - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 45, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 190},
"Peers - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 1800},
"Groups - L": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 1200, MaxMsPerOpCICD: 7500},
"Setup Keys - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 40, MaxMsPerOpCICD: 600},
"Users - XL": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 400},
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Users - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Users - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
}
log.SetOutput(io.Discard)

View File

@ -53,8 +53,8 @@ type MockAccountManager struct {
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error)
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error)
GetAccountFromPATFunc func(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
GetPATInfoFunc func(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
@ -69,7 +69,7 @@ type MockAccountManager struct {
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error)
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
@ -110,6 +110,7 @@ type MockAccountManager struct {
GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error)
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error)
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
}
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
@ -165,7 +166,7 @@ func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userI
}
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*types.UserInfo, error) {
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) (map[string]*types.UserInfo, error) {
if am.GetUsersFromAccountFunc != nil {
return am.GetUsersFromAccountFunc(ctx, accountID, userID)
}
@ -238,12 +239,12 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
if am.GetAccountFromPATFunc != nil {
return am.GetAccountFromPATFunc(ctx, pat)
// GetPATInfo mock implementation of GetPATInfo from server.AccountManager interface
func (am *MockAccountManager) GetPATInfo(ctx context.Context, pat string) (*types.User, *types.PersonalAccessToken, string, string, error) {
if am.GetPATInfoFunc != nil {
return am.GetPATInfoFunc(ctx, pat)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetPATInfo is not implemented")
}
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
@ -550,9 +551,9 @@ func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string,
}
// DeleteRegularUsers mocks DeleteRegularUsers of the AccountManager interface
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID string, initiatorUserID string, targetUserIDs []string) error {
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error {
if am.DeleteRegularUsersFunc != nil {
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs)
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs, userInfos)
}
return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented")
}
@ -849,3 +850,11 @@ func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peer
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
}
// BuildUserInfosForAccount mocks BuildUserInfosForAccount of the AccountManager interface
func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) {
if am.BuildUserInfosForAccountFunc != nil {
return am.BuildUserInfosForAccountFunc(ctx, accountID, initiatorUserID, accountUsers)
}
return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented")
}

View File

@ -13,6 +13,7 @@ import (
"testing"
"time"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@ -28,7 +29,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
@ -1554,7 +1554,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
// Adding peer to group linked with policy should update account peers and send peer update
t.Run("adding peer to group linked with policy", func(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
Enabled: true,
AccountID: account.Id,
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,

View File

@ -93,7 +93,7 @@ func NewPeerNotPartOfAccountError() error {
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
func NewUserNotFoundError(userKey string) error {
return Errorf(NotFound, "user not found: %s", userKey)
return Errorf(NotFound, "user: %s not found", userKey)
}
// NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer
@ -191,3 +191,18 @@ func NewResourceNotPartOfNetworkError(resourceID, networkID string) error {
func NewRouterNotPartOfNetworkError(routerID, networkID string) error {
return Errorf(BadRequest, "router %s is not part of the network %s", routerID, networkID)
}
// NewServiceUserRoleInvalidError creates a new Error with InvalidArgument type for creating a service user with owner role
func NewServiceUserRoleInvalidError() error {
return Errorf(InvalidArgument, "can't create a service user with owner role")
}
// NewOwnerDeletePermissionError creates a new Error with PermissionDenied type for attempting
// to delete a user with the owner role.
func NewOwnerDeletePermissionError() error {
return Errorf(PermissionDenied, "can't delete a user with the owner role")
}
func NewPATNotFoundError(patID string) error {
return Errorf(NotFound, "PAT: %s not found", patID)
}

View File

@ -15,6 +15,7 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/server/util"
log "github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
@ -414,24 +415,16 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr
}
// SaveUsers saves the given list of users to the database.
// It updates existing users if a conflict occurs.
func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) error {
usersToSave := make([]types.User, 0, len(users))
for _, user := range users {
user.AccountID = accountID
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
usersToSave = append(usersToSave, *user)
}
err := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).
Create(&usersToSave).Error
if err != nil {
return status.Errorf(status.Internal, "failed to save users to store: %v", err)
func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error {
if len(users) == 0 {
return nil
}
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&users)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save users to store")
}
return nil
}
@ -439,7 +432,8 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) err
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save user to store")
}
return nil
}
@ -526,30 +520,17 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
return token.ID, nil
}
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) {
var token types.PersonalAccessToken
result := s.db.First(&token, idQueryCondition, tokenID)
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) {
var user types.User
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
Where("personal_access_tokens.id = ?", patID).First(&user)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
return nil, status.NewPATNotFoundError(patID)
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return nil, status.NewGetAccountFromStoreError(result.Error)
}
if token.UserID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
var user types.User
result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATsG))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error)
return nil, status.NewGetUserFromStoreError()
}
return &user, nil
@ -557,8 +538,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
var user types.User
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).First(&user, idQueryCondition, userID)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@ -569,6 +549,25 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil
}
func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error {
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&types.PersonalAccessToken{}, "user_id = ?", userID)
if result.Error != nil {
return result.Error
}
return tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&types.User{}, accountAndIDQueryCondition, accountID, userID).Error
})
if err != nil {
log.WithContext(ctx).Errorf("failed to delete user from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete user from store")
}
return nil
}
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
var users []*types.User
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
@ -899,6 +898,20 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
return accountSettings.Settings, nil
}
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
var createdBy string
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
Select("created_by").First(&createdBy, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.NewAccountNotFoundError(accountID)
}
return "", status.NewGetAccountFromStoreError(result.Error)
}
return createdBy, nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user types.User
@ -2053,3 +2066,94 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki
return nil
}
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
var pat types.PersonalAccessToken
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPATNotFoundError(hashedToken)
}
log.WithContext(ctx).Errorf("failed to get pat by hash from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get pat by hash from store")
}
return &pat, nil
}
// GetPATByID retrieves a personal access token by its ID and user ID.
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
var pat types.PersonalAccessToken
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&pat, "id = ? AND user_id = ?", patID, userID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPATNotFoundError(patID)
}
log.WithContext(ctx).Errorf("failed to get pat from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get pat from store")
}
return &pat, nil
}
// GetUserPATs retrieves personal access tokens for a user.
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
var pats []*types.PersonalAccessToken
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get user pat's from store")
}
return pats, nil
}
// MarkPATUsed marks a personal access token as used.
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error {
patCopy := types.PersonalAccessToken{
LastUsed: util.ToPtr(time.Now().UTC()),
}
fieldsToUpdate := []string{"last_used"}
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Select(fieldsToUpdate).
Where(idQueryCondition, patID).Updates(&patCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to mark pat as used: %s", result.Error)
return status.Errorf(status.Internal, "failed to mark pat as used")
}
if result.RowsAffected == 0 {
return status.NewPATNotFoundError(patID)
}
return nil
}
// SavePAT saves a personal access token to the database.
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *types.PersonalAccessToken) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err)
return status.Errorf(status.Internal, "failed to save pat to store")
}
return nil
}
// DeletePAT deletes a personal access token from the database.
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&types.PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete pat from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete pat from store")
}
if result.RowsAffected == 0 {
return status.NewPATNotFoundError(patID)
}
return nil
}

View File

@ -627,29 +627,6 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}
func TestSqlite_GetUserByTokenID(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id)
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
_, err = store.GetUserByTokenID(context.Background(), "non-existing-id")
require.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}
func TestMigrate(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
@ -962,23 +939,6 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
require.Equal(t, id, token)
}
func TestPostgresql_GetUserByTokenID(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id)
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
}
func TestSqlite_GetTakenIPs(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
@ -1182,7 +1142,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
assert.NoError(t, err)
}
func TestSqlite_GetAccoundUsers(t *testing.T) {
func TestSqlStore_GetAccountUsers(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
@ -2915,3 +2875,326 @@ func TestSqlStore_DatabaseBlocking(t *testing.T) {
t.Logf("Test completed")
}
func TestSqlStore_GetAccountCreatedBy(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
tests := []struct {
name string
accountID string
expectError bool
createdBy string
}{
{
name: "existing account ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectError: false,
createdBy: "edafee4e-63fb-11ec-90d6-0242ac120003",
},
{
name: "non-existing account ID",
accountID: "nonexistent",
expectError: true,
},
{
name: "empty account ID",
accountID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
createdBy, err := store.GetAccountCreatedBy(context.Background(), LockingStrengthShare, tt.accountID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Empty(t, createdBy)
} else {
require.NoError(t, err)
require.NotNil(t, createdBy)
require.Equal(t, tt.createdBy, createdBy)
}
})
}
}
func TestSqlStore_GetUserByUserID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
tests := []struct {
name string
userID string
expectError bool
}{
{
name: "retrieve existing user",
userID: "edafee4e-63fb-11ec-90d6-0242ac120003",
expectError: false,
},
{
name: "retrieve non-existing user",
userID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty user ID",
userID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, tt.userID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, user)
} else {
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, tt.userID, user.Id)
}
})
}
}
func TestSqlStore_GetUserByPATID(t *testing.T) {
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
require.NoError(t, err)
require.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id)
}
func TestSqlStore_SaveUser(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
user := &types.User{
Id: "user-id",
AccountID: accountID,
Role: types.UserRoleAdmin,
IsServiceUser: false,
AutoGroups: []string{"groupA", "groupB"},
Blocked: false,
LastLogin: util.ToPtr(time.Now().UTC()),
CreatedAt: time.Now().UTC().Add(-time.Hour),
Issued: types.UserIssuedIntegration,
}
err = store.SaveUser(context.Background(), LockingStrengthUpdate, user)
require.NoError(t, err)
saveUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, user.Id)
require.NoError(t, err)
require.Equal(t, user.Id, saveUser.Id)
require.Equal(t, user.AccountID, saveUser.AccountID)
require.Equal(t, user.Role, saveUser.Role)
require.Equal(t, user.AutoGroups, saveUser.AutoGroups)
require.WithinDurationf(t, user.GetLastLogin(), saveUser.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal")
require.WithinDurationf(t, user.CreatedAt, saveUser.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
require.Equal(t, user.Issued, saveUser.Issued)
require.Equal(t, user.Blocked, saveUser.Blocked)
require.Equal(t, user.IsServiceUser, saveUser.IsServiceUser)
}
func TestSqlStore_SaveUsers(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
accountUsers, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Len(t, accountUsers, 2)
users := []*types.User{
{
Id: "user-1",
AccountID: accountID,
Issued: "api",
AutoGroups: []string{"groupA", "groupB"},
},
{
Id: "user-2",
AccountID: accountID,
Issued: "integration",
AutoGroups: []string{"groupA"},
},
}
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, users)
require.NoError(t, err)
accountUsers, err = store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Len(t, accountUsers, 4)
}
func TestSqlStore_DeleteUser(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
err = store.DeleteUser(context.Background(), LockingStrengthUpdate, accountID, userID)
require.NoError(t, err)
user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, userID)
require.Error(t, err)
require.Nil(t, user)
userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthShare, userID)
require.NoError(t, err)
require.Len(t, userPATs, 0)
}
func TestSqlStore_GetPATByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
tests := []struct {
name string
patID string
expectError bool
}{
{
name: "retrieve existing PAT",
patID: "9dj38s35-63fb-11ec-90d6-0242ac120003",
expectError: false,
},
{
name: "retrieve non-existing PAT",
patID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty PAT ID",
patID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, tt.patID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, pat)
} else {
require.NoError(t, err)
require.NotNil(t, pat)
require.Equal(t, tt.patID, pat.ID)
}
})
}
}
func TestSqlStore_GetUserPATs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthShare, "f4f6d672-63fb-11ec-90d6-0242ac120003")
require.NoError(t, err)
require.Len(t, userPATs, 1)
}
func TestSqlStore_GetPATByHashedToken(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "SoMeHaShEdToKeN")
require.NoError(t, err)
require.Equal(t, "9dj38s35-63fb-11ec-90d6-0242ac120003", pat.ID)
}
func TestSqlStore_MarkPATUsed(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
patID := "9dj38s35-63fb-11ec-90d6-0242ac120003"
err = store.MarkPATUsed(context.Background(), LockingStrengthUpdate, patID)
require.NoError(t, err)
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, patID)
require.NoError(t, err)
now := time.Now().UTC()
require.WithinRange(t, pat.LastUsed.UTC(), now.Add(-15*time.Second), now, "LastUsed should be within 1 second of now")
}
func TestSqlStore_SavePAT(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
userID := "edafee4e-63fb-11ec-90d6-0242ac120003"
pat := &types.PersonalAccessToken{
ID: "pat-id",
UserID: userID,
Name: "token",
HashedToken: "SoMeHaShEdToKeN",
ExpirationDate: util.ToPtr(time.Now().UTC().Add(12 * time.Hour)),
CreatedBy: userID,
CreatedAt: time.Now().UTC().Add(time.Hour),
LastUsed: util.ToPtr(time.Now().UTC().Add(-15 * time.Minute)),
}
err = store.SavePAT(context.Background(), LockingStrengthUpdate, pat)
require.NoError(t, err)
savePAT, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, pat.ID)
require.NoError(t, err)
require.Equal(t, pat.ID, savePAT.ID)
require.Equal(t, pat.UserID, savePAT.UserID)
require.Equal(t, pat.HashedToken, savePAT.HashedToken)
require.Equal(t, pat.CreatedBy, savePAT.CreatedBy)
require.WithinDurationf(t, pat.GetExpirationDate(), savePAT.ExpirationDate.UTC(), time.Millisecond, "ExpirationDate should be equal")
require.WithinDurationf(t, pat.CreatedAt, savePAT.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
require.WithinDurationf(t, pat.GetLastUsed(), savePAT.LastUsed.UTC(), time.Millisecond, "LastUsed should be equal")
}
func TestSqlStore_DeletePAT(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
patID := "9dj38s35-63fb-11ec-90d6-0242ac120003"
err = store.DeletePAT(context.Background(), LockingStrengthUpdate, userID, patID)
require.NoError(t, err)
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, patID)
require.Error(t, err)
require.Nil(t, pat)
}

View File

@ -59,21 +59,30 @@ type Store interface {
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error)
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error)
GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
SaveAccount(ctx context.Context, account *types.Account) error
DeleteAccount(ctx context.Context, account *types.Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error)
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
SaveUsers(accountID string, users map[string]*types.User) error
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
SavePAT(ctx context.Context, strength LockingStrength, pat *types.PersonalAccessToken) error
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)

View File

@ -37,7 +37,7 @@ CREATE INDEX `idx_network_resources_id` ON `network_resources`(`id`);
CREATE INDEX `idx_networks_id` ON `networks`(`id`);
CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0);
INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');

View File

@ -75,7 +75,7 @@ type PersonalAccessTokenGenerated struct {
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) {
func CreateNewPAT(name string, expirationInDays int, targetID, createdBy string) (*PersonalAccessTokenGenerated, error) {
hashedToken, plainToken, err := generateNewToken()
if err != nil {
return nil, err
@ -84,6 +84,7 @@ func CreateNewPAT(name string, expirationInDays int, createdBy string) (*Persona
return &PersonalAccessTokenGenerated{
PersonalAccessToken: PersonalAccessToken{
ID: xid.New().String(),
UserID: targetID,
Name: name,
HashedToken: hashedToken,
ExpirationDate: util.ToPtr(currentTime.AddDate(0, 0, expirationInDays)),

File diff suppressed because it is too large Load Diff

View File

@ -11,6 +11,7 @@ import (
cacheStore "github.com/eko/gocache/v3/store"
"github.com/google/go-cmp/cmp"
"github.com/netbirdio/netbird/management/server/util"
"golang.org/x/exp/maps"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
@ -45,7 +46,7 @@ const (
)
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
@ -53,13 +54,13 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err = store.SaveAccount(context.Background(), account)
err = s.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
Store: s,
eventStore: &activity.InMemoryEventStore{},
}
@ -81,7 +82,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
assert.Equal(t, pat.ID, tokenID)
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID)
if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err)
}
@ -855,7 +856,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
{
name: "Delete non-existent user",
userIDs: []string{"non-existent-user"},
expectedReasons: []string{"target user: non-existent-user not found"},
expectedReasons: []string{"user: non-existent-user not found"},
expectedNotDeleted: []string{},
},
{
@ -867,7 +868,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs)
userInfos, err := am.BuildUserInfosForAccount(context.Background(), mockAccountID, mockUserID, maps.Values(account.Users))
assert.NoError(t, err)
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs, userInfos)
if len(tc.expectedReasons) > 0 {
assert.Error(t, err)
var foundExpectedErrors int