validate keys for idp's with key rotation mechanism

This commit is contained in:
Bethuel 2023-04-14 12:20:34 +03:00
parent a89808ecae
commit 9f352c1b7e
2 changed files with 53 additions and 2 deletions

View File

@ -80,6 +80,8 @@ type HttpServerConfig struct {
AuthKeysLocation string
// OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration
OIDCConfigEndpoint string
// KeyRotationEnabled identifies the signing key is currently being rotated or not
KeyRotationEnabled bool
}
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)

View File

@ -12,6 +12,9 @@ import (
"fmt"
"math/big"
"net/http"
"strconv"
"strings"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
@ -45,7 +48,8 @@ type Options struct {
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct {
Keys []JSONWebKey `json:"keys"`
Keys []JSONWebKey `json:"keys"`
expiresInTime time.Time
}
// JSONWebKey is a representation of a Jason Web Key
@ -64,7 +68,7 @@ type JWTValidator struct {
}
// NewJWTValidator constructor
func NewJWTValidator(issuer string, audienceList []string, keysLocation string) (*JWTValidator, error) {
func NewJWTValidator(issuer string, audienceList []string, keysLocation string, keyRotationEnabled bool) (*JWTValidator, error) {
keys, err := getPemKeys(keysLocation)
if err != nil {
return nil, err
@ -89,6 +93,19 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string)
return token, errors.New("invalid issuer")
}
// If keys are rotated, verify the keys prior to token validation
if keyRotationEnabled {
// If the keys are invalid, retrieve new ones
if !keys.stillValid() {
keys, err = getPemKeys(keysLocation)
if err != nil {
log.Errorf("cannot get JSONWebKey: %v", err)
return nil, err
}
}
}
cert, err := getPemCert(token, keys)
if err != nil {
return nil, err
@ -154,6 +171,11 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
return parsedToken, nil
}
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
func (jwks *Jwks) stillValid() bool {
return jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
}
func getPemKeys(keysLocation string) (*Jwks, error) {
resp, err := http.Get(keysLocation)
if err != nil {
@ -167,6 +189,10 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
return jwks, err
}
cacheControlHeader := resp.Header.Get("Cache-Control")
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
return jwks, err
}
@ -248,3 +274,26 @@ func convertExponentStringToInt(stringExponent string) (int, error) {
return int(exponent), nil
}
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
func getMaxAgeFromCacheHeader(cacheControl string) int {
// Split into individual directives
directives := strings.Split(cacheControl, ",")
for _, directive := range directives {
directive = strings.TrimSpace(directive)
if strings.HasPrefix(directive, "max-age=") {
// Extract the max-age value
maxAgeStr := strings.TrimPrefix(directive, "max-age=")
maxAge, err := strconv.Atoi(maxAgeStr)
if err != nil {
log.Debugf("error parsing max-age: %v", err)
return 0
}
return maxAge
}
}
return 0
}