2022-02-23 20:02:02 +01:00
|
|
|
package jwtclaims
|
|
|
|
|
|
|
|
import (
|
|
|
|
"net/http"
|
2023-08-18 19:23:11 +02:00
|
|
|
"time"
|
2023-02-03 21:47:20 +01:00
|
|
|
|
|
|
|
"github.com/golang-jwt/jwt"
|
2022-02-23 20:02:02 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
2023-03-30 17:32:44 +02:00
|
|
|
// TokenUserProperty key for the user property in the request context
|
2023-03-30 18:54:55 +02:00
|
|
|
TokenUserProperty = "user"
|
2023-03-30 17:32:44 +02:00
|
|
|
// AccountIDSuffix suffix for the account id claim
|
2023-03-30 18:54:55 +02:00
|
|
|
AccountIDSuffix = "wt_account_id"
|
2023-03-30 17:32:44 +02:00
|
|
|
// DomainIDSuffix suffix for the domain id claim
|
2023-03-30 18:54:55 +02:00
|
|
|
DomainIDSuffix = "wt_account_domain"
|
2023-03-30 17:32:44 +02:00
|
|
|
// DomainCategorySuffix suffix for the domain category claim
|
2023-03-30 18:54:55 +02:00
|
|
|
DomainCategorySuffix = "wt_account_domain_category"
|
2023-03-30 17:32:44 +02:00
|
|
|
// UserIDClaim claim for the user id
|
2023-03-30 18:54:55 +02:00
|
|
|
UserIDClaim = "sub"
|
2023-08-18 19:23:11 +02:00
|
|
|
// LastLoginSuffix claim for the last login
|
|
|
|
LastLoginSuffix = "nb_last_login"
|
2022-02-23 20:02:02 +01:00
|
|
|
)
|
|
|
|
|
2023-03-30 18:59:35 +02:00
|
|
|
// ExtractClaims Extract function type
|
2023-02-03 21:47:20 +01:00
|
|
|
type ExtractClaims func(r *http.Request) AuthorizationClaims
|
2022-02-23 20:02:02 +01:00
|
|
|
|
|
|
|
// ClaimsExtractor struct that holds the extract function
|
|
|
|
type ClaimsExtractor struct {
|
2023-02-03 21:47:20 +01:00
|
|
|
authAudience string
|
|
|
|
userIDClaim string
|
|
|
|
|
|
|
|
FromRequestContext ExtractClaims
|
2022-02-23 20:02:02 +01:00
|
|
|
}
|
|
|
|
|
2023-02-03 21:47:20 +01:00
|
|
|
// ClaimsExtractorOption is a function that configures the ClaimsExtractor
|
|
|
|
type ClaimsExtractorOption func(*ClaimsExtractor)
|
|
|
|
|
|
|
|
// WithAudience sets the audience for the extractor
|
|
|
|
func WithAudience(audience string) ClaimsExtractorOption {
|
|
|
|
return func(c *ClaimsExtractor) {
|
|
|
|
c.authAudience = audience
|
2022-02-23 20:02:02 +01:00
|
|
|
}
|
2023-02-03 21:47:20 +01:00
|
|
|
}
|
2022-02-23 20:02:02 +01:00
|
|
|
|
2023-02-03 21:47:20 +01:00
|
|
|
// WithUserIDClaim sets the user id claim for the extractor
|
|
|
|
func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption {
|
|
|
|
return func(c *ClaimsExtractor) {
|
|
|
|
c.userIDClaim = userIDClaim
|
2022-02-23 20:02:02 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-02-03 21:47:20 +01:00
|
|
|
// WithFromRequestContext sets the function that extracts claims from the request context
|
|
|
|
func WithFromRequestContext(ec ExtractClaims) ClaimsExtractorOption {
|
|
|
|
return func(c *ClaimsExtractor) {
|
|
|
|
c.FromRequestContext = ec
|
2022-06-14 10:32:54 +02:00
|
|
|
}
|
2022-05-05 20:02:15 +02:00
|
|
|
}
|
|
|
|
|
2023-02-03 21:47:20 +01:00
|
|
|
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
|
|
|
|
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
|
|
|
|
func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
|
|
|
|
ce := &ClaimsExtractor{}
|
|
|
|
for _, option := range options {
|
|
|
|
option(ce)
|
|
|
|
}
|
|
|
|
if ce.FromRequestContext == nil {
|
|
|
|
ce.FromRequestContext = ce.fromRequestContext
|
|
|
|
}
|
|
|
|
if ce.userIDClaim == "" {
|
2023-03-30 18:59:35 +02:00
|
|
|
ce.userIDClaim = UserIDClaim
|
2023-02-03 21:47:20 +01:00
|
|
|
}
|
|
|
|
return ce
|
|
|
|
}
|
|
|
|
|
|
|
|
// FromToken extracts claims from the token (after auth)
|
|
|
|
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
|
2022-02-23 20:02:02 +01:00
|
|
|
claims := token.Claims.(jwt.MapClaims)
|
2023-06-27 16:51:05 +02:00
|
|
|
jwtClaims := AuthorizationClaims{
|
|
|
|
Raw: claims,
|
|
|
|
}
|
2023-02-03 21:47:20 +01:00
|
|
|
userID, ok := claims[c.userIDClaim].(string)
|
|
|
|
if !ok {
|
|
|
|
return jwtClaims
|
|
|
|
}
|
|
|
|
jwtClaims.UserId = userID
|
2023-03-30 18:54:55 +02:00
|
|
|
accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix]
|
2022-02-23 20:02:02 +01:00
|
|
|
if ok {
|
2023-02-03 21:47:20 +01:00
|
|
|
jwtClaims.AccountId = accountIDClaim.(string)
|
2022-02-23 20:02:02 +01:00
|
|
|
}
|
2023-03-30 18:54:55 +02:00
|
|
|
domainClaim, ok := claims[c.authAudience+DomainIDSuffix]
|
2022-02-23 20:02:02 +01:00
|
|
|
if ok {
|
|
|
|
jwtClaims.Domain = domainClaim.(string)
|
|
|
|
}
|
2023-03-30 18:54:55 +02:00
|
|
|
domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix]
|
2022-03-01 15:22:18 +01:00
|
|
|
if ok {
|
|
|
|
jwtClaims.DomainCategory = domainCategoryClaim.(string)
|
|
|
|
}
|
2023-08-18 19:23:11 +02:00
|
|
|
LastLoginClaimString, ok := claims[c.authAudience+LastLoginSuffix]
|
|
|
|
if ok {
|
|
|
|
jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string))
|
|
|
|
}
|
2022-02-23 20:02:02 +01:00
|
|
|
return jwtClaims
|
|
|
|
}
|
2023-02-03 21:47:20 +01:00
|
|
|
|
2023-08-18 19:23:11 +02:00
|
|
|
func parseTime(timeString string) time.Time {
|
|
|
|
if timeString == "" {
|
|
|
|
return time.Time{}
|
|
|
|
}
|
|
|
|
parsedTime, err := time.Parse(time.RFC3339, timeString)
|
|
|
|
if err != nil {
|
|
|
|
return time.Time{}
|
|
|
|
}
|
|
|
|
return parsedTime
|
|
|
|
}
|
|
|
|
|
2023-02-03 21:47:20 +01:00
|
|
|
// fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
|
|
|
|
func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims {
|
|
|
|
if r.Context().Value(TokenUserProperty) == nil {
|
|
|
|
return AuthorizationClaims{}
|
|
|
|
}
|
|
|
|
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
|
|
|
|
return c.FromToken(token)
|
|
|
|
}
|