diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index 39676982e..d5c1e7c9e 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -1,14 +1,12 @@ package jwtclaims import ( - "bytes" "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rsa" - "crypto/x509" "encoding/base64" - "encoding/binary" "encoding/json" - "encoding/pem" "errors" "fmt" "math/big" @@ -41,11 +39,6 @@ type Options struct { // When set, all requests with the OPTIONS method will use authentication // Default: false EnableAuthOnOptions bool - // When set, the middelware verifies that tokens are signed with the specific signing algorithm - // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks - // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ - // Default: nil - SigningMethod jwt.SigningMethod } // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation @@ -54,6 +47,18 @@ type Jwks struct { expiresInTime time.Time } +// The supported elliptic curves types +const ( + // p256 represents a cryptographic elliptical curve type. + p256 = "P-256" + + // p384 represents a cryptographic elliptical curve type. + p384 = "P-384" + + // p521 represents a cryptographic elliptical curve type. + p521 = "P-521" +) + // JSONWebKey is a representation of a Jason Web Key type JSONWebKey struct { Kty string `json:"kty"` @@ -61,6 +66,9 @@ type JSONWebKey struct { Use string `json:"use"` N string `json:"n"` E string `json:"e"` + Crv string `json:"crv"` + X string `json:"x"` + Y string `json:"y"` X5c []string `json:"x5c"` } @@ -115,15 +123,14 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, } } - cert, err := getPemCert(ctx, token, keys) + publicKey, err := getPublicKey(ctx, token, keys) if err != nil { + log.WithContext(ctx).Errorf("getPublicKey error: %s", err) return nil, err } - result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert)) - return result, nil + return publicKey, nil }, - SigningMethod: jwt.SigningMethodRS256, EnableAuthOnOptions: false, } @@ -159,15 +166,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt // Check if there was an error in parsing... if err != nil { log.WithContext(ctx).Errorf("error parsing token: %v", err) - return nil, fmt.Errorf("Error parsing token: %w", err) - } - - if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] { - errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s", - m.options.SigningMethod.Alg(), - parsedToken.Header["alg"]) - log.WithContext(ctx).Debugf("error validating token algorithm: %s", errorMsg) - return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg) + return nil, fmt.Errorf("error parsing token: %w", err) } // Check if the parsed token is valid... @@ -205,9 +204,8 @@ func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) { return jwks, err } -func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, error) { +func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{}, error) { // todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time - cert := "" for k := range jwks.Keys { if token.Header["kid"] != jwks.Keys[k].Kid { @@ -215,73 +213,79 @@ func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, erro } if len(jwks.Keys[k].X5c) != 0 { - cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----" - return cert, nil + cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----" + return jwt.ParseRSAPublicKeyFromPEM([]byte(cert)) } - log.WithContext(ctx).Debugf("generating validation pem from JWK") - return generatePemFromJWK(jwks.Keys[k]) + + if jwks.Keys[k].Kty == "RSA" { + log.WithContext(ctx).Debugf("generating PublicKey from RSA JWK") + return getPublicKeyFromRSA(jwks.Keys[k]) + } + if jwks.Keys[k].Kty == "EC" { + log.WithContext(ctx).Debugf("generating PublicKey from ECDSA JWK") + return getPublicKeyFromECDSA(jwks.Keys[k]) + } + + log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty) } - return cert, errors.New("unable to find appropriate key") + return nil, errors.New("unable to find appropriate key") } -func generatePemFromJWK(jwk JSONWebKey) (string, error) { - decodedModulus, err := base64.RawURLEncoding.DecodeString(jwk.N) - if err != nil { - return "", fmt.Errorf("unable to decode JWK modulus, error: %s", err) +func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) { + + if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" { + return nil, fmt.Errorf("ecdsa key incomplete") } - intModules := big.NewInt(0) - intModules.SetBytes(decodedModulus) - - exponent, err := convertExponentStringToInt(jwk.E) - if err != nil { - return "", fmt.Errorf("unable to decode JWK exponent, error: %s", err) + var xCoordinate []byte + if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil { + return nil, err } - publicKey := &rsa.PublicKey{ - N: intModules, - E: exponent, + var yCoordinate []byte + if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil { + return nil, err } - derKey, err := x509.MarshalPKIXPublicKey(publicKey) - if err != nil { - return "", fmt.Errorf("unable to convert public key to DER, error: %s", err) + publicKey = &ecdsa.PublicKey{} + + var curve elliptic.Curve + switch jwk.Crv { + case p256: + curve = elliptic.P256() + case p384: + curve = elliptic.P384() + case p521: + curve = elliptic.P521() } - block := &pem.Block{ - Type: "RSA PUBLIC KEY", - Bytes: derKey, - } + publicKey.Curve = curve + publicKey.X = big.NewInt(0).SetBytes(xCoordinate) + publicKey.Y = big.NewInt(0).SetBytes(yCoordinate) - var out bytes.Buffer - err = pem.Encode(&out, block) - if err != nil { - return "", fmt.Errorf("unable to encode Pem block , error: %s", err) - } - - return out.String(), nil + return publicKey, nil } -func convertExponentStringToInt(stringExponent string) (int, error) { - decodedString, err := base64.StdEncoding.DecodeString(stringExponent) +func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) { + + decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E) if err != nil { - return 0, err + return nil, err } - exponentBytes := decodedString - if len(decodedString) < 8 { - exponentBytes = make([]byte, 8-len(decodedString), 8) - exponentBytes = append(exponentBytes, decodedString...) + decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N) + if err != nil { + return nil, err } - bytesReader := bytes.NewReader(exponentBytes) - var exponent uint64 - err = binary.Read(bytesReader, binary.BigEndian, &exponent) - if err != nil { - return 0, err - } + var n, e big.Int + e.SetBytes(decodedE) + n.SetBytes(decodedN) - return int(exponent), nil + return &rsa.PublicKey{ + E: int(e.Int64()), + N: &n, + }, nil } // getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header @@ -306,3 +310,4 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int { return 0 } +