mirror of
https://github.com/netbirdio/netbird.git
synced 2025-04-30 14:25:06 +02:00
Jwtclaims package (#242)
* Move JWTClaims logic to its own package * Add extractor tests
This commit is contained in:
parent
5f5cbf7e20
commit
b29948b910
@ -3,6 +3,7 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/wiretrustee/wiretrustee/management/server/jwtclaims"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -15,7 +16,7 @@ import (
|
|||||||
type Peers struct {
|
type Peers struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
authAudience string
|
authAudience string
|
||||||
jwtExtractor JWTClaimsExtractor
|
jwtExtractor jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
//PeerResponse is a response sent to the client
|
//PeerResponse is a response sent to the client
|
||||||
@ -37,7 +38,7 @@ func NewPeers(accountManager server.AccountManager, authAudience string) *Peers
|
|||||||
return &Peers{
|
return &Peers{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
authAudience: authAudience,
|
authAudience: authAudience,
|
||||||
jwtExtractor: *NewJWTClaimsExtractor(nil),
|
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,7 +70,7 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Peers) getPeerAccount(r *http.Request) (*server.Account, error) {
|
func (h *Peers) getPeerAccount(r *http.Request) (*server.Account, error) {
|
||||||
jwtClaims := h.jwtExtractor.extractClaimsFromRequestContext(r, h.authAudience)
|
jwtClaims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
|
||||||
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
|
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"github.com/wiretrustee/wiretrustee/management/server/jwtclaims"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -27,9 +28,9 @@ func initTestMetaData(peer ...*server.Peer) *Peers {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
authAudience: "",
|
authAudience: "",
|
||||||
jwtExtractor: JWTClaimsExtractor{
|
jwtExtractor: jwtclaims.ClaimsExtractor{
|
||||||
extractClaimsFromRequestContext: func(r *http.Request, authAudiance string) JWTClaims {
|
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
|
||||||
return JWTClaims{
|
return jwtclaims.AuthorizationClaims{
|
||||||
UserId: "test_user",
|
UserId: "test_user",
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
|
@ -3,6 +3,7 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/wiretrustee/wiretrustee/management/server/jwtclaims"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -122,8 +123,8 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *SetupKeys) getSetupKeyAccount(r *http.Request) (*server.Account, error) {
|
func (h *SetupKeys) getSetupKeyAccount(r *http.Request) (*server.Account, error) {
|
||||||
extractor := NewJWTClaimsExtractor(nil)
|
extractor := jwtclaims.NewClaimsExtractor(nil)
|
||||||
jwtClaims := extractor.extractClaimsFromRequestContext(r, h.authAudience)
|
jwtClaims := extractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
|
||||||
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
|
account, err := h.accountManager.GetAccountByUserOrAccountId(jwtClaims.UserId, jwtClaims.AccountId, jwtClaims.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -5,53 +5,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// JWTClaims stores information from JWTs
|
|
||||||
type JWTClaims struct {
|
|
||||||
UserId string
|
|
||||||
AccountId string
|
|
||||||
Domain string
|
|
||||||
}
|
|
||||||
|
|
||||||
type extractJWTClaims func(r *http.Request, authAudiance string) JWTClaims
|
|
||||||
|
|
||||||
type JWTClaimsExtractor struct {
|
|
||||||
extractClaimsFromRequestContext extractJWTClaims
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewJWTClaimsExtractor returns an extractor, and if provided with a function with extractJWTClaims signature,
|
|
||||||
// then it will use that logic. Uses extractClaimsFromRequestContext by default
|
|
||||||
func NewJWTClaimsExtractor(e extractJWTClaims) *JWTClaimsExtractor {
|
|
||||||
var extractFunc extractJWTClaims
|
|
||||||
if extractFunc = e; extractFunc == nil {
|
|
||||||
extractFunc = extractClaimsFromRequestContext
|
|
||||||
}
|
|
||||||
|
|
||||||
return &JWTClaimsExtractor{
|
|
||||||
extractClaimsFromRequestContext: extractFunc,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
|
|
||||||
func extractClaimsFromRequestContext(r *http.Request, authAudiance string) JWTClaims {
|
|
||||||
token := r.Context().Value("user").(*jwt.Token)
|
|
||||||
claims := token.Claims.(jwt.MapClaims)
|
|
||||||
jwtClaims := JWTClaims{}
|
|
||||||
jwtClaims.UserId = claims["sub"].(string)
|
|
||||||
accountIdClaim, ok := claims[authAudiance+"wt_account_id"]
|
|
||||||
if ok {
|
|
||||||
jwtClaims.AccountId = accountIdClaim.(string)
|
|
||||||
}
|
|
||||||
domainClaim, ok := claims[authAudiance+"wt_user_domain"]
|
|
||||||
if ok {
|
|
||||||
jwtClaims.Domain = domainClaim.(string)
|
|
||||||
}
|
|
||||||
return jwtClaims
|
|
||||||
}
|
|
||||||
|
|
||||||
//writeJSONObject simply writes object to the HTTP reponse in JSON format
|
//writeJSONObject simply writes object to the HTTP reponse in JSON format
|
||||||
func writeJSONObject(w http.ResponseWriter, obj interface{}) {
|
func writeJSONObject(w http.ResponseWriter, obj interface{}) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
8
management/server/jwtclaims/claims.go
Normal file
8
management/server/jwtclaims/claims.go
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
package jwtclaims
|
||||||
|
|
||||||
|
// AuthorizationClaims stores authorization information from JWTs
|
||||||
|
type AuthorizationClaims struct {
|
||||||
|
UserId string
|
||||||
|
AccountId string
|
||||||
|
Domain string
|
||||||
|
}
|
51
management/server/jwtclaims/extractor.go
Normal file
51
management/server/jwtclaims/extractor.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package jwtclaims
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TokenUserProperty = "user"
|
||||||
|
AccountIDSuffix = "wt_account_id"
|
||||||
|
DomainIDSuffix = "wt_account_domain"
|
||||||
|
UserIDClaim = "sub"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Extract function type
|
||||||
|
type ExtractClaims func(r *http.Request, authAudiance string) AuthorizationClaims
|
||||||
|
|
||||||
|
// ClaimsExtractor struct that holds the extract function
|
||||||
|
type ClaimsExtractor struct {
|
||||||
|
ExtractClaimsFromRequestContext ExtractClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
|
||||||
|
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
|
||||||
|
func NewClaimsExtractor(e ExtractClaims) *ClaimsExtractor {
|
||||||
|
var extractFunc ExtractClaims
|
||||||
|
if extractFunc = e; extractFunc == nil {
|
||||||
|
extractFunc = ExtractClaimsFromRequestContext
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ClaimsExtractor{
|
||||||
|
ExtractClaimsFromRequestContext: extractFunc,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
|
||||||
|
func ExtractClaimsFromRequestContext(r *http.Request, authAudiance string) AuthorizationClaims {
|
||||||
|
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
|
||||||
|
claims := token.Claims.(jwt.MapClaims)
|
||||||
|
jwtClaims := AuthorizationClaims{}
|
||||||
|
jwtClaims.UserId = claims[UserIDClaim].(string)
|
||||||
|
accountIdClaim, ok := claims[authAudiance+AccountIDSuffix]
|
||||||
|
if ok {
|
||||||
|
jwtClaims.AccountId = accountIdClaim.(string)
|
||||||
|
}
|
||||||
|
domainClaim, ok := claims[authAudiance+DomainIDSuffix]
|
||||||
|
if ok {
|
||||||
|
jwtClaims.Domain = domainClaim.(string)
|
||||||
|
}
|
||||||
|
return jwtClaims
|
||||||
|
}
|
94
management/server/jwtclaims/extractor_test.go
Normal file
94
management/server/jwtclaims/extractor_test.go
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
package jwtclaims
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request {
|
||||||
|
claimMaps := jwt.MapClaims{}
|
||||||
|
if claims.UserId != "" {
|
||||||
|
claimMaps[UserIDClaim] = claims.UserId
|
||||||
|
}
|
||||||
|
if claims.AccountId != "" {
|
||||||
|
claimMaps[audiance+AccountIDSuffix] = claims.AccountId
|
||||||
|
}
|
||||||
|
if claims.Domain != "" {
|
||||||
|
claimMaps[audiance+DomainIDSuffix] = claims.Domain
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||||
|
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||||
|
require.NoError(t, err, "creating testing request failed")
|
||||||
|
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) //nolint
|
||||||
|
|
||||||
|
return testRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
inputAuthorizationClaims AuthorizationClaims
|
||||||
|
inputAudiance string
|
||||||
|
testingFunc require.ComparisonAssertionFunc
|
||||||
|
expectedMSG string
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase1 := test{
|
||||||
|
name: "All Claim Fields",
|
||||||
|
inputAudiance: "https://login/",
|
||||||
|
inputAuthorizationClaims: AuthorizationClaims{
|
||||||
|
UserId: "test",
|
||||||
|
Domain: "test.com",
|
||||||
|
AccountId: "testAcc",
|
||||||
|
},
|
||||||
|
testingFunc: require.EqualValues,
|
||||||
|
expectedMSG: "extracted claims should match input claims",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase2 := test{
|
||||||
|
name: "Domain Is Empty",
|
||||||
|
inputAudiance: "https://login/",
|
||||||
|
inputAuthorizationClaims: AuthorizationClaims{
|
||||||
|
UserId: "test",
|
||||||
|
AccountId: "testAcc",
|
||||||
|
},
|
||||||
|
testingFunc: require.EqualValues,
|
||||||
|
expectedMSG: "extracted claims should match input claims",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase3 := test{
|
||||||
|
name: "Account ID Is Empty",
|
||||||
|
inputAudiance: "https://login/",
|
||||||
|
inputAuthorizationClaims: AuthorizationClaims{
|
||||||
|
UserId: "test",
|
||||||
|
Domain: "test.com",
|
||||||
|
},
|
||||||
|
testingFunc: require.EqualValues,
|
||||||
|
expectedMSG: "extracted claims should match input claims",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCase4 := test{
|
||||||
|
name: "Only User ID Is set",
|
||||||
|
inputAudiance: "https://login/",
|
||||||
|
inputAuthorizationClaims: AuthorizationClaims{
|
||||||
|
UserId: "test",
|
||||||
|
},
|
||||||
|
testingFunc: require.EqualValues,
|
||||||
|
expectedMSG: "extracted claims should match input claims",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
|
||||||
|
|
||||||
|
extractedClaims := ExtractClaimsFromRequestContext(request, testCase.inputAudiance)
|
||||||
|
|
||||||
|
testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user