diff --git a/management/server/http/middleware/handler.go b/management/server/http/middleware/handler.go index 89b8410dd..8725a3d27 100644 --- a/management/server/http/middleware/handler.go +++ b/management/server/http/middleware/handler.go @@ -1,19 +1,28 @@ package middleware import ( + "bytes" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/binary" "encoding/json" + "encoding/pem" "errors" + "fmt" "github.com/golang-jwt/jwt" + log "github.com/sirupsen/logrus" + "math/big" "net/http" ) -// Jwks is a collection of JSONWebKeys obtained from Config.HttpServerConfig.AuthKeysLocation +// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation type Jwks struct { - Keys []JSONWebKeys `json:"keys"` + Keys []JSONWebKey `json:"keys"` } -// JSONWebKeys is a representation of a Jason Web Key -type JSONWebKeys struct { +// JSONWebKey is a representation of a Jason Web Key +type JSONWebKey struct { Kty string `json:"kty"` Kid string `json:"kid"` Use string `json:"use"` @@ -45,7 +54,7 @@ func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWT cert, err := getPemCert(token, keys) if err != nil { - panic(err.Error()) + return nil, err } result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert)) @@ -74,18 +83,80 @@ func getPemKeys(keysLocation string) (*Jwks, error) { } func getPemCert(token *jwt.Token, jwks *Jwks) (string, error) { + // todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time cert := "" for k := range jwks.Keys { - if token.Header["kid"] == jwks.Keys[k].Kid { - cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----" + if token.Header["kid"] != jwks.Keys[k].Kid { + continue } + + if len(jwks.Keys[k].X5c) != 0 { + cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----" + return cert, nil + } + log.Debugf("generating validation pem from JWK") + return generatePemFromJWK(jwks.Keys[k]) } - if cert == "" { - err := errors.New("unable to find appropriate key") - return cert, err - } - - return cert, nil + return "", errors.New("unable to find appropriate key") +} + +func generatePemFromJWK(jwk JSONWebKey) (string, error) { + decodedModulus, err := base64.RawURLEncoding.DecodeString(jwk.N) + if err != nil { + return "", fmt.Errorf("unable to decode JWK modulus, error: %s", err) + } + + intModules := big.NewInt(0) + intModules.SetBytes(decodedModulus) + + exponent, err := convertExponentStringToInt(jwk.E) + if err != nil { + return "", fmt.Errorf("unable to decode JWK exponent, error: %s", err) + } + + publicKey := &rsa.PublicKey{ + N: intModules, + E: exponent, + } + + derKey, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + return "", fmt.Errorf("unable to convert public key to DER, error: %s", err) + } + + block := &pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: derKey, + } + + var out bytes.Buffer + err = pem.Encode(&out, block) + if err != nil { + return "", fmt.Errorf("unable to encode Pem block , error: %s", err) + } + + return out.String(), nil +} + +func convertExponentStringToInt(stringExponent string) (int, error) { + decodedString, err := base64.StdEncoding.DecodeString(stringExponent) + if err != nil { + return 0, err + } + exponentBytes := decodedString + if len(decodedString) < 8 { + exponentBytes = make([]byte, 8-len(decodedString), 8) + exponentBytes = append(exponentBytes, decodedString...) + } + + bytesReader := bytes.NewReader(exponentBytes) + var exponent uint64 + err = binary.Read(bytesReader, binary.BigEndian, &exponent) + if err != nil { + return 0, err + } + + return int(exponent), nil }