Redeem invite only when incoming user was invited (#1861)

checks for users with pending invite status in the cache that already logged in and refresh the cache
This commit is contained in:
Maycon Santos 2024-04-22 11:10:27 +02:00 committed by GitHub
parent 9e01155d2e
commit a80c8b0176
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 41 additions and 16 deletions

View File

@ -46,6 +46,8 @@ const (
DefaultPeerLoginExpiration = 24 * time.Hour DefaultPeerLoginExpiration = 24 * time.Hour
) )
type userLoggedInOnce bool
type ExternalCacheManager cache.CacheInterface[*idp.UserData] type ExternalCacheManager cache.CacheInterface[*idp.UserData]
func cacheEntryExpiration() time.Duration { func cacheEntryExpiration() time.Duration {
@ -1092,13 +1094,15 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
} }
delete(userData, idp.UnsetAccountID) delete(userData, idp.UnsetAccountID)
rcvdUsers := 0
for accountID, users := range userData { for accountID, users := range userData {
rcvdUsers += len(users)
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration())) err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
if err != nil { if err != nil {
return err return err
} }
} }
log.Infof("warmed up IDP cache with %d entries", len(userData)) log.Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData))
return nil return nil
} }
@ -1263,7 +1267,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) { func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
users := make(map[string]struct{}, len(account.Users)) users := make(map[string]userLoggedInOnce, len(account.Users))
// ignore service users and users provisioned by integrations than are never logged in // ignore service users and users provisioned by integrations than are never logged in
for _, user := range account.Users { for _, user := range account.Users {
if user.IsServiceUser { if user.IsServiceUser {
@ -1272,7 +1276,7 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou
if user.Issued == UserIssuedIntegration { if user.Issued == UserIssuedIntegration {
continue continue
} }
users[user.Id] = struct{}{} users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
} }
log.Debugf("looking up user %s of account %s in cache", userID, account.Id) log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(users, account.Id) userData, err := am.lookupCache(users, account.Id)
@ -1345,22 +1349,30 @@ func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceRelo
} }
} }
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, accountID string) ([]*idp.UserData, error) { func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedInOnce, accountID string) ([]*idp.UserData, error) {
data, err := am.getAccountFromCache(accountID, false) data, err := am.getAccountFromCache(accountID, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userDataMap := make(map[string]struct{}) userDataMap := make(map[string]*idp.UserData, len(data))
for _, datum := range data { for _, datum := range data {
userDataMap[datum.ID] = struct{}{} userDataMap[datum.ID] = datum
} }
mustRefreshInviteStatus := false
// the accountUsers ID list of non integration users from store, we check if cache has all of them // the accountUsers ID list of non integration users from store, we check if cache has all of them
// as result of for loop knownUsersCount will have number of users are not presented in the cashed // as result of for loop knownUsersCount will have number of users are not presented in the cashed
knownUsersCount := len(accountUsers) knownUsersCount := len(accountUsers)
for user := range accountUsers { for user, loggedInOnce := range accountUsers {
if _, ok := userDataMap[user]; ok { if datum, ok := userDataMap[user]; ok {
// check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple
mustRefreshInviteStatus = true
log.Infof("user %s has a pending invite and has logged in once, forcing cache refresh", user)
break
}
knownUsersCount-- knownUsersCount--
continue continue
} }
@ -1368,8 +1380,10 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
} }
// if we know users that are not yet in cache more likely cache is outdated // if we know users that are not yet in cache more likely cache is outdated
if knownUsersCount > 0 { if knownUsersCount > 0 || mustRefreshInviteStatus {
log.Debugf("cache doesn't know about %d users from store, reloading", knownUsersCount) if !mustRefreshInviteStatus {
log.Infof("reloading cache with IDP manager. Users unknown to the cache: %d", knownUsersCount)
}
// reload cache once avoiding loops // reload cache once avoiding loops
data, err = am.refreshCache(accountID) data, err = am.refreshCache(accountID)
if err != nil { if err != nil {
@ -1649,7 +1663,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId) return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
} }
if !user.IsServiceUser { if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(account, claims.UserId) err = am.redeemInvite(account, claims.UserId)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View File

@ -198,8 +198,6 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
serviceUser := r.URL.Query().Get("service_user") serviceUser := r.URL.Query().Get("service_user")
log.Debugf("UserCount: %v", len(data))
users := make([]*api.User, 0) users := make([]*api.User, 0)
for _, r := range data { for _, r := range data {
if r.NonDeletable { if r.NonDeletable {

View File

@ -13,6 +13,7 @@ type AuthorizationClaims struct {
Domain string Domain string
DomainCategory string DomainCategory string
LastLogin time.Time LastLogin time.Time
Invited bool
Raw jwt.MapClaims Raw jwt.MapClaims
} }

View File

@ -20,6 +20,8 @@ const (
UserIDClaim = "sub" UserIDClaim = "sub"
// LastLoginSuffix claim for the last login // LastLoginSuffix claim for the last login
LastLoginSuffix = "nb_last_login" LastLoginSuffix = "nb_last_login"
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
Invited = "nb_invited"
) )
// ExtractClaims Extract function type // ExtractClaims Extract function type
@ -100,6 +102,10 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
if ok { if ok {
jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string)) jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string))
} }
invitedBool, ok := claims[c.authAudience+Invited]
if ok {
jwtClaims.Invited = invitedBool.(bool)
}
return jwtClaims return jwtClaims
} }

