[management] Add support to ECDSA public Keys (#2461)

Update the JWT validation logic to handle ECDSA keys in addition to the existing RSA keys

---------

Co-authored-by: Harry Kodden <harry.kodden@surf.nl>
Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
Harry Kodden 2024-08-27 16:37:55 +02:00 committed by GitHub
parent be6bc46bcd
commit 00944bcdbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,14 +1,12 @@
package jwtclaims
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/binary"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"math/big"
@ -41,11 +39,6 @@ type Options struct {
// When set, all requests with the OPTIONS method will use authentication
// Default: false
EnableAuthOnOptions bool
// When set, the middelware verifies that tokens are signed with the specific signing algorithm
// If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
// Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
// Default: nil
SigningMethod jwt.SigningMethod
}
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
@ -54,6 +47,18 @@ type Jwks struct {
expiresInTime time.Time
}
// The supported elliptic curves types
const (
// p256 represents a cryptographic elliptical curve type.
p256 = "P-256"
// p384 represents a cryptographic elliptical curve type.
p384 = "P-384"
// p521 represents a cryptographic elliptical curve type.
p521 = "P-521"
)
// JSONWebKey is a representation of a Jason Web Key
type JSONWebKey struct {
Kty string `json:"kty"`
@ -61,6 +66,9 @@ type JSONWebKey struct {
Use string `json:"use"`
N string `json:"n"`
E string `json:"e"`
Crv string `json:"crv"`
X string `json:"x"`
Y string `json:"y"`
X5c []string `json:"x5c"`
}
@ -115,15 +123,14 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
}
}
cert, err := getPemCert(ctx, token, keys)
publicKey, err := getPublicKey(ctx, token, keys)
if err != nil {
log.WithContext(ctx).Errorf("getPublicKey error: %s", err)
return nil, err
}
result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
return result, nil
return publicKey, nil
},
SigningMethod: jwt.SigningMethodRS256,
EnableAuthOnOptions: false,
}
@ -159,15 +166,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
// Check if there was an error in parsing...
if err != nil {
log.WithContext(ctx).Errorf("error parsing token: %v", err)
return nil, fmt.Errorf("Error parsing token: %w", err)
}
if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] {
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
m.options.SigningMethod.Alg(),
parsedToken.Header["alg"])
log.WithContext(ctx).Debugf("error validating token algorithm: %s", errorMsg)
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
return nil, fmt.Errorf("error parsing token: %w", err)
}
// Check if the parsed token is valid...
@ -205,9 +204,8 @@ func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) {
return jwks, err
}
func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, error) {
func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{}, error) {
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
cert := ""
for k := range jwks.Keys {
if token.Header["kid"] != jwks.Keys[k].Kid {
@ -215,73 +213,79 @@ func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, erro
}
if len(jwks.Keys[k].X5c) != 0 {
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
return cert, nil
}
log.WithContext(ctx).Debugf("generating validation pem from JWK")
return generatePemFromJWK(jwks.Keys[k])
cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
return jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
}
return cert, errors.New("unable to find appropriate key")
if jwks.Keys[k].Kty == "RSA" {
log.WithContext(ctx).Debugf("generating PublicKey from RSA JWK")
return getPublicKeyFromRSA(jwks.Keys[k])
}
if jwks.Keys[k].Kty == "EC" {
log.WithContext(ctx).Debugf("generating PublicKey from ECDSA JWK")
return getPublicKeyFromECDSA(jwks.Keys[k])
}
func generatePemFromJWK(jwk JSONWebKey) (string, error) {
decodedModulus, err := base64.RawURLEncoding.DecodeString(jwk.N)
log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty)
}
return nil, errors.New("unable to find appropriate key")
}
func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" {
return nil, fmt.Errorf("ecdsa key incomplete")
}
var xCoordinate []byte
if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil {
return nil, err
}
var yCoordinate []byte
if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil {
return nil, err
}
publicKey = &ecdsa.PublicKey{}
var curve elliptic.Curve
switch jwk.Crv {
case p256:
curve = elliptic.P256()
case p384:
curve = elliptic.P384()
case p521:
curve = elliptic.P521()
}
publicKey.Curve = curve
publicKey.X = big.NewInt(0).SetBytes(xCoordinate)
publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)
return publicKey, nil
}
func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) {
decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
return "", fmt.Errorf("unable to decode JWK modulus, error: %s", err)
return nil, err
}
intModules := big.NewInt(0)
intModules.SetBytes(decodedModulus)
exponent, err := convertExponentStringToInt(jwk.E)
decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return "", fmt.Errorf("unable to decode JWK exponent, error: %s", err)
return nil, err
}
publicKey := &rsa.PublicKey{
N: intModules,
E: exponent,
}
var n, e big.Int
e.SetBytes(decodedE)
n.SetBytes(decodedN)
derKey, err := x509.MarshalPKIXPublicKey(publicKey)
if err != nil {
return "", fmt.Errorf("unable to convert public key to DER, error: %s", err)
}
block := &pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: derKey,
}
var out bytes.Buffer
err = pem.Encode(&out, block)
if err != nil {
return "", fmt.Errorf("unable to encode Pem block , error: %s", err)
}
return out.String(), nil
}
func convertExponentStringToInt(stringExponent string) (int, error) {
decodedString, err := base64.StdEncoding.DecodeString(stringExponent)
if err != nil {
return 0, err
}
exponentBytes := decodedString
if len(decodedString) < 8 {
exponentBytes = make([]byte, 8-len(decodedString), 8)
exponentBytes = append(exponentBytes, decodedString...)
}
bytesReader := bytes.NewReader(exponentBytes)
var exponent uint64
err = binary.Read(bytesReader, binary.BigEndian, &exponent)
if err != nil {
return 0, err
}
return int(exponent), nil
return &rsa.PublicKey{
E: int(e.Int64()),
N: &n,
}, nil
}
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
@ -306,3 +310,4 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
return 0
}