mirror of
https://github.com/netbirdio/netbird.git
synced 2025-05-29 14:22:41 +02:00
Redeem invite only when incoming user was invited (#1861)
checks for users with pending invite status in the cache that already logged in and refresh the cache
This commit is contained in:
parent
9e01155d2e
commit
a80c8b0176
@ -46,6 +46,8 @@ const (
|
|||||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type userLoggedInOnce bool
|
||||||
|
|
||||||
type ExternalCacheManager cache.CacheInterface[*idp.UserData]
|
type ExternalCacheManager cache.CacheInterface[*idp.UserData]
|
||||||
|
|
||||||
func cacheEntryExpiration() time.Duration {
|
func cacheEntryExpiration() time.Duration {
|
||||||
@ -1092,13 +1094,15 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
|
|||||||
}
|
}
|
||||||
delete(userData, idp.UnsetAccountID)
|
delete(userData, idp.UnsetAccountID)
|
||||||
|
|
||||||
|
rcvdUsers := 0
|
||||||
for accountID, users := range userData {
|
for accountID, users := range userData {
|
||||||
|
rcvdUsers += len(users)
|
||||||
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
|
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Infof("warmed up IDP cache with %d entries", len(userData))
|
log.Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1263,7 +1267,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
|
|||||||
|
|
||||||
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
|
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
|
||||||
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
|
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
|
||||||
users := make(map[string]struct{}, len(account.Users))
|
users := make(map[string]userLoggedInOnce, len(account.Users))
|
||||||
// ignore service users and users provisioned by integrations than are never logged in
|
// ignore service users and users provisioned by integrations than are never logged in
|
||||||
for _, user := range account.Users {
|
for _, user := range account.Users {
|
||||||
if user.IsServiceUser {
|
if user.IsServiceUser {
|
||||||
@ -1272,7 +1276,7 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou
|
|||||||
if user.Issued == UserIssuedIntegration {
|
if user.Issued == UserIssuedIntegration {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
users[user.Id] = struct{}{}
|
users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
|
||||||
}
|
}
|
||||||
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
|
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
|
||||||
userData, err := am.lookupCache(users, account.Id)
|
userData, err := am.lookupCache(users, account.Id)
|
||||||
@ -1345,22 +1349,30 @@ func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceRelo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, accountID string) ([]*idp.UserData, error) {
|
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedInOnce, accountID string) ([]*idp.UserData, error) {
|
||||||
data, err := am.getAccountFromCache(accountID, false)
|
data, err := am.getAccountFromCache(accountID, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userDataMap := make(map[string]struct{})
|
userDataMap := make(map[string]*idp.UserData, len(data))
|
||||||
for _, datum := range data {
|
for _, datum := range data {
|
||||||
userDataMap[datum.ID] = struct{}{}
|
userDataMap[datum.ID] = datum
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mustRefreshInviteStatus := false
|
||||||
|
|
||||||
// the accountUsers ID list of non integration users from store, we check if cache has all of them
|
// the accountUsers ID list of non integration users from store, we check if cache has all of them
|
||||||
// as result of for loop knownUsersCount will have number of users are not presented in the cashed
|
// as result of for loop knownUsersCount will have number of users are not presented in the cashed
|
||||||
knownUsersCount := len(accountUsers)
|
knownUsersCount := len(accountUsers)
|
||||||
for user := range accountUsers {
|
for user, loggedInOnce := range accountUsers {
|
||||||
if _, ok := userDataMap[user]; ok {
|
if datum, ok := userDataMap[user]; ok {
|
||||||
|
// check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed
|
||||||
|
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple
|
||||||
|
mustRefreshInviteStatus = true
|
||||||
|
log.Infof("user %s has a pending invite and has logged in once, forcing cache refresh", user)
|
||||||
|
break
|
||||||
|
}
|
||||||
knownUsersCount--
|
knownUsersCount--
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -1368,8 +1380,10 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
|
|||||||
}
|
}
|
||||||
|
|
||||||
// if we know users that are not yet in cache more likely cache is outdated
|
// if we know users that are not yet in cache more likely cache is outdated
|
||||||
if knownUsersCount > 0 {
|
if knownUsersCount > 0 || mustRefreshInviteStatus {
|
||||||
log.Debugf("cache doesn't know about %d users from store, reloading", knownUsersCount)
|
if !mustRefreshInviteStatus {
|
||||||
|
log.Infof("reloading cache with IDP manager. Users unknown to the cache: %d", knownUsersCount)
|
||||||
|
}
|
||||||
// reload cache once avoiding loops
|
// reload cache once avoiding loops
|
||||||
data, err = am.refreshCache(accountID)
|
data, err = am.refreshCache(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1649,7 +1663,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
|||||||
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsServiceUser {
|
if !user.IsServiceUser && claims.Invited {
|
||||||
err = am.redeemInvite(account, claims.UserId)
|
err = am.redeemInvite(account, claims.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
@ -198,8 +198,6 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
serviceUser := r.URL.Query().Get("service_user")
|
serviceUser := r.URL.Query().Get("service_user")
|
||||||
|
|
||||||
log.Debugf("UserCount: %v", len(data))
|
|
||||||
|
|
||||||
users := make([]*api.User, 0)
|
users := make([]*api.User, 0)
|
||||||
for _, r := range data {
|
for _, r := range data {
|
||||||
if r.NonDeletable {
|
if r.NonDeletable {
|
||||||
|
@ -13,6 +13,7 @@ type AuthorizationClaims struct {
|
|||||||
Domain string
|
Domain string
|
||||||
DomainCategory string
|
DomainCategory string
|
||||||
LastLogin time.Time
|
LastLogin time.Time
|
||||||
|
Invited bool
|
||||||
|
|
||||||
Raw jwt.MapClaims
|
Raw jwt.MapClaims
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,8 @@ const (
|
|||||||
UserIDClaim = "sub"
|
UserIDClaim = "sub"
|
||||||
// LastLoginSuffix claim for the last login
|
// LastLoginSuffix claim for the last login
|
||||||
LastLoginSuffix = "nb_last_login"
|
LastLoginSuffix = "nb_last_login"
|
||||||
|
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
|
||||||
|
Invited = "nb_invited"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ExtractClaims Extract function type
|
// ExtractClaims Extract function type
|
||||||
@ -100,6 +102,10 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
|
|||||||
if ok {
|
if ok {
|
||||||
jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string))
|
jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string))
|
||||||
}
|
}
|
||||||
|
invitedBool, ok := claims[c.authAudience+Invited]
|
||||||
|
if ok {
|
||||||
|
jwtClaims.Invited = invitedBool.(bool)
|
||||||
|
}
|
||||||
return jwtClaims
|
return jwtClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -30,6 +30,10 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audience st
|
|||||||
if claims.LastLogin != (time.Time{}) {
|
if claims.LastLogin != (time.Time{}) {
|
||||||
claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout)
|
claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if claims.Invited {
|
||||||
|
claimMaps[audience+Invited] = true
|
||||||
|
}
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||||
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
|
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||||
require.NoError(t, err, "creating testing request failed")
|
require.NoError(t, err, "creating testing request failed")
|
||||||
@ -59,12 +63,14 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
|||||||
AccountId: "testAcc",
|
AccountId: "testAcc",
|
||||||
LastLogin: lastLogin,
|
LastLogin: lastLogin,
|
||||||
DomainCategory: "public",
|
DomainCategory: "public",
|
||||||
|
Invited: true,
|
||||||
Raw: jwt.MapClaims{
|
Raw: jwt.MapClaims{
|
||||||
"https://login/wt_account_domain": "test.com",
|
"https://login/wt_account_domain": "test.com",
|
||||||
"https://login/wt_account_domain_category": "public",
|
"https://login/wt_account_domain_category": "public",
|
||||||
"https://login/wt_account_id": "testAcc",
|
"https://login/wt_account_id": "testAcc",
|
||||||
"https://login/nb_last_login": lastLogin.Format(layout),
|
"https://login/nb_last_login": lastLogin.Format(layout),
|
||||||
"sub": "test",
|
"sub": "test",
|
||||||
|
"https://login/" + Invited: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testingFunc: require.EqualValues,
|
testingFunc: require.EqualValues,
|
||||||
|
@ -960,7 +960,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
|||||||
|
|
||||||
queriedUsers := make([]*idp.UserData, 0)
|
queriedUsers := make([]*idp.UserData, 0)
|
||||||
if !isNil(am.idpManager) {
|
if !isNil(am.idpManager) {
|
||||||
users := make(map[string]struct{}, len(account.Users))
|
users := make(map[string]userLoggedInOnce, len(account.Users))
|
||||||
usersFromIntegration := make([]*idp.UserData, 0)
|
usersFromIntegration := make([]*idp.UserData, 0)
|
||||||
for _, user := range account.Users {
|
for _, user := range account.Users {
|
||||||
if user.Issued == UserIssuedIntegration {
|
if user.Issued == UserIssuedIntegration {
|
||||||
@ -968,14 +968,14 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
|||||||
info, err := am.externalCacheManager.Get(am.ctx, key)
|
info, err := am.externalCacheManager.Get(am.ctx, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Infof("Get ExternalCache for key: %s, error: %s", key, err)
|
log.Infof("Get ExternalCache for key: %s, error: %s", key, err)
|
||||||
users[user.Id] = struct{}{}
|
users[user.Id] = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
usersFromIntegration = append(usersFromIntegration, info)
|
usersFromIntegration = append(usersFromIntegration, info)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !user.IsServiceUser {
|
if !user.IsServiceUser {
|
||||||
users[user.Id] = struct{}{}
|
users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
queriedUsers, err = am.lookupCache(users, accountID)
|
queriedUsers, err = am.lookupCache(users, accountID)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user