Generate validation certificate from mandatory JWK fields (#614)

When there is no X5c we will use N and E fields of 
a JWK to generate the public RSA and a Pem certificate
This commit is contained in:
Maycon Santos 2022-12-07 22:06:43 +01:00 committed by GitHub
parent 0fbfec4ce4
commit 0be46c083d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,19 +1,28 @@
package middleware
import (
"bytes"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/binary"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"math/big"
"net/http"
)
// Jwks is a collection of JSONWebKeys obtained from Config.HttpServerConfig.AuthKeysLocation
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct {
Keys []JSONWebKeys `json:"keys"`
Keys []JSONWebKey `json:"keys"`
}
// JSONWebKeys is a representation of a Jason Web Key
type JSONWebKeys struct {
// JSONWebKey is a representation of a Jason Web Key
type JSONWebKey struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
Use string `json:"use"`
@ -45,7 +54,7 @@ func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWT
cert, err := getPemCert(token, keys)
if err != nil {
panic(err.Error())
return nil, err
}
result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
@ -74,18 +83,80 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
}
func getPemCert(token *jwt.Token, jwks *Jwks) (string, error) {
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
cert := ""
for k := range jwks.Keys {
if token.Header["kid"] == jwks.Keys[k].Kid {
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
if token.Header["kid"] != jwks.Keys[k].Kid {
continue
}
if len(jwks.Keys[k].X5c) != 0 {
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
return cert, nil
}
log.Debugf("generating validation pem from JWK")
return generatePemFromJWK(jwks.Keys[k])
}
if cert == "" {
err := errors.New("unable to find appropriate key")
return cert, err
}
return cert, nil
return "", errors.New("unable to find appropriate key")
}
func generatePemFromJWK(jwk JSONWebKey) (string, error) {
decodedModulus, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return "", fmt.Errorf("unable to decode JWK modulus, error: %s", err)
}
intModules := big.NewInt(0)
intModules.SetBytes(decodedModulus)
exponent, err := convertExponentStringToInt(jwk.E)
if err != nil {
return "", fmt.Errorf("unable to decode JWK exponent, error: %s", err)
}
publicKey := &rsa.PublicKey{
N: intModules,
E: exponent,
}
derKey, err := x509.MarshalPKIXPublicKey(publicKey)
if err != nil {
return "", fmt.Errorf("unable to convert public key to DER, error: %s", err)
}
block := &pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: derKey,
}
var out bytes.Buffer
err = pem.Encode(&out, block)
if err != nil {
return "", fmt.Errorf("unable to encode Pem block , error: %s", err)
}
return out.String(), nil
}
func convertExponentStringToInt(stringExponent string) (int, error) {
decodedString, err := base64.StdEncoding.DecodeString(stringExponent)
if err != nil {
return 0, err
}
exponentBytes := decodedString
if len(decodedString) < 8 {
exponentBytes = make([]byte, 8-len(decodedString), 8)
exponentBytes = append(exponentBytes, decodedString...)
}
bytesReader := bytes.NewReader(exponentBytes)
var exponent uint64
err = binary.Read(bytesReader, binary.BigEndian, &exponent)
if err != nil {
return 0, err
}
return int(exponent), nil
}