[management] Add support to ECDSA public Keys (#2461)

Update the JWT validation logic to handle ECDSA keys in addition to the existing RSA keys

---------

Co-authored-by: Harry Kodden <harry.kodden@surf.nl>
Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
Harry Kodden 2024-08-27 16:37:55 +02:00 committed by GitHub
parent be6bc46bcd
commit 00944bcdbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,14 +1,12 @@
package jwtclaims package jwtclaims
import ( import (
"bytes"
"context" "context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa" "crypto/rsa"
"crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/binary"
"encoding/json" "encoding/json"
"encoding/pem"
"errors" "errors"
"fmt" "fmt"
"math/big" "math/big"
@ -41,11 +39,6 @@ type Options struct {
// When set, all requests with the OPTIONS method will use authentication // When set, all requests with the OPTIONS method will use authentication
// Default: false // Default: false
EnableAuthOnOptions bool EnableAuthOnOptions bool
// When set, the middelware verifies that tokens are signed with the specific signing algorithm
// If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
// Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
// Default: nil
SigningMethod jwt.SigningMethod
} }
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
@ -54,6 +47,18 @@ type Jwks struct {
expiresInTime time.Time expiresInTime time.Time
} }
// The supported elliptic curves types
const (
// p256 represents a cryptographic elliptical curve type.
p256 = "P-256"
// p384 represents a cryptographic elliptical curve type.
p384 = "P-384"
// p521 represents a cryptographic elliptical curve type.
p521 = "P-521"
)
// JSONWebKey is a representation of a Jason Web Key // JSONWebKey is a representation of a Jason Web Key
type JSONWebKey struct { type JSONWebKey struct {
Kty string `json:"kty"` Kty string `json:"kty"`
@ -61,6 +66,9 @@ type JSONWebKey struct {
Use string `json:"use"` Use string `json:"use"`
N string `json:"n"` N string `json:"n"`
E string `json:"e"` E string `json:"e"`
Crv string `json:"crv"`
X string `json:"x"`
Y string `json:"y"`
X5c []string `json:"x5c"` X5c []string `json:"x5c"`
} }
@ -115,15 +123,14 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
} }
} }
cert, err := getPemCert(ctx, token, keys) publicKey, err := getPublicKey(ctx, token, keys)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("getPublicKey error: %s", err)
return nil, err return nil, err
} }
result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert)) return publicKey, nil
return result, nil
}, },
SigningMethod: jwt.SigningMethodRS256,
EnableAuthOnOptions: false, EnableAuthOnOptions: false,
} }
@ -159,15 +166,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
// Check if there was an error in parsing... // Check if there was an error in parsing...
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error parsing token: %v", err) log.WithContext(ctx).Errorf("error parsing token: %v", err)
return nil, fmt.Errorf("Error parsing token: %w", err) return nil, fmt.Errorf("error parsing token: %w", err)
}
if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] {
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
m.options.SigningMethod.Alg(),
parsedToken.Header["alg"])
log.WithContext(ctx).Debugf("error validating token algorithm: %s", errorMsg)
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
} }
// Check if the parsed token is valid... // Check if the parsed token is valid...
@ -205,9 +204,8 @@ func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) {
return jwks, err return jwks, err
} }
func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, error) { func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{}, error) {
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time // todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
cert := ""
for k := range jwks.Keys { for k := range jwks.Keys {
if token.Header["kid"] != jwks.Keys[k].Kid { if token.Header["kid"] != jwks.Keys[k].Kid {
@ -215,73 +213,79 @@ func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, erro
} }
if len(jwks.Keys[k].X5c) != 0 { if len(jwks.Keys[k].X5c) != 0 {
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----" cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
return cert, nil return jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
}
log.WithContext(ctx).Debugf("generating validation pem from JWK")
return generatePemFromJWK(jwks.Keys[k])
} }
return cert, errors.New("unable to find appropriate key") if jwks.Keys[k].Kty == "RSA" {
log.WithContext(ctx).Debugf("generating PublicKey from RSA JWK")
return getPublicKeyFromRSA(jwks.Keys[k])
}
if jwks.Keys[k].Kty == "EC" {
log.WithContext(ctx).Debugf("generating PublicKey from ECDSA JWK")
return getPublicKeyFromECDSA(jwks.Keys[k])
}
log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty)
}
return nil, errors.New("unable to find appropriate key")
} }
func generatePemFromJWK(jwk JSONWebKey) (string, error) { func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
decodedModulus, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil { if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" {
return "", fmt.Errorf("unable to decode JWK modulus, error: %s", err) return nil, fmt.Errorf("ecdsa key incomplete")
} }
intModules := big.NewInt(0) var xCoordinate []byte
intModules.SetBytes(decodedModulus) if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil {
return nil, err
exponent, err := convertExponentStringToInt(jwk.E)
if err != nil {
return "", fmt.Errorf("unable to decode JWK exponent, error: %s", err)
} }
publicKey := &rsa.PublicKey{ var yCoordinate []byte
N: intModules, if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil {
E: exponent, return nil, err
} }
derKey, err := x509.MarshalPKIXPublicKey(publicKey) publicKey = &ecdsa.PublicKey{}
if err != nil {
return "", fmt.Errorf("unable to convert public key to DER, error: %s", err) var curve elliptic.Curve
switch jwk.Crv {
case p256:
curve = elliptic.P256()
case p384:
curve = elliptic.P384()
case p521:
curve = elliptic.P521()
} }
block := &pem.Block{ publicKey.Curve = curve
Type: "RSA PUBLIC KEY", publicKey.X = big.NewInt(0).SetBytes(xCoordinate)
Bytes: derKey, publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)
}
var out bytes.Buffer return publicKey, nil
err = pem.Encode(&out, block)
if err != nil {
return "", fmt.Errorf("unable to encode Pem block , error: %s", err)
}
return out.String(), nil
} }
func convertExponentStringToInt(stringExponent string) (int, error) { func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) {
decodedString, err := base64.StdEncoding.DecodeString(stringExponent)
decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil { if err != nil {
return 0, err return nil, err
} }
exponentBytes := decodedString decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N)
if len(decodedString) < 8 { if err != nil {
exponentBytes = make([]byte, 8-len(decodedString), 8) return nil, err
exponentBytes = append(exponentBytes, decodedString...)
} }
bytesReader := bytes.NewReader(exponentBytes) var n, e big.Int
var exponent uint64 e.SetBytes(decodedE)
err = binary.Read(bytesReader, binary.BigEndian, &exponent) n.SetBytes(decodedN)
if err != nil {
return 0, err
}
return int(exponent), nil return &rsa.PublicKey{
E: int(e.Int64()),
N: &n,
}, nil
} }
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header // getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
@ -306,3 +310,4 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
return 0 return 0
} }