View File

@ -30,6 +30,10 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audience st
if claims.LastLogin != (time.Time{}) { if claims.LastLogin != (time.Time{}) {
claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout) claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout)
} }
if claims.Invited {
claimMaps[audience+Invited] = true
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
require.NoError(t, err, "creating testing request failed") require.NoError(t, err, "creating testing request failed")
@ -59,12 +63,14 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
AccountId: "testAcc", AccountId: "testAcc",
LastLogin: lastLogin, LastLogin: lastLogin,
DomainCategory: "public", DomainCategory: "public",
Invited: true,
Raw: jwt.MapClaims{ Raw: jwt.MapClaims{
"https://login/wt_account_domain": "test.com", "https://login/wt_account_domain": "test.com",
"https://login/wt_account_domain_category": "public", "https://login/wt_account_domain_category": "public",
"https://login/wt_account_id": "testAcc", "https://login/wt_account_id": "testAcc",
"https://login/nb_last_login": lastLogin.Format(layout), "https://login/nb_last_login": lastLogin.Format(layout),
"sub": "test", "sub": "test",
"https://login/" + Invited: true,
}, },
}, },
testingFunc: require.EqualValues, testingFunc: require.EqualValues,

View File

@ -960,7 +960,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
queriedUsers := make([]*idp.UserData, 0) queriedUsers := make([]*idp.UserData, 0)
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
users := make(map[string]struct{}, len(account.Users)) users := make(map[string]userLoggedInOnce, len(account.Users))
usersFromIntegration := make([]*idp.UserData, 0) usersFromIntegration := make([]*idp.UserData, 0)
for _, user := range account.Users { for _, user := range account.Users {
if user.Issued == UserIssuedIntegration { if user.Issued == UserIssuedIntegration {
@ -968,14 +968,14 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
info, err := am.externalCacheManager.Get(am.ctx, key) info, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil { if err != nil {
log.Infof("Get ExternalCache for key: %s, error: %s", key, err) log.Infof("Get ExternalCache for key: %s, error: %s", key, err)
users[user.Id] = struct{}{} users[user.Id] = true
continue continue
} }
usersFromIntegration = append(usersFromIntegration, info) usersFromIntegration = append(usersFromIntegration, info)
continue continue
} }
if !user.IsServiceUser { if !user.IsServiceUser {
users[user.Id] = struct{}{} users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
} }
} }
queriedUsers, err = am.lookupCache(users, accountID) queriedUsers, err = am.lookupCache(users, accountID)