netbird/management/server/jwtclaims/extractor_test.go
Givi Khojanashvili 3ec8274b8e
Feature: add custom id claim (#667)
This feature allows using the custom claim in the JWT token as a user ID.

Refactor claims extractor with options support

Add is_current to the user API response
2023-02-03 21:47:20 +01:00

183 lines
5.1 KiB
Go

package jwtclaims
import (
"context"
"net/http"
"testing"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
)
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request {
claimMaps := jwt.MapClaims{}
if claims.UserId != "" {
claimMaps[UserIDClaim] = claims.UserId
}
if claims.AccountId != "" {
claimMaps[audiance+AccountIDSuffix] = claims.AccountId
}
if claims.Domain != "" {
claimMaps[audiance+DomainIDSuffix] = claims.Domain
}
if claims.DomainCategory != "" {
claimMaps[audiance+DomainCategorySuffix] = claims.DomainCategory
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
require.NoError(t, err, "creating testing request failed")
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) //nolint
return testRequest
}
func TestExtractClaimsFromRequestContext(t *testing.T) {
type test struct {
name string
inputAuthorizationClaims AuthorizationClaims
inputAudiance string
testingFunc require.ComparisonAssertionFunc
expectedMSG string
}
testCase1 := test{
name: "All Claim Fields",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Domain: "test.com",
AccountId: "testAcc",
DomainCategory: "public",
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
testCase2 := test{
name: "Domain Is Empty",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
AccountId: "testAcc",
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
testCase3 := test{
name: "Account ID Is Empty",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Domain: "test.com",
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
testCase4 := test{
name: "Category Is Empty",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Domain: "test.com",
AccountId: "testAcc",
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
testCase5 := test{
name: "Only User ID Is set",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
t.Run(testCase.name, func(t *testing.T) {
request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
extractor := NewClaimsExtractor(WithAudience(testCase.inputAudiance))
extractedClaims := extractor.FromRequestContext(request)
testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG)
})
}
}
func TestExtractClaimsSetOptions(t *testing.T) {
type test struct {
name string
extractor *ClaimsExtractor
check func(t *testing.T, c test)
}
testCase1 := test{
name: "No custom options",
extractor: NewClaimsExtractor(),
check: func(t *testing.T, c test) {
if c.extractor.authAudience != "" {
t.Error("audience should be empty")
return
}
if c.extractor.userIDClaim != UserIDClaim {
t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim)
return
}
if c.extractor.FromRequestContext == nil {
t.Error("from request context should not be nil")
return
}
},
}
testCase2 := test{
name: "Custom audience",
extractor: NewClaimsExtractor(WithAudience("https://login/")),
check: func(t *testing.T, c test) {
if c.extractor.authAudience != "https://login/" {
t.Errorf("audience expected %s, got %s", "https://login/", c.extractor.authAudience)
return
}
},
}
testCase3 := test{
name: "Custom user id claim",
extractor: NewClaimsExtractor(WithUserIDClaim("customUserId")),
check: func(t *testing.T, c test) {
if c.extractor.userIDClaim != "customUserId" {
t.Errorf("user id claim expected %s, got %s", "customUserId", c.extractor.userIDClaim)
return
}
},
}
testCase4 := test{
name: "Custom extractor from request context",
extractor: NewClaimsExtractor(
WithFromRequestContext(func(r *http.Request) AuthorizationClaims {
return AuthorizationClaims{
UserId: "testCustomRequest",
}
})),
check: func(t *testing.T, c test) {
claims := c.extractor.FromRequestContext(&http.Request{})
if claims.UserId != "testCustomRequest" {
t.Errorf("user id claim expected %s, got %s", "testCustomRequest", claims.UserId)
return
}
},
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
t.Run(testCase.name, func(t *testing.T) {
testCase.check(t, testCase)
})
}
}