add test + codacy

This commit is contained in:
Pascal Fischer 2023-03-30 16:43:39 +02:00
parent 2a79995706
commit 1343a3f00e
5 changed files with 59 additions and 20 deletions

View File

@ -495,6 +495,44 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat) assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
} }
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
store := newStore(t)
account := newAccountWithId("account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
Id: "someUser",
PATs: map[string]*PersonalAccessToken{
"tokenId": {
ID: "tokenId",
HashedToken: encodedHashedToken,
LastUsed: time.Time{},
},
},
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
}
err = am.MarkPATUsed("tokenId")
if err != nil {
t.Fatalf("Error when marking PAT used: %s", err)
}
account, err = am.Store.GetAccount("account_id")
if err != nil {
t.Fatalf("Error when getting account: %s", err)
}
assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero())
}
func TestAccountManager_PrivateAccount(t *testing.T) { func TestAccountManager_PrivateAccount(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
if err != nil { if err != nil {

View File

@ -124,10 +124,10 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ
} }
claimMaps := jwt.MapClaims{} claimMaps := jwt.MapClaims{}
claimMaps[jwtclaims.UserIDClaim] = user.Id claimMaps[string(jwtclaims.UserIDClaim)] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id claimMaps[m.audience+string(jwtclaims.AccountIDSuffix)] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain claimMaps[m.audience+string(jwtclaims.DomainIDSuffix)] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory claimMaps[m.audience+string(jwtclaims.DomainCategorySuffix)] = account.DomainCategory
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken))
// Update the current request with the new context information. // Update the current request with the new context information.

View File

@ -1 +0,0 @@
package middleware

View File

@ -6,12 +6,14 @@ import (
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
) )
type key string
const ( const (
TokenUserProperty = "user" TokenUserProperty key = "user"
AccountIDSuffix = "wt_account_id" AccountIDSuffix key = "wt_account_id"
DomainIDSuffix = "wt_account_domain" DomainIDSuffix key = "wt_account_domain"
DomainCategorySuffix = "wt_account_domain_category" DomainCategorySuffix key = "wt_account_domain_category"
UserIDClaim = "sub" UserIDClaim key = "sub"
) )
// Extract function type // Extract function type
@ -60,7 +62,7 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
ce.FromRequestContext = ce.fromRequestContext ce.FromRequestContext = ce.fromRequestContext
} }
if ce.userIDClaim == "" { if ce.userIDClaim == "" {
ce.userIDClaim = UserIDClaim ce.userIDClaim = string(UserIDClaim)
} }
return ce return ce
} }
@ -74,15 +76,15 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
return jwtClaims return jwtClaims
} }
jwtClaims.UserId = userID jwtClaims.UserId = userID
accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix] accountIDClaim, ok := claims[c.authAudience+string(AccountIDSuffix)]
if ok { if ok {
jwtClaims.AccountId = accountIDClaim.(string) jwtClaims.AccountId = accountIDClaim.(string)
} }
domainClaim, ok := claims[c.authAudience+DomainIDSuffix] domainClaim, ok := claims[c.authAudience+string(DomainIDSuffix)]
if ok { if ok {
jwtClaims.Domain = domainClaim.(string) jwtClaims.Domain = domainClaim.(string)
} }
domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix] domainCategoryClaim, ok := claims[c.authAudience+string(DomainCategorySuffix)]
if ok { if ok {
jwtClaims.DomainCategory = domainCategoryClaim.(string) jwtClaims.DomainCategory = domainCategoryClaim.(string)
} }

View File

@ -12,21 +12,21 @@ import (
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request { func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request {
claimMaps := jwt.MapClaims{} claimMaps := jwt.MapClaims{}
if claims.UserId != "" { if claims.UserId != "" {
claimMaps[UserIDClaim] = claims.UserId claimMaps[string(UserIDClaim)] = claims.UserId
} }
if claims.AccountId != "" { if claims.AccountId != "" {
claimMaps[audiance+AccountIDSuffix] = claims.AccountId claimMaps[audiance+string(AccountIDSuffix)] = claims.AccountId
} }
if claims.Domain != "" { if claims.Domain != "" {
claimMaps[audiance+DomainIDSuffix] = claims.Domain claimMaps[audiance+string(DomainIDSuffix)] = claims.Domain
} }
if claims.DomainCategory != "" { if claims.DomainCategory != "" {
claimMaps[audiance+DomainCategorySuffix] = claims.DomainCategory claimMaps[audiance+string(DomainCategorySuffix)] = claims.DomainCategory
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
require.NoError(t, err, "creating testing request failed") require.NoError(t, err, "creating testing request failed")
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) //nolint testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) // nolint
return testRequest return testRequest
} }
@ -124,7 +124,7 @@ func TestExtractClaimsSetOptions(t *testing.T) {
t.Error("audience should be empty") t.Error("audience should be empty")
return return
} }
if c.extractor.userIDClaim != UserIDClaim { if c.extractor.userIDClaim != string(UserIDClaim) {
t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim) t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim)
return return
} }