Jwtclaims package (#242)

* Move JWTClaims logic to its own package

* Add extractor tests
This commit is contained in:
Maycon Santos 2022-02-23 20:02:02 +01:00 committed by GitHub
parent 5f5cbf7e20
commit b29948b910
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 164 additions and 53 deletions

View File

@ -3,6 +3,7 @@ package handler
import (
"encoding/json"
"fmt"
"github.com/wiretrustee/wiretrustee/management/server/jwtclaims"
"net/http"
"time"
@ -15,7 +16,7 @@ import (
type Peers struct {
accountManager server.AccountManager
authAudience string
jwtExtractor JWTClaimsExtractor
jwtExtractor jwtclaims.ClaimsExtractor
}
//PeerResponse is a response sent to the client
@ -37,7 +38,7 @@ func NewPeers(accountManager server.AccountManager, authAudience string) *Peers
return &Peers{
accountManager: accountManager,
authAudience: authAudience,
jwtExtractor: *NewJWTClaimsExtractor(nil),
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
}
}
@ -69,7 +70,7 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW
}
func (h *Peers) getPeerAccount(r *http.Request) (*server.Account, error) {
jwtClaims := h.jwtExtractor.extractClaimsFromRequestContext(r, h.authAudience)
jwtClaims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
if err != nil {

View File

@ -2,6 +2,7 @@ package handler
import (
"encoding/json"
"github.com/wiretrustee/wiretrustee/management/server/jwtclaims"
"io"
"net"
"net/http"
@ -27,9 +28,9 @@ func initTestMetaData(peer ...*server.Peer) *Peers {
},
},
authAudience: "",
jwtExtractor: JWTClaimsExtractor{
extractClaimsFromRequestContext: func(r *http.Request, authAudiance string) JWTClaims {
return JWTClaims{
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",

View File

@ -3,6 +3,7 @@ package handler
import (
"encoding/json"
"fmt"
"github.com/wiretrustee/wiretrustee/management/server/jwtclaims"
"net/http"
"time"
@ -122,8 +123,8 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
}
func (h *SetupKeys) getSetupKeyAccount(r *http.Request) (*server.Account, error) {
extractor := NewJWTClaimsExtractor(nil)
jwtClaims := extractor.extractClaimsFromRequestContext(r, h.authAudience)
extractor := jwtclaims.NewClaimsExtractor(nil)
jwtClaims := extractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
if err != nil {

View File

@ -5,53 +5,8 @@ import (
"errors"
"net/http"
"time"
"github.com/golang-jwt/jwt"
)
// JWTClaims stores information from JWTs
type JWTClaims struct {
UserId string
AccountId string
Domain string
}
type extractJWTClaims func(r *http.Request, authAudiance string) JWTClaims
type JWTClaimsExtractor struct {
extractClaimsFromRequestContext extractJWTClaims
}
// NewJWTClaimsExtractor returns an extractor, and if provided with a function with extractJWTClaims signature,
// then it will use that logic. Uses extractClaimsFromRequestContext by default
func NewJWTClaimsExtractor(e extractJWTClaims) *JWTClaimsExtractor {
var extractFunc extractJWTClaims
if extractFunc = e; extractFunc == nil {
extractFunc = extractClaimsFromRequestContext
}
return &JWTClaimsExtractor{
extractClaimsFromRequestContext: extractFunc,
}
}
// extractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
func extractClaimsFromRequestContext(r *http.Request, authAudiance string) JWTClaims {
token := r.Context().Value("user").(*jwt.Token)
claims := token.Claims.(jwt.MapClaims)
jwtClaims := JWTClaims{}
jwtClaims.UserId = claims["sub"].(string)
accountIdClaim, ok := claims[authAudiance+"wt_account_id"]
if ok {
jwtClaims.AccountId = accountIdClaim.(string)
}
domainClaim, ok := claims[authAudiance+"wt_user_domain"]
if ok {
jwtClaims.Domain = domainClaim.(string)
}
return jwtClaims
}
//writeJSONObject simply writes object to the HTTP reponse in JSON format
func writeJSONObject(w http.ResponseWriter, obj interface{}) {
w.WriteHeader(http.StatusOK)

View File

@ -0,0 +1,8 @@
package jwtclaims
// AuthorizationClaims stores authorization information from JWTs
type AuthorizationClaims struct {
UserId string
AccountId string
Domain string
}

View File

@ -0,0 +1,51 @@
package jwtclaims
import (
"github.com/golang-jwt/jwt"
"net/http"
)
const (
TokenUserProperty = "user"
AccountIDSuffix = "wt_account_id"
DomainIDSuffix = "wt_account_domain"
UserIDClaim = "sub"
)
// Extract function type
type ExtractClaims func(r *http.Request, authAudiance string) AuthorizationClaims
// ClaimsExtractor struct that holds the extract function
type ClaimsExtractor struct {
ExtractClaimsFromRequestContext ExtractClaims
}
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
func NewClaimsExtractor(e ExtractClaims) *ClaimsExtractor {
var extractFunc ExtractClaims
if extractFunc = e; extractFunc == nil {
extractFunc = ExtractClaimsFromRequestContext
}
return &ClaimsExtractor{
ExtractClaimsFromRequestContext: extractFunc,
}
}
// ExtractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
func ExtractClaimsFromRequestContext(r *http.Request, authAudiance string) AuthorizationClaims {
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
claims := token.Claims.(jwt.MapClaims)
jwtClaims := AuthorizationClaims{}
jwtClaims.UserId = claims[UserIDClaim].(string)
accountIdClaim, ok := claims[authAudiance+AccountIDSuffix]
if ok {
jwtClaims.AccountId = accountIdClaim.(string)
}
domainClaim, ok := claims[authAudiance+DomainIDSuffix]
if ok {
jwtClaims.Domain = domainClaim.(string)
}
return jwtClaims
}

View File

@ -0,0 +1,94 @@
package jwtclaims
import (
"context"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
"net/http"
"testing"
)
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request {
claimMaps := jwt.MapClaims{}
if claims.UserId != "" {
claimMaps[UserIDClaim] = claims.UserId
}
if claims.AccountId != "" {
claimMaps[audiance+AccountIDSuffix] = claims.AccountId
}
if claims.Domain != "" {
claimMaps[audiance+DomainIDSuffix] = claims.Domain
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
require.NoError(t, err, "creating testing request failed")
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) //nolint
return testRequest
}
func TestExtractClaimsFromRequestContext(t *testing.T) {
type test struct {
name string
inputAuthorizationClaims AuthorizationClaims
inputAudiance string
testingFunc require.ComparisonAssertionFunc
expectedMSG string
}
testCase1 := test{
name: "All Claim Fields",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Domain: "test.com",
AccountId: "testAcc",
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
testCase2 := test{
name: "Domain Is Empty",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
AccountId: "testAcc",
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
testCase3 := test{
name: "Account ID Is Empty",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Domain: "test.com",
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
testCase4 := test{
name: "Only User ID Is set",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
t.Run(testCase.name, func(t *testing.T) {
request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
extractedClaims := ExtractClaimsFromRequestContext(request, testCase.inputAudiance)
testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG)
})
}
}