From 2bdb4cb44a8c128214eac0933d040f2e8447a92a Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 31 Dec 2024 18:59:37 +0300 Subject: [PATCH] [management] Preserve jwt groups when accessing API with PAT (#3128) * Skip JWT group sync for token-based authentication Signed-off-by: bcmmbaga * Add tests Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/account.go | 6 ++++ management/server/account_test.go | 30 +++++++++++++++++-- .../server/http/middleware/auth_middleware.go | 1 + management/server/jwtclaims/extractor.go | 2 ++ 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index e60b41b4e..83a8759f9 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1252,6 +1252,12 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { + if claim, exists := claims.Raw[jwtclaims.IsToken]; exists { + if isToken, ok := claim.(bool); ok && isToken { + return nil + } + } + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err diff --git a/management/server/account_test.go b/management/server/account_test.go index 280d998fd..2289c96f9 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2729,6 +2729,19 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") + t.Run("skip sync for token auth type", func(t *testing.T) { + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced") + }) + t.Run("empty jwt groups", func(t *testing.T) { claims := jwtclaims.AuthorizationClaims{ UserId: "user1", @@ -2822,7 +2835,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.Len(t, user.AutoGroups, 1, "new group should be added") }) - t.Run("remove all JWT groups", func(t *testing.T) { + t.Run("remove all JWT groups when list is empty", func(t *testing.T) { claims := jwtclaims.AuthorizationClaims{ UserId: "user1", Raw: jwt.MapClaims{"groups": []interface{}{}}, @@ -2833,7 +2846,20 @@ func TestAccount_SetJWTGroups(t *testing.T) { user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") - assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") + assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present") + }) + + t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) { + claims := jwtclaims.AuthorizationClaims{ + UserId: "user2", + Raw: jwt.MapClaims{}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed") }) } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 0d3459712..0a54cbaed 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -175,6 +175,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory + claimMaps[jwtclaims.IsToken] = true jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint // Update the current request with the new context information. diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index c441650e9..18214b434 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -22,6 +22,8 @@ const ( LastLoginSuffix = "nb_last_login" // Invited claim indicates that an incoming JWT is from a user that just accepted an invitation Invited = "nb_invited" + // IsToken claim indicates that auth type from the user is a token + IsToken = "is_token" ) // ExtractClaims Extract function type