mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-19 00:06:58 +02:00
[management] Refactor users to use store methods (#2917)
* Refactor setup key handling to use store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add lock to get account groups Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add check for regular user Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * get only required groups for auto-group validation Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add account lock and return auto groups map on validation Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor account peers update Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor groups to use store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor GetGroupByID and add NewGroupNotFoundError Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add AddPeer and RemovePeer methods to Group struct Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Preserve store engine in SqlStore transactions Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Run groups ops in transaction Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix missing group removed from setup key activity Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor posture checks to remove get and save account Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix refactor Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix sonar Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Change setup key log level to debug for missing group Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Retrieve modified peers once for group events Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor policy get and save account to use store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Retrieve policy groups and posture checks once for validation Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix typo Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add policy tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor anyGroupHasPeers to retrieve all groups once Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor dns settings to use store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add account locking and merge group deletion methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor name server groups to use store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add peer store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor ephemeral peers Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add lock for peer store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor peer handlers Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor peer to use store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix typo Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add locks and remove log Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * run peer ops in transaction Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * remove duplicate store method Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix peer fields updated after save Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Use update strength and simplify check Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * prevent changing ruleID when not empty Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * prevent duplicate rules during updates Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix lint Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor auth middleware Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor account methods and mock Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor user and PAT handling Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Remove db query context and fix get user by id Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix database transaction locking issue Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Use UTC time in test Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add account locks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix prevent users from creating PATs for other users Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add store locks and prevent fetching setup keys peers when retrieving user peers with empty userID Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add missing tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor test names and remove duplicate TestPostgresql_SavePeerStatus Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add account locks and remove redundant ephemeral check Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Retrieve all groups for peers and restrict groups for regular users Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix store tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * use account object to get validated peers Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Improve peer performance Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Get account direct from store without buffer Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add get peer groups tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Adjust benchmarks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Adjust benchmarks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * [management] Update benchmark workflow (#3181) * update local benchmark expectations * update cloud expectations * Add status error for generic result error Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Use integrated validator direct Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * update expectations * update expectations * update expectations * Refactor peer scheduler to retry every 3 seconds on errors Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * update expectations * fix validator * fix validator * fix validator * update timeouts * Refactor ToGroupsInfo to process slices of groups Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * update expectations * update expectations * update expectations * Bump integrations version Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor GetValidatedPeers Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * go mod tidy Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Use peers and groups map for peers validation Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * remove mysql from api benchmark tests * Fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix blocked db calls on user auto groups update Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * update expectations Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * update expectations Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Skip user check for system initiated peer deletion Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Remove context in db calls Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * update expectations Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * [management] Improve group peer/resource counting (#3192) * Fix sonar Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Adjust bench expectations Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Rename GetAccountInfoFromPAT to GetTokenInfo Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Remove global account lock for ListUsers Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * build userinfo after updating users in db Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * [management] Optimize user bulk deletion (#3315) * refactor building user infos Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * remove unused code Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor GetUsersFromAccount to return a map of UserInfo instead of a slice Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Export BuildUserInfosForAccount to account manager Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fetch account user info once for bulk users save Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Update user deletion expectations Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Set max open conns for activity store Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Update bench expectations Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Co-authored-by: Pascal Fischer <pascal@netbird.io> Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com>
This commit is contained in:
parent
abe8da697c
commit
4cdb2e533a
@ -260,6 +260,7 @@ func TestDNS_Integration(t *testing.T) {
|
|||||||
nsGroupReq := api.NameserverGroupRequest{
|
nsGroupReq := api.NameserverGroupRequest{
|
||||||
Description: "Test",
|
Description: "Test",
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
Domains: []string{},
|
||||||
Groups: []string{"cs1tnh0hhcjnqoiuebeg"},
|
Groups: []string{"cs1tnh0hhcjnqoiuebeg"},
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Nameservers: []api.Nameserver{
|
Nameservers: []api.Nameserver{
|
||||||
|
@ -67,7 +67,7 @@ type AccountManager interface {
|
|||||||
SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
|
SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
|
||||||
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
|
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||||
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
|
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
|
||||||
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
|
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||||
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||||
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
|
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
|
||||||
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error)
|
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error)
|
||||||
@ -79,7 +79,7 @@ type AccountManager interface {
|
|||||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||||
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||||
GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
|
GetPATInfo(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
|
||||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||||
@ -96,7 +96,7 @@ type AccountManager interface {
|
|||||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
||||||
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error)
|
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error)
|
||||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error)
|
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
|
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
|
||||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||||
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
|
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
|
||||||
@ -149,6 +149,7 @@ type AccountManager interface {
|
|||||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||||
UpdateAccountPeers(ctx context.Context, accountID string)
|
UpdateAccountPeers(ctx context.Context, accountID string)
|
||||||
|
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultAccountManager struct {
|
type DefaultAccountManager struct {
|
||||||
@ -617,6 +618,12 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
|||||||
if user.Role != types.UserRoleOwner {
|
if user.Role != types.UserRoleOwner {
|
||||||
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
|
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users))
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
|
||||||
|
}
|
||||||
|
|
||||||
for _, otherUser := range account.Users {
|
for _, otherUser := range account.Users {
|
||||||
if otherUser.IsServiceUser {
|
if otherUser.IsServiceUser {
|
||||||
continue
|
continue
|
||||||
@ -626,13 +633,23 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
deleteUserErr := am.deleteRegularUser(ctx, account, userID, otherUser.Id)
|
userInfo, ok := userInfosMap[otherUser.Id]
|
||||||
|
if !ok {
|
||||||
|
return status.Errorf(status.NotFound, "user info not found for user %s", otherUser.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, deleteUserErr := am.deleteRegularUser(ctx, accountID, userID, userInfo)
|
||||||
if deleteUserErr != nil {
|
if deleteUserErr != nil {
|
||||||
return deleteUserErr
|
return deleteUserErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.deleteRegularUser(ctx, account, userID, userID)
|
userInfo, ok := userInfosMap[userID]
|
||||||
|
if !ok {
|
||||||
|
return status.Errorf(status.NotFound, "user info not found for user %s", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = am.deleteRegularUser(ctx, accountID, userID, userInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err)
|
log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err)
|
||||||
return err
|
return err
|
||||||
@ -689,20 +706,8 @@ func isNil(i idp.Manager) bool {
|
|||||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||||
if !isNil(am.idpManager) {
|
if !isNil(am.idpManager) {
|
||||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cachedAccount := &types.Account{
|
|
||||||
Id: accountID,
|
|
||||||
Users: make(map[string]*types.User),
|
|
||||||
}
|
|
||||||
for _, user := range accountUsers {
|
|
||||||
cachedAccount.Users[user.Id] = user
|
|
||||||
}
|
|
||||||
|
|
||||||
// user can be nil if it wasn't found (e.g., just created)
|
// user can be nil if it wasn't found (e.g., just created)
|
||||||
user, err := am.lookupUserInCache(ctx, userID, cachedAccount)
|
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -778,10 +783,15 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
|
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
|
||||||
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *types.Account) (*idp.UserData, error) {
|
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
|
||||||
users := make(map[string]userLoggedInOnce, len(account.Users))
|
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
users := make(map[string]userLoggedInOnce, len(accountUsers))
|
||||||
// ignore service users and users provisioned by integrations than are never logged in
|
// ignore service users and users provisioned by integrations than are never logged in
|
||||||
for _, user := range account.Users {
|
for _, user := range accountUsers {
|
||||||
if user.IsServiceUser {
|
if user.IsServiceUser {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -790,8 +800,8 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
|
|||||||
}
|
}
|
||||||
users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero())
|
users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero())
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, account.Id)
|
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, accountID)
|
||||||
userData, err := am.lookupCache(ctx, users, account.Id)
|
userData, err := am.lookupCache(ctx, users, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -804,13 +814,13 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
|
|||||||
|
|
||||||
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
|
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
|
||||||
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
|
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
|
||||||
user, err := account.FindUser(userID)
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, account.Id)
|
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
key := user.IntegrationReference.CacheKey(account.Id, userID)
|
key := user.IntegrationReference.CacheKey(accountID, userID)
|
||||||
ud, err := am.externalCacheManager.Get(am.ctx, key)
|
ud, err := am.externalCacheManager.Get(am.ctx, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err)
|
log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err)
|
||||||
@ -1050,9 +1060,9 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
|
|||||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||||
defer unlockAccount()
|
defer unlockAccount()
|
||||||
|
|
||||||
usersMap := make(map[string]*types.User)
|
newUser := types.NewRegularUser(claims.UserId)
|
||||||
usersMap[claims.UserId] = types.NewRegularUser(claims.UserId)
|
newUser.AccountID = domainAccountID
|
||||||
err := am.Store.SaveUsers(domainAccountID, usersMap)
|
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -1075,12 +1085,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -1090,17 +1095,17 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
|
|||||||
}
|
}
|
||||||
|
|
||||||
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
|
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
|
||||||
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, account.Id)
|
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID)
|
||||||
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
|
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
|
||||||
// Our job is to just reload cache.
|
// Our job is to just reload cache.
|
||||||
go func() {
|
go func() {
|
||||||
_, err = am.refreshCache(ctx, account.Id)
|
_, err = am.refreshCache(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id)
|
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, account.Id)
|
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID)
|
||||||
am.StoreEvent(ctx, userID, userID, account.Id, activity.UserJoined, nil)
|
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1109,33 +1114,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
|
|||||||
|
|
||||||
// MarkPATUsed marks a personal access token as used
|
// MarkPATUsed marks a personal access token as used
|
||||||
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
|
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
|
||||||
|
return am.Store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID)
|
||||||
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccountByUser(ctx, user.Id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err = am.Store.GetAccountByUser(ctx, user.Id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
pat, ok := account.Users[user.Id].PATs[tokenID]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("token not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
pat.LastUsed = util.ToPtr(time.Now().UTC())
|
|
||||||
|
|
||||||
return am.Store.SaveAccount(ctx, account)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccount returns an account associated with this account ID.
|
// GetAccount returns an account associated with this account ID.
|
||||||
@ -1143,52 +1122,64 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
|
|||||||
return am.Store.GetAccount(ctx, accountID)
|
return am.Store.GetAccount(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountFromPAT returns Account and User associated with a personal access token
|
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
|
||||||
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
|
||||||
|
user, pat, err = am.extractPATFromToken(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, pat, domain, category, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
|
||||||
|
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
|
||||||
if len(token) != types.PATLength {
|
if len(token) != types.PATLength {
|
||||||
return nil, nil, nil, fmt.Errorf("token has wrong length")
|
return nil, nil, fmt.Errorf("token has incorrect length")
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix := token[:len(types.PATPrefix)]
|
prefix := token[:len(types.PATPrefix)]
|
||||||
if prefix != types.PATPrefix {
|
if prefix != types.PATPrefix {
|
||||||
return nil, nil, nil, fmt.Errorf("token has wrong prefix")
|
return nil, nil, fmt.Errorf("token has wrong prefix")
|
||||||
}
|
}
|
||||||
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
|
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
|
||||||
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
|
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
|
||||||
|
|
||||||
verificationChecksum, err := base62.Decode(encodedChecksum)
|
verificationChecksum, err := base62.Decode(encodedChecksum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
|
return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
|
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||||
if secretChecksum != verificationChecksum {
|
if secretChecksum != verificationChecksum {
|
||||||
return nil, nil, nil, fmt.Errorf("token checksum does not match")
|
return nil, nil, fmt.Errorf("token checksum does not match")
|
||||||
}
|
}
|
||||||
|
|
||||||
hashedToken := sha256.Sum256([]byte(token))
|
hashedToken := sha256.Sum256([]byte(token))
|
||||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||||
tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken)
|
|
||||||
|
var user *types.User
|
||||||
|
var pat *types.PersonalAccessToken
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
|
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
|
return user, pat, nil
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccountByUser(ctx, user.Id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pat := user.PATs[tokenID]
|
|
||||||
if pat == nil {
|
|
||||||
return nil, nil, nil, fmt.Errorf("personal access token not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return account, user, pat, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByID returns an account associated with this account ID.
|
// GetAccountByID returns an account associated with this account ID.
|
||||||
@ -1334,7 +1325,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
return fmt.Errorf("error getting user peers: %w", err)
|
return fmt.Errorf("error getting user peers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
|
updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error modifying user peers in groups: %w", err)
|
return fmt.Errorf("error modifying user peers in groups: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -732,6 +732,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
|||||||
PATs: map[string]*types.PersonalAccessToken{
|
PATs: map[string]*types.PersonalAccessToken{
|
||||||
"tokenId": {
|
"tokenId": {
|
||||||
ID: "tokenId",
|
ID: "tokenId",
|
||||||
|
UserID: "someUser",
|
||||||
HashedToken: encodedHashedToken,
|
HashedToken: encodedHashedToken,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -745,14 +746,14 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
|||||||
Store: store,
|
Store: store,
|
||||||
}
|
}
|
||||||
|
|
||||||
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
|
user, pat, _, _, err := am.GetPATInfo(context.Background(), token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error when getting Account from PAT: %s", err)
|
t.Fatalf("Error when getting Account from PAT: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, "account_id", account.Id)
|
assert.Equal(t, "account_id", user.AccountID)
|
||||||
assert.Equal(t, "someUser", user.Id)
|
assert.Equal(t, "someUser", user.Id)
|
||||||
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
|
assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
@ -95,6 +96,7 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
db.SetMaxOpenConns(runtime.NumCPU())
|
||||||
|
|
||||||
crypt, err := NewFieldEncrypt(encryptionKey)
|
crypt, err := NewFieldEncrypt(encryptionKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -43,7 +43,7 @@ func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, network
|
|||||||
)
|
)
|
||||||
|
|
||||||
authMiddleware := middleware.NewAuthMiddleware(
|
authMiddleware := middleware.NewAuthMiddleware(
|
||||||
accountManager.GetAccountFromPAT,
|
accountManager.GetPATInfo,
|
||||||
jwtValidator.ValidateAndParse,
|
jwtValidator.ValidateAndParse,
|
||||||
accountManager.MarkPATUsed,
|
accountManager.MarkPATUsed,
|
||||||
accountManager.CheckUserAccessByJWTGroups,
|
accountManager.CheckUserAccessByJWTGroups,
|
||||||
|
@ -32,8 +32,8 @@ func initEventsTestData(account string, events ...*activity.Event) *handler {
|
|||||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return claims.AccountId, claims.UserId, nil
|
return claims.AccountId, claims.UserId, nil
|
||||||
},
|
},
|
||||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) {
|
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) {
|
||||||
return make([]*types.UserInfo, 0), nil
|
return make(map[string]*types.UserInfo), nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
|
@ -52,7 +52,7 @@ var usersTestAccount = &types.Account{
|
|||||||
Issued: types.UserIssuedAPI,
|
Issued: types.UserIssuedAPI,
|
||||||
},
|
},
|
||||||
nonDeletableServiceUserID: {
|
nonDeletableServiceUserID: {
|
||||||
Id: serviceUserID,
|
Id: nonDeletableServiceUserID,
|
||||||
Role: "admin",
|
Role: "admin",
|
||||||
IsServiceUser: true,
|
IsServiceUser: true,
|
||||||
NonDeletable: true,
|
NonDeletable: true,
|
||||||
@ -70,10 +70,10 @@ func initUsersTestData() *handler {
|
|||||||
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
|
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
|
||||||
return usersTestAccount.Users[id], nil
|
return usersTestAccount.Users[id], nil
|
||||||
},
|
},
|
||||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) {
|
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) {
|
||||||
users := make([]*types.UserInfo, 0)
|
usersInfos := make(map[string]*types.UserInfo)
|
||||||
for _, v := range usersTestAccount.Users {
|
for _, v := range usersTestAccount.Users {
|
||||||
users = append(users, &types.UserInfo{
|
usersInfos[v.Id] = &types.UserInfo{
|
||||||
ID: v.Id,
|
ID: v.Id,
|
||||||
Role: string(v.Role),
|
Role: string(v.Role),
|
||||||
Name: "",
|
Name: "",
|
||||||
@ -81,9 +81,9 @@ func initUsersTestData() *handler {
|
|||||||
IsServiceUser: v.IsServiceUser,
|
IsServiceUser: v.IsServiceUser,
|
||||||
NonDeletable: v.NonDeletable,
|
NonDeletable: v.NonDeletable,
|
||||||
Issued: v.Issued,
|
Issued: v.Issued,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
return users, nil
|
return usersInfos, nil
|
||||||
},
|
},
|
||||||
CreateUserFunc: func(_ context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) {
|
CreateUserFunc: func(_ context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) {
|
||||||
if userID != existingUserID {
|
if userID != existingUserID {
|
||||||
|
@ -19,8 +19,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetAccountFromPATFunc function
|
// GetAccountInfoFromPATFunc function
|
||||||
type GetAccountFromPATFunc func(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
|
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
|
||||||
|
|
||||||
// ValidateAndParseTokenFunc function
|
// ValidateAndParseTokenFunc function
|
||||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||||
@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
|
|||||||
|
|
||||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||||
type AuthMiddleware struct {
|
type AuthMiddleware struct {
|
||||||
getAccountFromPAT GetAccountFromPATFunc
|
getAccountInfoFromPAT GetAccountInfoFromPATFunc
|
||||||
validateAndParseToken ValidateAndParseTokenFunc
|
validateAndParseToken ValidateAndParseTokenFunc
|
||||||
markPATUsed MarkPATUsedFunc
|
markPATUsed MarkPATUsedFunc
|
||||||
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
||||||
@ -47,7 +47,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewAuthMiddleware instance constructor
|
// NewAuthMiddleware instance constructor
|
||||||
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||||
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||||
audience string, userIdClaim string) *AuthMiddleware {
|
audience string, userIdClaim string) *AuthMiddleware {
|
||||||
if userIdClaim == "" {
|
if userIdClaim == "" {
|
||||||
@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &AuthMiddleware{
|
return &AuthMiddleware{
|
||||||
getAccountFromPAT: getAccountFromPAT,
|
getAccountInfoFromPAT: getAccountInfoFromPAT,
|
||||||
validateAndParseToken: validateAndParseToken,
|
validateAndParseToken: validateAndParseToken,
|
||||||
markPATUsed: markPATUsed,
|
markPATUsed: markPATUsed,
|
||||||
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
||||||
@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
|
|||||||
// CheckPATFromRequest checks if the PAT is valid
|
// CheckPATFromRequest checks if the PAT is valid
|
||||||
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||||
token, err := getTokenFromPATRequest(auth)
|
token, err := getTokenFromPATRequest(auth)
|
||||||
|
|
||||||
// If an error occurs, call the error handler and return an error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error extracting token: %w", err)
|
return fmt.Errorf("error extracting token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
|
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid Token: %w", err)
|
return fmt.Errorf("invalid Token: %w", err)
|
||||||
}
|
}
|
||||||
@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
|||||||
|
|
||||||
claimMaps := jwt.MapClaims{}
|
claimMaps := jwt.MapClaims{}
|
||||||
claimMaps[m.userIDClaim] = user.Id
|
claimMaps[m.userIDClaim] = user.Id
|
||||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
|
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
|
||||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
|
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
|
||||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
|
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
|
||||||
claimMaps[jwtclaims.IsToken] = true
|
claimMaps[jwtclaims.IsToken] = true
|
||||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
||||||
|
@ -34,7 +34,8 @@ var testAccount = &types.Account{
|
|||||||
Domain: domain,
|
Domain: domain,
|
||||||
Users: map[string]*types.User{
|
Users: map[string]*types.User{
|
||||||
userID: {
|
userID: {
|
||||||
Id: userID,
|
Id: userID,
|
||||||
|
AccountID: accountID,
|
||||||
PATs: map[string]*types.PersonalAccessToken{
|
PATs: map[string]*types.PersonalAccessToken{
|
||||||
tokenID: {
|
tokenID: {
|
||||||
ID: tokenID,
|
ID: tokenID,
|
||||||
@ -50,11 +51,11 @@ var testAccount = &types.Account{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func mockGetAccountFromPAT(_ context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
|
func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
|
||||||
if token == PAT {
|
if token == PAT {
|
||||||
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
|
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
|
||||||
}
|
}
|
||||||
return nil, nil, nil, fmt.Errorf("PAT invalid")
|
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
|
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
|
||||||
@ -166,7 +167,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
authMiddleware := NewAuthMiddleware(
|
authMiddleware := NewAuthMiddleware(
|
||||||
mockGetAccountFromPAT,
|
mockGetAccountInfoFromPAT,
|
||||||
mockValidateAndParseToken,
|
mockValidateAndParseToken,
|
||||||
mockMarkPATUsed,
|
mockMarkPATUsed,
|
||||||
mockCheckUserAccessByJWTGroups,
|
mockCheckUserAccessByJWTGroups,
|
||||||
|
@ -35,14 +35,14 @@ var benchCasesUsers = map[string]testing_tools.BenchmarkCase{
|
|||||||
|
|
||||||
func BenchmarkUpdateUser(b *testing.B) {
|
func BenchmarkUpdateUser(b *testing.B) {
|
||||||
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||||
"Users - XS": {MinMsPerOpLocal: 700, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 8000},
|
"Users - XS": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 160, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 310},
|
||||||
"Users - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 4, MaxMsPerOpCICD: 50},
|
"Users - S": {MinMsPerOpLocal: 0.3, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 15},
|
||||||
"Users - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 250},
|
"Users - M": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 3, MaxMsPerOpCICD: 20},
|
||||||
"Users - L": {MinMsPerOpLocal: 60, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 90, MaxMsPerOpCICD: 700},
|
"Users - L": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 50},
|
||||||
"Peers - L": {MinMsPerOpLocal: 300, MaxMsPerOpLocal: 500, MinMsPerOpCICD: 550, MaxMsPerOpCICD: 2400},
|
"Peers - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 310},
|
||||||
"Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 750, MaxMsPerOpCICD: 5000},
|
"Groups - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 120},
|
||||||
"Setup Keys - L": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 1000},
|
"Setup Keys - L": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 50},
|
||||||
"Users - XL": {MinMsPerOpLocal: 350, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 3500},
|
"Users - XL": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 280},
|
||||||
}
|
}
|
||||||
|
|
||||||
log.SetOutput(io.Discard)
|
log.SetOutput(io.Discard)
|
||||||
@ -118,14 +118,14 @@ func BenchmarkGetOneUser(b *testing.B) {
|
|||||||
|
|
||||||
func BenchmarkGetAllUsers(b *testing.B) {
|
func BenchmarkGetAllUsers(b *testing.B) {
|
||||||
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||||
"Users - XS": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 180},
|
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
|
||||||
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
|
||||||
"Users - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 12, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
"Users - M": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 15},
|
||||||
"Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
"Users - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 50},
|
||||||
"Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
"Peers - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 55},
|
||||||
"Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
"Groups - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
|
||||||
"Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 140, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 200},
|
"Setup Keys - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
|
||||||
"Users - XL": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 90},
|
"Users - XL": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 120, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300},
|
||||||
}
|
}
|
||||||
|
|
||||||
log.SetOutput(io.Discard)
|
log.SetOutput(io.Discard)
|
||||||
@ -141,7 +141,7 @@ func BenchmarkGetAllUsers(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId)
|
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
|
||||||
apiHandler.ServeHTTP(recorder, req)
|
apiHandler.ServeHTTP(recorder, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -152,14 +152,14 @@ func BenchmarkGetAllUsers(b *testing.B) {
|
|||||||
|
|
||||||
func BenchmarkDeleteUsers(b *testing.B) {
|
func BenchmarkDeleteUsers(b *testing.B) {
|
||||||
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||||
"Users - XS": {MinMsPerOpLocal: 1000, MaxMsPerOpLocal: 1600, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 11000},
|
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||||
"Users - S": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200},
|
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||||
"Users - M": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 230},
|
"Users - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||||
"Users - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 45, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 190},
|
"Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||||
"Peers - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 1800},
|
"Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||||
"Groups - L": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 1200, MaxMsPerOpCICD: 7500},
|
"Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||||
"Setup Keys - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 40, MaxMsPerOpCICD: 600},
|
"Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||||
"Users - XL": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 400},
|
"Users - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||||
}
|
}
|
||||||
|
|
||||||
log.SetOutput(io.Discard)
|
log.SetOutput(io.Discard)
|
||||||
|
@ -53,8 +53,8 @@ type MockAccountManager struct {
|
|||||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error)
|
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error)
|
||||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
||||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error)
|
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||||
GetAccountFromPATFunc func(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
|
GetPATInfoFunc func(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
|
||||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
@ -69,7 +69,7 @@ type MockAccountManager struct {
|
|||||||
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error)
|
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error)
|
||||||
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
|
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
|
||||||
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||||
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
|
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||||
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||||
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
|
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
|
||||||
@ -110,6 +110,7 @@ type MockAccountManager struct {
|
|||||||
GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error)
|
GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error)
|
||||||
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||||
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||||
|
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||||
@ -165,7 +166,7 @@ func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userI
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*types.UserInfo, error) {
|
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) (map[string]*types.UserInfo, error) {
|
||||||
if am.GetUsersFromAccountFunc != nil {
|
if am.GetUsersFromAccountFunc != nil {
|
||||||
return am.GetUsersFromAccountFunc(ctx, accountID, userID)
|
return am.GetUsersFromAccountFunc(ctx, accountID, userID)
|
||||||
}
|
}
|
||||||
@ -238,12 +239,12 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str
|
|||||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
|
// GetPATInfo mock implementation of GetPATInfo from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
|
func (am *MockAccountManager) GetPATInfo(ctx context.Context, pat string) (*types.User, *types.PersonalAccessToken, string, string, error) {
|
||||||
if am.GetAccountFromPATFunc != nil {
|
if am.GetPATInfoFunc != nil {
|
||||||
return am.GetAccountFromPATFunc(ctx, pat)
|
return am.GetPATInfoFunc(ctx, pat)
|
||||||
}
|
}
|
||||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
|
return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetPATInfo is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
||||||
@ -550,9 +551,9 @@ func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRegularUsers mocks DeleteRegularUsers of the AccountManager interface
|
// DeleteRegularUsers mocks DeleteRegularUsers of the AccountManager interface
|
||||||
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID string, initiatorUserID string, targetUserIDs []string) error {
|
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error {
|
||||||
if am.DeleteRegularUsersFunc != nil {
|
if am.DeleteRegularUsersFunc != nil {
|
||||||
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs)
|
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs, userInfos)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented")
|
return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented")
|
||||||
}
|
}
|
||||||
@ -849,3 +850,11 @@ func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peer
|
|||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BuildUserInfosForAccount mocks BuildUserInfosForAccount of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) {
|
||||||
|
if am.BuildUserInfosForAccountFunc != nil {
|
||||||
|
return am.BuildUserInfosForAccountFunc(ctx, accountID, initiatorUserID, accountUsers)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented")
|
||||||
|
}
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -28,7 +29,6 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
@ -1554,7 +1554,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
|||||||
// Adding peer to group linked with policy should update account peers and send peer update
|
// Adding peer to group linked with policy should update account peers and send peer update
|
||||||
t.Run("adding peer to group linked with policy", func(t *testing.T) {
|
t.Run("adding peer to group linked with policy", func(t *testing.T) {
|
||||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||||
Enabled: true,
|
AccountID: account.Id,
|
||||||
|
Enabled: true,
|
||||||
Rules: []*types.PolicyRule{
|
Rules: []*types.PolicyRule{
|
||||||
{
|
{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -93,7 +93,7 @@ func NewPeerNotPartOfAccountError() error {
|
|||||||
|
|
||||||
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
|
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
|
||||||
func NewUserNotFoundError(userKey string) error {
|
func NewUserNotFoundError(userKey string) error {
|
||||||
return Errorf(NotFound, "user not found: %s", userKey)
|
return Errorf(NotFound, "user: %s not found", userKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer
|
// NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer
|
||||||
@ -191,3 +191,18 @@ func NewResourceNotPartOfNetworkError(resourceID, networkID string) error {
|
|||||||
func NewRouterNotPartOfNetworkError(routerID, networkID string) error {
|
func NewRouterNotPartOfNetworkError(routerID, networkID string) error {
|
||||||
return Errorf(BadRequest, "router %s is not part of the network %s", routerID, networkID)
|
return Errorf(BadRequest, "router %s is not part of the network %s", routerID, networkID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewServiceUserRoleInvalidError creates a new Error with InvalidArgument type for creating a service user with owner role
|
||||||
|
func NewServiceUserRoleInvalidError() error {
|
||||||
|
return Errorf(InvalidArgument, "can't create a service user with owner role")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOwnerDeletePermissionError creates a new Error with PermissionDenied type for attempting
|
||||||
|
// to delete a user with the owner role.
|
||||||
|
func NewOwnerDeletePermissionError() error {
|
||||||
|
return Errorf(PermissionDenied, "can't delete a user with the owner role")
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPATNotFoundError(patID string) error {
|
||||||
|
return Errorf(NotFound, "PAT: %s not found", patID)
|
||||||
|
}
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gorm.io/driver/mysql"
|
"gorm.io/driver/mysql"
|
||||||
"gorm.io/driver/postgres"
|
"gorm.io/driver/postgres"
|
||||||
@ -414,24 +415,16 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SaveUsers saves the given list of users to the database.
|
// SaveUsers saves the given list of users to the database.
|
||||||
// It updates existing users if a conflict occurs.
|
func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error {
|
||||||
func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) error {
|
if len(users) == 0 {
|
||||||
usersToSave := make([]types.User, 0, len(users))
|
return nil
|
||||||
for _, user := range users {
|
|
||||||
user.AccountID = accountID
|
|
||||||
for id, pat := range user.PATs {
|
|
||||||
pat.ID = id
|
|
||||||
user.PATsG = append(user.PATsG, *pat)
|
|
||||||
}
|
|
||||||
usersToSave = append(usersToSave, *user)
|
|
||||||
}
|
|
||||||
err := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
|
|
||||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
|
||||||
Create(&usersToSave).Error
|
|
||||||
if err != nil {
|
|
||||||
return status.Errorf(status.Internal, "failed to save users to store: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&users)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to save users to store")
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -439,7 +432,8 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) err
|
|||||||
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error {
|
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error {
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
|
log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to save user to store")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -526,30 +520,17 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
|
|||||||
return token.ID, nil
|
return token.ID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) {
|
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) {
|
||||||
var token types.PersonalAccessToken
|
var user types.User
|
||||||
result := s.db.First(&token, idQueryCondition, tokenID)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
|
||||||
|
Where("personal_access_tokens.id = ?", patID).First(&user)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.NewPATNotFoundError(patID)
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error)
|
||||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
return nil, status.NewGetUserFromStoreError()
|
||||||
}
|
|
||||||
|
|
||||||
if token.UserID == "" {
|
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
var user types.User
|
|
||||||
result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID)
|
|
||||||
if result.Error != nil {
|
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATsG))
|
|
||||||
for _, pat := range user.PATsG {
|
|
||||||
user.PATs[pat.ID] = pat.Copy()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &user, nil
|
return &user, nil
|
||||||
@ -557,8 +538,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types
|
|||||||
|
|
||||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
||||||
var user types.User
|
var user types.User
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&user, idQueryCondition, userID)
|
||||||
Preload(clause.Associations).First(&user, idQueryCondition, userID)
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.NewUserNotFoundError(userID)
|
return nil, status.NewUserNotFoundError(userID)
|
||||||
@ -569,6 +549,25 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error {
|
||||||
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&types.PersonalAccessToken{}, "user_id = ?", userID)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&types.User{}, accountAndIDQueryCondition, accountID, userID).Error
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete user from the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete user from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
|
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
|
||||||
var users []*types.User
|
var users []*types.User
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
|
||||||
@ -899,6 +898,20 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
|
|||||||
return accountSettings.Settings, nil
|
return accountSettings.Settings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
|
||||||
|
var createdBy string
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
|
||||||
|
Select("created_by").First(&createdBy, idQueryCondition, accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return "", status.NewAccountNotFoundError(accountID)
|
||||||
|
}
|
||||||
|
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return createdBy, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||||
var user types.User
|
var user types.User
|
||||||
@ -2053,3 +2066,94 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
|
||||||
|
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
|
||||||
|
var pat types.PersonalAccessToken
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewPATNotFoundError(hashedToken)
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get pat by hash from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get pat by hash from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pat, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPATByID retrieves a personal access token by its ID and user ID.
|
||||||
|
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
|
||||||
|
var pat types.PersonalAccessToken
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&pat, "id = ? AND user_id = ?", patID, userID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewPATNotFoundError(patID)
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get pat from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get pat from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pat, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserPATs retrieves personal access tokens for a user.
|
||||||
|
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
|
||||||
|
var pats []*types.PersonalAccessToken
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get user pat's from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return pats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkPATUsed marks a personal access token as used.
|
||||||
|
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error {
|
||||||
|
patCopy := types.PersonalAccessToken{
|
||||||
|
LastUsed: util.ToPtr(time.Now().UTC()),
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldsToUpdate := []string{"last_used"}
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Select(fieldsToUpdate).
|
||||||
|
Where(idQueryCondition, patID).Updates(&patCopy)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to mark pat as used: %s", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to mark pat as used")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.NewPATNotFoundError(patID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SavePAT saves a personal access token to the database.
|
||||||
|
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *types.PersonalAccessToken) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to save pat to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePAT deletes a personal access token from the database.
|
||||||
|
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&types.PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete pat from the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete pat from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.NewPATNotFoundError(patID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -627,29 +627,6 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
|
|||||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSqlite_GetUserByTokenID(t *testing.T) {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
|
||||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
|
||||||
t.Cleanup(cleanUp)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
|
||||||
|
|
||||||
user, err := store.GetUserByTokenID(context.Background(), id)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, id, user.PATs[id].ID)
|
|
||||||
|
|
||||||
_, err = store.GetUserByTokenID(context.Background(), "non-existing-id")
|
|
||||||
require.Error(t, err)
|
|
||||||
parsedErr, ok := status.FromError(err)
|
|
||||||
require.True(t, ok)
|
|
||||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrate(t *testing.T) {
|
func TestMigrate(t *testing.T) {
|
||||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||||
t.Skip("skip CI tests on darwin and windows")
|
t.Skip("skip CI tests on darwin and windows")
|
||||||
@ -962,23 +939,6 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
|
|||||||
require.Equal(t, id, token)
|
require.Equal(t, id, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
|
||||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
|
||||||
t.Skip("skip CI tests on darwin and windows")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
|
||||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
|
||||||
t.Cleanup(cleanUp)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
|
||||||
|
|
||||||
user, err := store.GetUserByTokenID(context.Background(), id)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, id, user.PATs[id].ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqlite_GetTakenIPs(t *testing.T) {
|
func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
@ -1182,7 +1142,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSqlite_GetAccoundUsers(t *testing.T) {
|
func TestSqlStore_GetAccountUsers(t *testing.T) {
|
||||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
t.Cleanup(cleanup)
|
t.Cleanup(cleanup)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -2915,3 +2875,326 @@ func TestSqlStore_DatabaseBlocking(t *testing.T) {
|
|||||||
|
|
||||||
t.Logf("Test completed")
|
t.Logf("Test completed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetAccountCreatedBy(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accountID string
|
||||||
|
expectError bool
|
||||||
|
createdBy string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "existing account ID",
|
||||||
|
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||||
|
expectError: false,
|
||||||
|
createdBy: "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existing account ID",
|
||||||
|
accountID: "nonexistent",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty account ID",
|
||||||
|
accountID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
createdBy, err := store.GetAccountCreatedBy(context.Background(), LockingStrengthShare, tt.accountID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
require.Empty(t, createdBy)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, createdBy)
|
||||||
|
require.Equal(t, tt.createdBy, createdBy)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetUserByUserID(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
userID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing user",
|
||||||
|
userID: "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve non-existing user",
|
||||||
|
userID: "non-existing",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve with empty user ID",
|
||||||
|
userID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, tt.userID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
require.Nil(t, user)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, user)
|
||||||
|
require.Equal(t, tt.userID, user.Id)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetUserByPATID(t *testing.T) {
|
||||||
|
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanUp)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
|
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_SaveUser(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
user := &types.User{
|
||||||
|
Id: "user-id",
|
||||||
|
AccountID: accountID,
|
||||||
|
Role: types.UserRoleAdmin,
|
||||||
|
IsServiceUser: false,
|
||||||
|
AutoGroups: []string{"groupA", "groupB"},
|
||||||
|
Blocked: false,
|
||||||
|
LastLogin: util.ToPtr(time.Now().UTC()),
|
||||||
|
CreatedAt: time.Now().UTC().Add(-time.Hour),
|
||||||
|
Issued: types.UserIssuedIntegration,
|
||||||
|
}
|
||||||
|
err = store.SaveUser(context.Background(), LockingStrengthUpdate, user)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
saveUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, user.Id)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, user.Id, saveUser.Id)
|
||||||
|
require.Equal(t, user.AccountID, saveUser.AccountID)
|
||||||
|
require.Equal(t, user.Role, saveUser.Role)
|
||||||
|
require.Equal(t, user.AutoGroups, saveUser.AutoGroups)
|
||||||
|
require.WithinDurationf(t, user.GetLastLogin(), saveUser.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal")
|
||||||
|
require.WithinDurationf(t, user.CreatedAt, saveUser.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
|
||||||
|
require.Equal(t, user.Issued, saveUser.Issued)
|
||||||
|
require.Equal(t, user.Blocked, saveUser.Blocked)
|
||||||
|
require.Equal(t, user.IsServiceUser, saveUser.IsServiceUser)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_SaveUsers(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
accountUsers, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, accountUsers, 2)
|
||||||
|
|
||||||
|
users := []*types.User{
|
||||||
|
{
|
||||||
|
Id: "user-1",
|
||||||
|
AccountID: accountID,
|
||||||
|
Issued: "api",
|
||||||
|
AutoGroups: []string{"groupA", "groupB"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "user-2",
|
||||||
|
AccountID: accountID,
|
||||||
|
Issued: "integration",
|
||||||
|
AutoGroups: []string{"groupA"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, users)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountUsers, err = store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, accountUsers, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_DeleteUser(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
|
err = store.DeleteUser(context.Background(), LockingStrengthUpdate, accountID, userID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, userID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, user)
|
||||||
|
|
||||||
|
userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthShare, userID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, userPATs, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetPATByID(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
patID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing PAT",
|
||||||
|
patID: "9dj38s35-63fb-11ec-90d6-0242ac120003",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve non-existing PAT",
|
||||||
|
patID: "non-existing",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve with empty PAT ID",
|
||||||
|
patID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, tt.patID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
require.Nil(t, pat)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, pat)
|
||||||
|
require.Equal(t, tt.patID, pat.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetUserPATs(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthShare, "f4f6d672-63fb-11ec-90d6-0242ac120003")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, userPATs, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetPATByHashedToken(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "SoMeHaShEdToKeN")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "9dj38s35-63fb-11ec-90d6-0242ac120003", pat.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_MarkPATUsed(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
|
||||||
|
patID := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
|
err = store.MarkPATUsed(context.Background(), LockingStrengthUpdate, patID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, patID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
now := time.Now().UTC()
|
||||||
|
require.WithinRange(t, pat.LastUsed.UTC(), now.Add(-15*time.Second), now, "LastUsed should be within 1 second of now")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_SavePAT(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
userID := "edafee4e-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
|
pat := &types.PersonalAccessToken{
|
||||||
|
ID: "pat-id",
|
||||||
|
UserID: userID,
|
||||||
|
Name: "token",
|
||||||
|
HashedToken: "SoMeHaShEdToKeN",
|
||||||
|
ExpirationDate: util.ToPtr(time.Now().UTC().Add(12 * time.Hour)),
|
||||||
|
CreatedBy: userID,
|
||||||
|
CreatedAt: time.Now().UTC().Add(time.Hour),
|
||||||
|
LastUsed: util.ToPtr(time.Now().UTC().Add(-15 * time.Minute)),
|
||||||
|
}
|
||||||
|
err = store.SavePAT(context.Background(), LockingStrengthUpdate, pat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
savePAT, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, pat.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, pat.ID, savePAT.ID)
|
||||||
|
require.Equal(t, pat.UserID, savePAT.UserID)
|
||||||
|
require.Equal(t, pat.HashedToken, savePAT.HashedToken)
|
||||||
|
require.Equal(t, pat.CreatedBy, savePAT.CreatedBy)
|
||||||
|
require.WithinDurationf(t, pat.GetExpirationDate(), savePAT.ExpirationDate.UTC(), time.Millisecond, "ExpirationDate should be equal")
|
||||||
|
require.WithinDurationf(t, pat.CreatedAt, savePAT.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
|
||||||
|
require.WithinDurationf(t, pat.GetLastUsed(), savePAT.LastUsed.UTC(), time.Millisecond, "LastUsed should be equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_DeletePAT(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
|
||||||
|
patID := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
|
err = store.DeletePAT(context.Background(), LockingStrengthUpdate, userID, patID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, patID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, pat)
|
||||||
|
}
|
||||||
|
@ -59,21 +59,30 @@ type Store interface {
|
|||||||
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
||||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error)
|
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error)
|
||||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error)
|
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error)
|
||||||
|
GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
|
||||||
SaveAccount(ctx context.Context, account *types.Account) error
|
SaveAccount(ctx context.Context, account *types.Account) error
|
||||||
DeleteAccount(ctx context.Context, account *types.Account) error
|
DeleteAccount(ctx context.Context, account *types.Account) error
|
||||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
|
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
|
||||||
|
|
||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error)
|
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
|
||||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
|
||||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
|
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
|
||||||
SaveUsers(accountID string, users map[string]*types.User) error
|
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error
|
||||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error
|
SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error
|
||||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||||
|
DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error
|
||||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||||
|
|
||||||
|
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
|
||||||
|
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
|
||||||
|
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)
|
||||||
|
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
|
||||||
|
SavePAT(ctx context.Context, strength LockingStrength, pat *types.PersonalAccessToken) error
|
||||||
|
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
|
||||||
|
|
||||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
|
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
|
||||||
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
|
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
|
||||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
|
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
|
||||||
|
2
management/server/testdata/store.sql
vendored
2
management/server/testdata/store.sql
vendored
@ -37,7 +37,7 @@ CREATE INDEX `idx_network_resources_id` ON `network_resources`(`id`);
|
|||||||
CREATE INDEX `idx_networks_id` ON `networks`(`id`);
|
CREATE INDEX `idx_networks_id` ON `networks`(`id`);
|
||||||
CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`);
|
CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`);
|
||||||
|
|
||||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||||
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0);
|
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0);
|
||||||
INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');
|
INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');
|
||||||
|
@ -75,7 +75,7 @@ type PersonalAccessTokenGenerated struct {
|
|||||||
|
|
||||||
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
|
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
|
||||||
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
|
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
|
||||||
func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) {
|
func CreateNewPAT(name string, expirationInDays int, targetID, createdBy string) (*PersonalAccessTokenGenerated, error) {
|
||||||
hashedToken, plainToken, err := generateNewToken()
|
hashedToken, plainToken, err := generateNewToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -84,6 +84,7 @@ func CreateNewPAT(name string, expirationInDays int, createdBy string) (*Persona
|
|||||||
return &PersonalAccessTokenGenerated{
|
return &PersonalAccessTokenGenerated{
|
||||||
PersonalAccessToken: PersonalAccessToken{
|
PersonalAccessToken: PersonalAccessToken{
|
||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
|
UserID: targetID,
|
||||||
Name: name,
|
Name: name,
|
||||||
HashedToken: hashedToken,
|
HashedToken: hashedToken,
|
||||||
ExpirationDate: util.ToPtr(currentTime.AddDate(0, 0, expirationInDays)),
|
ExpirationDate: util.ToPtr(currentTime.AddDate(0, 0, expirationInDays)),
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -11,6 +11,7 @@ import (
|
|||||||
cacheStore "github.com/eko/gocache/v3/store"
|
cacheStore "github.com/eko/gocache/v3/store"
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@ -45,7 +46,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error when creating store: %s", err)
|
t.Fatalf("Error when creating store: %s", err)
|
||||||
}
|
}
|
||||||
@ -53,13 +54,13 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
|||||||
|
|
||||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||||
|
|
||||||
err = store.SaveAccount(context.Background(), account)
|
err = s.SaveAccount(context.Background(), account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error when saving account: %s", err)
|
t.Fatalf("Error when saving account: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
am := DefaultAccountManager{
|
am := DefaultAccountManager{
|
||||||
Store: store,
|
Store: s,
|
||||||
eventStore: &activity.InMemoryEventStore{},
|
eventStore: &activity.InMemoryEventStore{},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,7 +82,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, pat.ID, tokenID)
|
assert.Equal(t, pat.ID, tokenID)
|
||||||
|
|
||||||
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID)
|
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error when getting user by token ID: %s", err)
|
t.Fatalf("Error when getting user by token ID: %s", err)
|
||||||
}
|
}
|
||||||
@ -855,7 +856,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Delete non-existent user",
|
name: "Delete non-existent user",
|
||||||
userIDs: []string{"non-existent-user"},
|
userIDs: []string{"non-existent-user"},
|
||||||
expectedReasons: []string{"target user: non-existent-user not found"},
|
expectedReasons: []string{"user: non-existent-user not found"},
|
||||||
expectedNotDeleted: []string{},
|
expectedNotDeleted: []string{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -867,7 +868,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs)
|
userInfos, err := am.BuildUserInfosForAccount(context.Background(), mockAccountID, mockUserID, maps.Values(account.Users))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs, userInfos)
|
||||||
if len(tc.expectedReasons) > 0 {
|
if len(tc.expectedReasons) > 0 {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
var foundExpectedErrors int
|
var foundExpectedErrors int
|
||||||
|
Loading…
x
Reference in New Issue
Block a user