diff --git a/management/server/account_test.go b/management/server/account_test.go index 25d501c8b..f21c93f0e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -495,6 +495,44 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat) } +func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { + store := newStore(t) + account := newAccountWithId("account_id", "testuser", "") + + token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" + hashedToken := sha256.Sum256([]byte(token)) + encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) + account.Users["someUser"] = &User{ + Id: "someUser", + PATs: map[string]*PersonalAccessToken{ + "tokenId": { + ID: "tokenId", + HashedToken: encodedHashedToken, + LastUsed: time.Time{}, + }, + }, + } + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + err = am.MarkPATUsed("tokenId") + if err != nil { + t.Fatalf("Error when marking PAT used: %s", err) + } + + account, err = am.Store.GetAccount("account_id") + if err != nil { + t.Fatalf("Error when getting account: %s", err) + } + assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero()) +} + func TestAccountManager_PrivateAccount(t *testing.T) { manager, err := createManager(t) if err != nil { diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 0b1478ef3..54889466e 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -124,10 +124,10 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ } claimMaps := jwt.MapClaims{} - claimMaps[jwtclaims.UserIDClaim] = user.Id - claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id - claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain - claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory + claimMaps[string(jwtclaims.UserIDClaim)] = user.Id + claimMaps[m.audience+string(jwtclaims.AccountIDSuffix)] = account.Id + claimMaps[m.audience+string(jwtclaims.DomainIDSuffix)] = account.Domain + claimMaps[m.audience+string(jwtclaims.DomainCategorySuffix)] = account.DomainCategory jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) // Update the current request with the new context information. diff --git a/management/server/http/middleware/auth_midleware_test.go b/management/server/http/middleware/auth_midleware_test.go deleted file mode 100644 index c870d7c16..000000000 --- a/management/server/http/middleware/auth_midleware_test.go +++ /dev/null @@ -1 +0,0 @@ -package middleware diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index 9d60da335..5063d7b91 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -6,12 +6,14 @@ import ( "github.com/golang-jwt/jwt" ) +type key string + const ( - TokenUserProperty = "user" - AccountIDSuffix = "wt_account_id" - DomainIDSuffix = "wt_account_domain" - DomainCategorySuffix = "wt_account_domain_category" - UserIDClaim = "sub" + TokenUserProperty key = "user" + AccountIDSuffix key = "wt_account_id" + DomainIDSuffix key = "wt_account_domain" + DomainCategorySuffix key = "wt_account_domain_category" + UserIDClaim key = "sub" ) // Extract function type @@ -60,7 +62,7 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor { ce.FromRequestContext = ce.fromRequestContext } if ce.userIDClaim == "" { - ce.userIDClaim = UserIDClaim + ce.userIDClaim = string(UserIDClaim) } return ce } @@ -74,15 +76,15 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims { return jwtClaims } jwtClaims.UserId = userID - accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix] + accountIDClaim, ok := claims[c.authAudience+string(AccountIDSuffix)] if ok { jwtClaims.AccountId = accountIDClaim.(string) } - domainClaim, ok := claims[c.authAudience+DomainIDSuffix] + domainClaim, ok := claims[c.authAudience+string(DomainIDSuffix)] if ok { jwtClaims.Domain = domainClaim.(string) } - domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix] + domainCategoryClaim, ok := claims[c.authAudience+string(DomainCategorySuffix)] if ok { jwtClaims.DomainCategory = domainCategoryClaim.(string) } diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go index d8acd79b6..d8476e039 100644 --- a/management/server/jwtclaims/extractor_test.go +++ b/management/server/jwtclaims/extractor_test.go @@ -12,21 +12,21 @@ import ( func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request { claimMaps := jwt.MapClaims{} if claims.UserId != "" { - claimMaps[UserIDClaim] = claims.UserId + claimMaps[string(UserIDClaim)] = claims.UserId } if claims.AccountId != "" { - claimMaps[audiance+AccountIDSuffix] = claims.AccountId + claimMaps[audiance+string(AccountIDSuffix)] = claims.AccountId } if claims.Domain != "" { - claimMaps[audiance+DomainIDSuffix] = claims.Domain + claimMaps[audiance+string(DomainIDSuffix)] = claims.Domain } if claims.DomainCategory != "" { - claimMaps[audiance+DomainCategorySuffix] = claims.DomainCategory + claimMaps[audiance+string(DomainCategorySuffix)] = claims.DomainCategory } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) require.NoError(t, err, "creating testing request failed") - testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) //nolint + testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) // nolint return testRequest } @@ -124,7 +124,7 @@ func TestExtractClaimsSetOptions(t *testing.T) { t.Error("audience should be empty") return } - if c.extractor.userIDClaim != UserIDClaim { + if c.extractor.userIDClaim != string(UserIDClaim) { t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim) return }