mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-16 18:21:24 +01:00
add test + codacy
This commit is contained in:
parent
2a79995706
commit
1343a3f00e
@ -495,6 +495,44 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
|||||||
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
|
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
account := newAccountWithId("account_id", "testuser", "")
|
||||||
|
|
||||||
|
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||||
|
hashedToken := sha256.Sum256([]byte(token))
|
||||||
|
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||||
|
account.Users["someUser"] = &User{
|
||||||
|
Id: "someUser",
|
||||||
|
PATs: map[string]*PersonalAccessToken{
|
||||||
|
"tokenId": {
|
||||||
|
ID: "tokenId",
|
||||||
|
HashedToken: encodedHashedToken,
|
||||||
|
LastUsed: time.Time{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := store.SaveAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error when saving account: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
am := DefaultAccountManager{
|
||||||
|
Store: store,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = am.MarkPATUsed("tokenId")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error when marking PAT used: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err = am.Store.GetAccount("account_id")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error when getting account: %s", err)
|
||||||
|
}
|
||||||
|
assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero())
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountManager_PrivateAccount(t *testing.T) {
|
func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -124,10 +124,10 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ
|
|||||||
}
|
}
|
||||||
|
|
||||||
claimMaps := jwt.MapClaims{}
|
claimMaps := jwt.MapClaims{}
|
||||||
claimMaps[jwtclaims.UserIDClaim] = user.Id
|
claimMaps[string(jwtclaims.UserIDClaim)] = user.Id
|
||||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
|
claimMaps[m.audience+string(jwtclaims.AccountIDSuffix)] = account.Id
|
||||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
|
claimMaps[m.audience+string(jwtclaims.DomainIDSuffix)] = account.Domain
|
||||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
|
claimMaps[m.audience+string(jwtclaims.DomainCategorySuffix)] = account.DomainCategory
|
||||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken))
|
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken))
|
||||||
// Update the current request with the new context information.
|
// Update the current request with the new context information.
|
||||||
|
@ -1 +0,0 @@
|
|||||||
package middleware
|
|
@ -6,12 +6,14 @@ import (
|
|||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type key string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TokenUserProperty = "user"
|
TokenUserProperty key = "user"
|
||||||
AccountIDSuffix = "wt_account_id"
|
AccountIDSuffix key = "wt_account_id"
|
||||||
DomainIDSuffix = "wt_account_domain"
|
DomainIDSuffix key = "wt_account_domain"
|
||||||
DomainCategorySuffix = "wt_account_domain_category"
|
DomainCategorySuffix key = "wt_account_domain_category"
|
||||||
UserIDClaim = "sub"
|
UserIDClaim key = "sub"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Extract function type
|
// Extract function type
|
||||||
@ -60,7 +62,7 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
|
|||||||
ce.FromRequestContext = ce.fromRequestContext
|
ce.FromRequestContext = ce.fromRequestContext
|
||||||
}
|
}
|
||||||
if ce.userIDClaim == "" {
|
if ce.userIDClaim == "" {
|
||||||
ce.userIDClaim = UserIDClaim
|
ce.userIDClaim = string(UserIDClaim)
|
||||||
}
|
}
|
||||||
return ce
|
return ce
|
||||||
}
|
}
|
||||||
@ -74,15 +76,15 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
|
|||||||
return jwtClaims
|
return jwtClaims
|
||||||
}
|
}
|
||||||
jwtClaims.UserId = userID
|
jwtClaims.UserId = userID
|
||||||
accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix]
|
accountIDClaim, ok := claims[c.authAudience+string(AccountIDSuffix)]
|
||||||
if ok {
|
if ok {
|
||||||
jwtClaims.AccountId = accountIDClaim.(string)
|
jwtClaims.AccountId = accountIDClaim.(string)
|
||||||
}
|
}
|
||||||
domainClaim, ok := claims[c.authAudience+DomainIDSuffix]
|
domainClaim, ok := claims[c.authAudience+string(DomainIDSuffix)]
|
||||||
if ok {
|
if ok {
|
||||||
jwtClaims.Domain = domainClaim.(string)
|
jwtClaims.Domain = domainClaim.(string)
|
||||||
}
|
}
|
||||||
domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix]
|
domainCategoryClaim, ok := claims[c.authAudience+string(DomainCategorySuffix)]
|
||||||
if ok {
|
if ok {
|
||||||
jwtClaims.DomainCategory = domainCategoryClaim.(string)
|
jwtClaims.DomainCategory = domainCategoryClaim.(string)
|
||||||
}
|
}
|
||||||
|
@ -12,21 +12,21 @@ import (
|
|||||||
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request {
|
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request {
|
||||||
claimMaps := jwt.MapClaims{}
|
claimMaps := jwt.MapClaims{}
|
||||||
if claims.UserId != "" {
|
if claims.UserId != "" {
|
||||||
claimMaps[UserIDClaim] = claims.UserId
|
claimMaps[string(UserIDClaim)] = claims.UserId
|
||||||
}
|
}
|
||||||
if claims.AccountId != "" {
|
if claims.AccountId != "" {
|
||||||
claimMaps[audiance+AccountIDSuffix] = claims.AccountId
|
claimMaps[audiance+string(AccountIDSuffix)] = claims.AccountId
|
||||||
}
|
}
|
||||||
if claims.Domain != "" {
|
if claims.Domain != "" {
|
||||||
claimMaps[audiance+DomainIDSuffix] = claims.Domain
|
claimMaps[audiance+string(DomainIDSuffix)] = claims.Domain
|
||||||
}
|
}
|
||||||
if claims.DomainCategory != "" {
|
if claims.DomainCategory != "" {
|
||||||
claimMaps[audiance+DomainCategorySuffix] = claims.DomainCategory
|
claimMaps[audiance+string(DomainCategorySuffix)] = claims.DomainCategory
|
||||||
}
|
}
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||||
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
|
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||||
require.NoError(t, err, "creating testing request failed")
|
require.NoError(t, err, "creating testing request failed")
|
||||||
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) //nolint
|
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) // nolint
|
||||||
|
|
||||||
return testRequest
|
return testRequest
|
||||||
}
|
}
|
||||||
@ -124,7 +124,7 @@ func TestExtractClaimsSetOptions(t *testing.T) {
|
|||||||
t.Error("audience should be empty")
|
t.Error("audience should be empty")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if c.extractor.userIDClaim != UserIDClaim {
|
if c.extractor.userIDClaim != string(UserIDClaim) {
|
||||||
t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim)
|
t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user