netbird/management/server/jwtclaims/extractor.go
Givi Khojanashvili 3ec8274b8e
Feature: add custom id claim (#667)
This feature allows using the custom claim in the JWT token as a user ID.

Refactor claims extractor with options support

Add is_current to the user API response
2023-02-03 21:47:20 +01:00

100 lines
2.8 KiB
Go

package jwtclaims
import (
"net/http"
"github.com/golang-jwt/jwt"
)
const (
TokenUserProperty = "user"
AccountIDSuffix = "wt_account_id"
DomainIDSuffix = "wt_account_domain"
DomainCategorySuffix = "wt_account_domain_category"
UserIDClaim = "sub"
)
// Extract function type
type ExtractClaims func(r *http.Request) AuthorizationClaims
// ClaimsExtractor struct that holds the extract function
type ClaimsExtractor struct {
authAudience string
userIDClaim string
FromRequestContext ExtractClaims
}
// ClaimsExtractorOption is a function that configures the ClaimsExtractor
type ClaimsExtractorOption func(*ClaimsExtractor)
// WithAudience sets the audience for the extractor
func WithAudience(audience string) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.authAudience = audience
}
}
// WithUserIDClaim sets the user id claim for the extractor
func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.userIDClaim = userIDClaim
}
}
// WithFromRequestContext sets the function that extracts claims from the request context
func WithFromRequestContext(ec ExtractClaims) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.FromRequestContext = ec
}
}
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
ce := &ClaimsExtractor{}
for _, option := range options {
option(ce)
}
if ce.FromRequestContext == nil {
ce.FromRequestContext = ce.fromRequestContext
}
if ce.userIDClaim == "" {
ce.userIDClaim = UserIDClaim
}
return ce
}
// FromToken extracts claims from the token (after auth)
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
claims := token.Claims.(jwt.MapClaims)
jwtClaims := AuthorizationClaims{}
userID, ok := claims[c.userIDClaim].(string)
if !ok {
return jwtClaims
}
jwtClaims.UserId = userID
accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix]
if ok {
jwtClaims.AccountId = accountIDClaim.(string)
}
domainClaim, ok := claims[c.authAudience+DomainIDSuffix]
if ok {
jwtClaims.Domain = domainClaim.(string)
}
domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix]
if ok {
jwtClaims.DomainCategory = domainCategoryClaim.(string)
}
return jwtClaims
}
// fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims {
if r.Context().Value(TokenUserProperty) == nil {
return AuthorizationClaims{}
}
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
return c.FromToken(token)
}