mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-30 18:46:08 +02:00
Redeem invite only when incoming user was invited (#1861)
checks for users with pending invite status in the cache that already logged in and refresh the cache
This commit is contained in:
parent
9e01155d2e
commit
a80c8b0176
@ -46,6 +46,8 @@ const (
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
)
|
||||
|
||||
type userLoggedInOnce bool
|
||||
|
||||
type ExternalCacheManager cache.CacheInterface[*idp.UserData]
|
||||
|
||||
func cacheEntryExpiration() time.Duration {
|
||||
@ -1092,13 +1094,15 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
|
||||
}
|
||||
delete(userData, idp.UnsetAccountID)
|
||||
|
||||
rcvdUsers := 0
|
||||
for accountID, users := range userData {
|
||||
rcvdUsers += len(users)
|
||||
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
log.Infof("warmed up IDP cache with %d entries", len(userData))
|
||||
log.Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData))
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1263,7 +1267,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
|
||||
|
||||
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
|
||||
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
|
||||
users := make(map[string]struct{}, len(account.Users))
|
||||
users := make(map[string]userLoggedInOnce, len(account.Users))
|
||||
// ignore service users and users provisioned by integrations than are never logged in
|
||||
for _, user := range account.Users {
|
||||
if user.IsServiceUser {
|
||||
@ -1272,7 +1276,7 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou
|
||||
if user.Issued == UserIssuedIntegration {
|
||||
continue
|
||||
}
|
||||
users[user.Id] = struct{}{}
|
||||
users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
|
||||
}
|
||||
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
|
||||
userData, err := am.lookupCache(users, account.Id)
|
||||
@ -1345,22 +1349,30 @@ func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceRelo
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, accountID string) ([]*idp.UserData, error) {
|
||||
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedInOnce, accountID string) ([]*idp.UserData, error) {
|
||||
data, err := am.getAccountFromCache(accountID, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userDataMap := make(map[string]struct{})
|
||||
userDataMap := make(map[string]*idp.UserData, len(data))
|
||||
for _, datum := range data {
|
||||
userDataMap[datum.ID] = struct{}{}
|
||||
userDataMap[datum.ID] = datum
|
||||
}
|
||||
|
||||
mustRefreshInviteStatus := false
|
||||
|
||||
// the accountUsers ID list of non integration users from store, we check if cache has all of them
|
||||
// as result of for loop knownUsersCount will have number of users are not presented in the cashed
|
||||
knownUsersCount := len(accountUsers)
|
||||
for user := range accountUsers {
|
||||
if _, ok := userDataMap[user]; ok {
|
||||
for user, loggedInOnce := range accountUsers {
|
||||
if datum, ok := userDataMap[user]; ok {
|
||||
// check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed
|
||||
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple
|
||||
mustRefreshInviteStatus = true
|
||||
log.Infof("user %s has a pending invite and has logged in once, forcing cache refresh", user)
|
||||
break
|
||||
}
|
||||
knownUsersCount--
|
||||
continue
|
||||
}
|
||||
@ -1368,8 +1380,10 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
|
||||
}
|
||||
|
||||
// if we know users that are not yet in cache more likely cache is outdated
|
||||
if knownUsersCount > 0 {
|
||||
log.Debugf("cache doesn't know about %d users from store, reloading", knownUsersCount)
|
||||
if knownUsersCount > 0 || mustRefreshInviteStatus {
|
||||
if !mustRefreshInviteStatus {
|
||||
log.Infof("reloading cache with IDP manager. Users unknown to the cache: %d", knownUsersCount)
|
||||
}
|
||||
// reload cache once avoiding loops
|
||||
data, err = am.refreshCache(accountID)
|
||||
if err != nil {
|
||||
@ -1649,7 +1663,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
||||
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||
}
|
||||
|
||||
if !user.IsServiceUser {
|
||||
if !user.IsServiceUser && claims.Invited {
|
||||
err = am.redeemInvite(account, claims.UserId)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
@ -198,8 +198,6 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
serviceUser := r.URL.Query().Get("service_user")
|
||||
|
||||
log.Debugf("UserCount: %v", len(data))
|
||||
|
||||
users := make([]*api.User, 0)
|
||||
for _, r := range data {
|
||||
if r.NonDeletable {
|
||||
|
@ -13,6 +13,7 @@ type AuthorizationClaims struct {
|
||||
Domain string
|
||||
DomainCategory string
|
||||
LastLogin time.Time
|
||||
Invited bool
|
||||
|
||||
Raw jwt.MapClaims
|
||||
}
|
||||
|
@ -20,6 +20,8 @@ const (
|
||||
UserIDClaim = "sub"
|
||||
// LastLoginSuffix claim for the last login
|
||||
LastLoginSuffix = "nb_last_login"
|
||||
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
|
||||
Invited = "nb_invited"
|
||||
)
|
||||
|
||||
// ExtractClaims Extract function type
|
||||
@ -100,6 +102,10 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
|
||||
if ok {
|
||||
jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string))
|
||||
}
|
||||
invitedBool, ok := claims[c.authAudience+Invited]
|
||||
if ok {
|
||||
jwtClaims.Invited = invitedBool.(bool)
|
||||
}
|
||||
return jwtClaims
|
||||
}
|
||||
|
||||
|
@ -30,6 +30,10 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audience st
|
||||
if claims.LastLogin != (time.Time{}) {
|
||||
claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout)
|
||||
}
|
||||
|
||||
if claims.Invited {
|
||||
claimMaps[audience+Invited] = true
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
require.NoError(t, err, "creating testing request failed")
|
||||
@ -59,12 +63,14 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
AccountId: "testAcc",
|
||||
LastLogin: lastLogin,
|
||||
DomainCategory: "public",
|
||||
Invited: true,
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_domain": "test.com",
|
||||
"https://login/wt_account_domain_category": "public",
|
||||
"https://login/wt_account_id": "testAcc",
|
||||
"https://login/nb_last_login": lastLogin.Format(layout),
|
||||
"sub": "test",
|
||||
"https://login/" + Invited: true,
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
|
@ -960,7 +960,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
||||
|
||||
queriedUsers := make([]*idp.UserData, 0)
|
||||
if !isNil(am.idpManager) {
|
||||
users := make(map[string]struct{}, len(account.Users))
|
||||
users := make(map[string]userLoggedInOnce, len(account.Users))
|
||||
usersFromIntegration := make([]*idp.UserData, 0)
|
||||
for _, user := range account.Users {
|
||||
if user.Issued == UserIssuedIntegration {
|
||||
@ -968,14 +968,14 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
||||
info, err := am.externalCacheManager.Get(am.ctx, key)
|
||||
if err != nil {
|
||||
log.Infof("Get ExternalCache for key: %s, error: %s", key, err)
|
||||
users[user.Id] = struct{}{}
|
||||
users[user.Id] = true
|
||||
continue
|
||||
}
|
||||
usersFromIntegration = append(usersFromIntegration, info)
|
||||
continue
|
||||
}
|
||||
if !user.IsServiceUser {
|
||||
users[user.Id] = struct{}{}
|
||||
users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
|
||||
}
|
||||
}
|
||||
queriedUsers, err = am.lookupCache(users, accountID)
|
||||
|
Loading…
Reference in New Issue
Block a user