mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-23 14:28:51 +01:00
183 lines
5.1 KiB
Go
183 lines
5.1 KiB
Go
package jwtclaims
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"testing"
|
|
|
|
"github.com/golang-jwt/jwt"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request {
|
|
claimMaps := jwt.MapClaims{}
|
|
if claims.UserId != "" {
|
|
claimMaps[UserIDClaim] = claims.UserId
|
|
}
|
|
if claims.AccountId != "" {
|
|
claimMaps[audiance+AccountIDSuffix] = claims.AccountId
|
|
}
|
|
if claims.Domain != "" {
|
|
claimMaps[audiance+DomainIDSuffix] = claims.Domain
|
|
}
|
|
if claims.DomainCategory != "" {
|
|
claimMaps[audiance+DomainCategorySuffix] = claims.DomainCategory
|
|
}
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
|
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
|
|
require.NoError(t, err, "creating testing request failed")
|
|
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) // nolint
|
|
|
|
return testRequest
|
|
}
|
|
|
|
func TestExtractClaimsFromRequestContext(t *testing.T) {
|
|
type test struct {
|
|
name string
|
|
inputAuthorizationClaims AuthorizationClaims
|
|
inputAudiance string
|
|
testingFunc require.ComparisonAssertionFunc
|
|
expectedMSG string
|
|
}
|
|
|
|
testCase1 := test{
|
|
name: "All Claim Fields",
|
|
inputAudiance: "https://login/",
|
|
inputAuthorizationClaims: AuthorizationClaims{
|
|
UserId: "test",
|
|
Domain: "test.com",
|
|
AccountId: "testAcc",
|
|
DomainCategory: "public",
|
|
},
|
|
testingFunc: require.EqualValues,
|
|
expectedMSG: "extracted claims should match input claims",
|
|
}
|
|
|
|
testCase2 := test{
|
|
name: "Domain Is Empty",
|
|
inputAudiance: "https://login/",
|
|
inputAuthorizationClaims: AuthorizationClaims{
|
|
UserId: "test",
|
|
AccountId: "testAcc",
|
|
},
|
|
testingFunc: require.EqualValues,
|
|
expectedMSG: "extracted claims should match input claims",
|
|
}
|
|
|
|
testCase3 := test{
|
|
name: "Account ID Is Empty",
|
|
inputAudiance: "https://login/",
|
|
inputAuthorizationClaims: AuthorizationClaims{
|
|
UserId: "test",
|
|
Domain: "test.com",
|
|
},
|
|
testingFunc: require.EqualValues,
|
|
expectedMSG: "extracted claims should match input claims",
|
|
}
|
|
|
|
testCase4 := test{
|
|
name: "Category Is Empty",
|
|
inputAudiance: "https://login/",
|
|
inputAuthorizationClaims: AuthorizationClaims{
|
|
UserId: "test",
|
|
Domain: "test.com",
|
|
AccountId: "testAcc",
|
|
},
|
|
testingFunc: require.EqualValues,
|
|
expectedMSG: "extracted claims should match input claims",
|
|
}
|
|
|
|
testCase5 := test{
|
|
name: "Only User ID Is set",
|
|
inputAudiance: "https://login/",
|
|
inputAuthorizationClaims: AuthorizationClaims{
|
|
UserId: "test",
|
|
},
|
|
testingFunc: require.EqualValues,
|
|
expectedMSG: "extracted claims should match input claims",
|
|
}
|
|
|
|
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
|
|
|
|
extractor := NewClaimsExtractor(WithAudience(testCase.inputAudiance))
|
|
extractedClaims := extractor.FromRequestContext(request)
|
|
|
|
testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtractClaimsSetOptions(t *testing.T) {
|
|
type test struct {
|
|
name string
|
|
extractor *ClaimsExtractor
|
|
check func(t *testing.T, c test)
|
|
}
|
|
|
|
testCase1 := test{
|
|
name: "No custom options",
|
|
extractor: NewClaimsExtractor(),
|
|
check: func(t *testing.T, c test) {
|
|
if c.extractor.authAudience != "" {
|
|
t.Error("audience should be empty")
|
|
return
|
|
}
|
|
if c.extractor.userIDClaim != UserIDClaim {
|
|
t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim)
|
|
return
|
|
}
|
|
if c.extractor.FromRequestContext == nil {
|
|
t.Error("from request context should not be nil")
|
|
return
|
|
}
|
|
},
|
|
}
|
|
|
|
testCase2 := test{
|
|
name: "Custom audience",
|
|
extractor: NewClaimsExtractor(WithAudience("https://login/")),
|
|
check: func(t *testing.T, c test) {
|
|
if c.extractor.authAudience != "https://login/" {
|
|
t.Errorf("audience expected %s, got %s", "https://login/", c.extractor.authAudience)
|
|
return
|
|
}
|
|
},
|
|
}
|
|
|
|
testCase3 := test{
|
|
name: "Custom user id claim",
|
|
extractor: NewClaimsExtractor(WithUserIDClaim("customUserId")),
|
|
check: func(t *testing.T, c test) {
|
|
if c.extractor.userIDClaim != "customUserId" {
|
|
t.Errorf("user id claim expected %s, got %s", "customUserId", c.extractor.userIDClaim)
|
|
return
|
|
}
|
|
},
|
|
}
|
|
|
|
testCase4 := test{
|
|
name: "Custom extractor from request context",
|
|
extractor: NewClaimsExtractor(
|
|
WithFromRequestContext(func(r *http.Request) AuthorizationClaims {
|
|
return AuthorizationClaims{
|
|
UserId: "testCustomRequest",
|
|
}
|
|
})),
|
|
check: func(t *testing.T, c test) {
|
|
claims := c.extractor.FromRequestContext(&http.Request{})
|
|
if claims.UserId != "testCustomRequest" {
|
|
t.Errorf("user id claim expected %s, got %s", "testCustomRequest", claims.UserId)
|
|
return
|
|
}
|
|
},
|
|
}
|
|
|
|
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
testCase.check(t, testCase)
|
|
})
|
|
}
|
|
}
|