mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-12 08:58:44 +01:00
105 lines
3.0 KiB
Go
105 lines
3.0 KiB
Go
package jwtclaims
|
|
|
|
import (
|
|
"net/http"
|
|
|
|
"github.com/golang-jwt/jwt"
|
|
)
|
|
|
|
const (
|
|
// TokenUserProperty key for the user property in the request context
|
|
TokenUserProperty = "user"
|
|
// AccountIDSuffix suffix for the account id claim
|
|
AccountIDSuffix = "wt_account_id"
|
|
// DomainIDSuffix suffix for the domain id claim
|
|
DomainIDSuffix = "wt_account_domain"
|
|
// DomainCategorySuffix suffix for the domain category claim
|
|
DomainCategorySuffix = "wt_account_domain_category"
|
|
// UserIDClaim claim for the user id
|
|
UserIDClaim = "sub"
|
|
)
|
|
|
|
// ExtractClaims Extract function type
|
|
type ExtractClaims func(r *http.Request) AuthorizationClaims
|
|
|
|
// ClaimsExtractor struct that holds the extract function
|
|
type ClaimsExtractor struct {
|
|
authAudience string
|
|
userIDClaim string
|
|
|
|
FromRequestContext ExtractClaims
|
|
}
|
|
|
|
// ClaimsExtractorOption is a function that configures the ClaimsExtractor
|
|
type ClaimsExtractorOption func(*ClaimsExtractor)
|
|
|
|
// WithAudience sets the audience for the extractor
|
|
func WithAudience(audience string) ClaimsExtractorOption {
|
|
return func(c *ClaimsExtractor) {
|
|
c.authAudience = audience
|
|
}
|
|
}
|
|
|
|
// WithUserIDClaim sets the user id claim for the extractor
|
|
func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption {
|
|
return func(c *ClaimsExtractor) {
|
|
c.userIDClaim = userIDClaim
|
|
}
|
|
}
|
|
|
|
// WithFromRequestContext sets the function that extracts claims from the request context
|
|
func WithFromRequestContext(ec ExtractClaims) ClaimsExtractorOption {
|
|
return func(c *ClaimsExtractor) {
|
|
c.FromRequestContext = ec
|
|
}
|
|
}
|
|
|
|
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
|
|
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
|
|
func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
|
|
ce := &ClaimsExtractor{}
|
|
for _, option := range options {
|
|
option(ce)
|
|
}
|
|
if ce.FromRequestContext == nil {
|
|
ce.FromRequestContext = ce.fromRequestContext
|
|
}
|
|
if ce.userIDClaim == "" {
|
|
ce.userIDClaim = UserIDClaim
|
|
}
|
|
return ce
|
|
}
|
|
|
|
// FromToken extracts claims from the token (after auth)
|
|
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
|
|
claims := token.Claims.(jwt.MapClaims)
|
|
jwtClaims := AuthorizationClaims{}
|
|
userID, ok := claims[c.userIDClaim].(string)
|
|
if !ok {
|
|
return jwtClaims
|
|
}
|
|
jwtClaims.UserId = userID
|
|
accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix]
|
|
if ok {
|
|
jwtClaims.AccountId = accountIDClaim.(string)
|
|
}
|
|
domainClaim, ok := claims[c.authAudience+DomainIDSuffix]
|
|
if ok {
|
|
jwtClaims.Domain = domainClaim.(string)
|
|
}
|
|
domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix]
|
|
if ok {
|
|
jwtClaims.DomainCategory = domainCategoryClaim.(string)
|
|
}
|
|
return jwtClaims
|
|
}
|
|
|
|
// fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
|
|
func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims {
|
|
if r.Context().Value(TokenUserProperty) == nil {
|
|
return AuthorizationClaims{}
|
|
}
|
|
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
|
|
return c.FromToken(token)
|
|
}
|