diff --git a/management/server/http/handler/peers.go b/management/server/http/handler/peers.go index bdedd723d..9212a4677 100644 --- a/management/server/http/handler/peers.go +++ b/management/server/http/handler/peers.go @@ -3,6 +3,7 @@ package handler import ( "encoding/json" "fmt" + "github.com/wiretrustee/wiretrustee/management/server/jwtclaims" "net/http" "time" @@ -15,7 +16,7 @@ import ( type Peers struct { accountManager server.AccountManager authAudience string - jwtExtractor JWTClaimsExtractor + jwtExtractor jwtclaims.ClaimsExtractor } //PeerResponse is a response sent to the client @@ -37,7 +38,7 @@ func NewPeers(accountManager server.AccountManager, authAudience string) *Peers return &Peers{ accountManager: accountManager, authAudience: authAudience, - jwtExtractor: *NewJWTClaimsExtractor(nil), + jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), } } @@ -69,7 +70,7 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW } func (h *Peers) getPeerAccount(r *http.Request) (*server.Account, error) { - jwtClaims := h.jwtExtractor.extractClaimsFromRequestContext(r, h.authAudience) + jwtClaims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain) if err != nil { diff --git a/management/server/http/handler/peers_test.go b/management/server/http/handler/peers_test.go index 8a67a38d8..bb4aa4a03 100644 --- a/management/server/http/handler/peers_test.go +++ b/management/server/http/handler/peers_test.go @@ -2,6 +2,7 @@ package handler import ( "encoding/json" + "github.com/wiretrustee/wiretrustee/management/server/jwtclaims" "io" "net" "net/http" @@ -27,9 +28,9 @@ func initTestMetaData(peer ...*server.Peer) *Peers { }, }, authAudience: "", - jwtExtractor: JWTClaimsExtractor{ - extractClaimsFromRequestContext: func(r *http.Request, authAudiance string) JWTClaims { - return JWTClaims{ + jwtExtractor: jwtclaims.ClaimsExtractor{ + ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { + return jwtclaims.AuthorizationClaims{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handler/setupkeys.go b/management/server/http/handler/setupkeys.go index 48e24e072..ee3dcd628 100644 --- a/management/server/http/handler/setupkeys.go +++ b/management/server/http/handler/setupkeys.go @@ -3,6 +3,7 @@ package handler import ( "encoding/json" "fmt" + "github.com/wiretrustee/wiretrustee/management/server/jwtclaims" "net/http" "time" @@ -122,8 +123,8 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R } func (h *SetupKeys) getSetupKeyAccount(r *http.Request) (*server.Account, error) { - extractor := NewJWTClaimsExtractor(nil) - jwtClaims := extractor.extractClaimsFromRequestContext(r, h.authAudience) + extractor := jwtclaims.NewClaimsExtractor(nil) + jwtClaims := extractor.ExtractClaimsFromRequestContext(r, h.authAudience) account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain) if err != nil { diff --git a/management/server/http/handler/util.go b/management/server/http/handler/util.go index 1c7b63b53..08b1c7d3e 100644 --- a/management/server/http/handler/util.go +++ b/management/server/http/handler/util.go @@ -5,53 +5,8 @@ import ( "errors" "net/http" "time" - - "github.com/golang-jwt/jwt" ) -// JWTClaims stores information from JWTs -type JWTClaims struct { - UserId string - AccountId string - Domain string -} - -type extractJWTClaims func(r *http.Request, authAudiance string) JWTClaims - -type JWTClaimsExtractor struct { - extractClaimsFromRequestContext extractJWTClaims -} - -// NewJWTClaimsExtractor returns an extractor, and if provided with a function with extractJWTClaims signature, -// then it will use that logic. Uses extractClaimsFromRequestContext by default -func NewJWTClaimsExtractor(e extractJWTClaims) *JWTClaimsExtractor { - var extractFunc extractJWTClaims - if extractFunc = e; extractFunc == nil { - extractFunc = extractClaimsFromRequestContext - } - - return &JWTClaimsExtractor{ - extractClaimsFromRequestContext: extractFunc, - } -} - -// extractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth) -func extractClaimsFromRequestContext(r *http.Request, authAudiance string) JWTClaims { - token := r.Context().Value("user").(*jwt.Token) - claims := token.Claims.(jwt.MapClaims) - jwtClaims := JWTClaims{} - jwtClaims.UserId = claims["sub"].(string) - accountIdClaim, ok := claims[authAudiance+"wt_account_id"] - if ok { - jwtClaims.AccountId = accountIdClaim.(string) - } - domainClaim, ok := claims[authAudiance+"wt_user_domain"] - if ok { - jwtClaims.Domain = domainClaim.(string) - } - return jwtClaims -} - //writeJSONObject simply writes object to the HTTP reponse in JSON format func writeJSONObject(w http.ResponseWriter, obj interface{}) { w.WriteHeader(http.StatusOK) diff --git a/management/server/jwtclaims/claims.go b/management/server/jwtclaims/claims.go new file mode 100644 index 000000000..277f3c20d --- /dev/null +++ b/management/server/jwtclaims/claims.go @@ -0,0 +1,8 @@ +package jwtclaims + +// AuthorizationClaims stores authorization information from JWTs +type AuthorizationClaims struct { + UserId string + AccountId string + Domain string +} diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go new file mode 100644 index 000000000..f6f609d12 --- /dev/null +++ b/management/server/jwtclaims/extractor.go @@ -0,0 +1,51 @@ +package jwtclaims + +import ( + "github.com/golang-jwt/jwt" + "net/http" +) + +const ( + TokenUserProperty = "user" + AccountIDSuffix = "wt_account_id" + DomainIDSuffix = "wt_account_domain" + UserIDClaim = "sub" +) + +// Extract function type +type ExtractClaims func(r *http.Request, authAudiance string) AuthorizationClaims + +// ClaimsExtractor struct that holds the extract function +type ClaimsExtractor struct { + ExtractClaimsFromRequestContext ExtractClaims +} + +// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature, +// then it will use that logic. Uses ExtractClaimsFromRequestContext by default +func NewClaimsExtractor(e ExtractClaims) *ClaimsExtractor { + var extractFunc ExtractClaims + if extractFunc = e; extractFunc == nil { + extractFunc = ExtractClaimsFromRequestContext + } + + return &ClaimsExtractor{ + ExtractClaimsFromRequestContext: extractFunc, + } +} + +// ExtractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth) +func ExtractClaimsFromRequestContext(r *http.Request, authAudiance string) AuthorizationClaims { + token := r.Context().Value(TokenUserProperty).(*jwt.Token) + claims := token.Claims.(jwt.MapClaims) + jwtClaims := AuthorizationClaims{} + jwtClaims.UserId = claims[UserIDClaim].(string) + accountIdClaim, ok := claims[authAudiance+AccountIDSuffix] + if ok { + jwtClaims.AccountId = accountIdClaim.(string) + } + domainClaim, ok := claims[authAudiance+DomainIDSuffix] + if ok { + jwtClaims.Domain = domainClaim.(string) + } + return jwtClaims +} diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go new file mode 100644 index 000000000..7859d187a --- /dev/null +++ b/management/server/jwtclaims/extractor_test.go @@ -0,0 +1,94 @@ +package jwtclaims + +import ( + "context" + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/require" + "net/http" + "testing" +) + +func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request { + claimMaps := jwt.MapClaims{} + if claims.UserId != "" { + claimMaps[UserIDClaim] = claims.UserId + } + if claims.AccountId != "" { + claimMaps[audiance+AccountIDSuffix] = claims.AccountId + } + if claims.Domain != "" { + claimMaps[audiance+DomainIDSuffix] = claims.Domain + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) + r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) + require.NoError(t, err, "creating testing request failed") + testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) //nolint + + return testRequest +} + +func TestExtractClaimsFromRequestContext(t *testing.T) { + + type test struct { + name string + inputAuthorizationClaims AuthorizationClaims + inputAudiance string + testingFunc require.ComparisonAssertionFunc + expectedMSG string + } + + testCase1 := test{ + name: "All Claim Fields", + inputAudiance: "https://login/", + inputAuthorizationClaims: AuthorizationClaims{ + UserId: "test", + Domain: "test.com", + AccountId: "testAcc", + }, + testingFunc: require.EqualValues, + expectedMSG: "extracted claims should match input claims", + } + + testCase2 := test{ + name: "Domain Is Empty", + inputAudiance: "https://login/", + inputAuthorizationClaims: AuthorizationClaims{ + UserId: "test", + AccountId: "testAcc", + }, + testingFunc: require.EqualValues, + expectedMSG: "extracted claims should match input claims", + } + + testCase3 := test{ + name: "Account ID Is Empty", + inputAudiance: "https://login/", + inputAuthorizationClaims: AuthorizationClaims{ + UserId: "test", + Domain: "test.com", + }, + testingFunc: require.EqualValues, + expectedMSG: "extracted claims should match input claims", + } + + testCase4 := test{ + name: "Only User ID Is set", + inputAudiance: "https://login/", + inputAuthorizationClaims: AuthorizationClaims{ + UserId: "test", + }, + testingFunc: require.EqualValues, + expectedMSG: "extracted claims should match input claims", + } + + for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} { + t.Run(testCase.name, func(t *testing.T) { + + request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance) + + extractedClaims := ExtractClaimsFromRequestContext(request, testCase.inputAudiance) + + testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG) + }) + } +}