mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-22 08:03:30 +01:00
[management] Add support to ECDSA public Keys (#2461)
Update the JWT validation logic to handle ECDSA keys in addition to the existing RSA keys --------- Co-authored-by: Harry Kodden <harry.kodden@surf.nl> Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
be6bc46bcd
commit
00944bcdbf
@ -1,14 +1,12 @@
|
|||||||
package jwtclaims
|
package jwtclaims
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/x509"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
@ -41,11 +39,6 @@ type Options struct {
|
|||||||
// When set, all requests with the OPTIONS method will use authentication
|
// When set, all requests with the OPTIONS method will use authentication
|
||||||
// Default: false
|
// Default: false
|
||||||
EnableAuthOnOptions bool
|
EnableAuthOnOptions bool
|
||||||
// When set, the middelware verifies that tokens are signed with the specific signing algorithm
|
|
||||||
// If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
|
|
||||||
// Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
|
|
||||||
// Default: nil
|
|
||||||
SigningMethod jwt.SigningMethod
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
|
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
|
||||||
@ -54,6 +47,18 @@ type Jwks struct {
|
|||||||
expiresInTime time.Time
|
expiresInTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The supported elliptic curves types
|
||||||
|
const (
|
||||||
|
// p256 represents a cryptographic elliptical curve type.
|
||||||
|
p256 = "P-256"
|
||||||
|
|
||||||
|
// p384 represents a cryptographic elliptical curve type.
|
||||||
|
p384 = "P-384"
|
||||||
|
|
||||||
|
// p521 represents a cryptographic elliptical curve type.
|
||||||
|
p521 = "P-521"
|
||||||
|
)
|
||||||
|
|
||||||
// JSONWebKey is a representation of a Jason Web Key
|
// JSONWebKey is a representation of a Jason Web Key
|
||||||
type JSONWebKey struct {
|
type JSONWebKey struct {
|
||||||
Kty string `json:"kty"`
|
Kty string `json:"kty"`
|
||||||
@ -61,6 +66,9 @@ type JSONWebKey struct {
|
|||||||
Use string `json:"use"`
|
Use string `json:"use"`
|
||||||
N string `json:"n"`
|
N string `json:"n"`
|
||||||
E string `json:"e"`
|
E string `json:"e"`
|
||||||
|
Crv string `json:"crv"`
|
||||||
|
X string `json:"x"`
|
||||||
|
Y string `json:"y"`
|
||||||
X5c []string `json:"x5c"`
|
X5c []string `json:"x5c"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,15 +123,14 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cert, err := getPemCert(ctx, token, keys)
|
publicKey, err := getPublicKey(ctx, token, keys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("getPublicKey error: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
|
return publicKey, nil
|
||||||
return result, nil
|
|
||||||
},
|
},
|
||||||
SigningMethod: jwt.SigningMethodRS256,
|
|
||||||
EnableAuthOnOptions: false,
|
EnableAuthOnOptions: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -159,15 +166,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
|
|||||||
// Check if there was an error in parsing...
|
// Check if there was an error in parsing...
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("error parsing token: %v", err)
|
log.WithContext(ctx).Errorf("error parsing token: %v", err)
|
||||||
return nil, fmt.Errorf("Error parsing token: %w", err)
|
return nil, fmt.Errorf("error parsing token: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] {
|
|
||||||
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
|
|
||||||
m.options.SigningMethod.Alg(),
|
|
||||||
parsedToken.Header["alg"])
|
|
||||||
log.WithContext(ctx).Debugf("error validating token algorithm: %s", errorMsg)
|
|
||||||
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the parsed token is valid...
|
// Check if the parsed token is valid...
|
||||||
@ -205,9 +204,8 @@ func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) {
|
|||||||
return jwks, err
|
return jwks, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, error) {
|
func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{}, error) {
|
||||||
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
|
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
|
||||||
cert := ""
|
|
||||||
|
|
||||||
for k := range jwks.Keys {
|
for k := range jwks.Keys {
|
||||||
if token.Header["kid"] != jwks.Keys[k].Kid {
|
if token.Header["kid"] != jwks.Keys[k].Kid {
|
||||||
@ -215,73 +213,79 @@ func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(jwks.Keys[k].X5c) != 0 {
|
if len(jwks.Keys[k].X5c) != 0 {
|
||||||
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
|
cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
|
||||||
return cert, nil
|
return jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("generating validation pem from JWK")
|
|
||||||
return generatePemFromJWK(jwks.Keys[k])
|
if jwks.Keys[k].Kty == "RSA" {
|
||||||
|
log.WithContext(ctx).Debugf("generating PublicKey from RSA JWK")
|
||||||
|
return getPublicKeyFromRSA(jwks.Keys[k])
|
||||||
|
}
|
||||||
|
if jwks.Keys[k].Kty == "EC" {
|
||||||
|
log.WithContext(ctx).Debugf("generating PublicKey from ECDSA JWK")
|
||||||
|
return getPublicKeyFromECDSA(jwks.Keys[k])
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cert, errors.New("unable to find appropriate key")
|
return nil, errors.New("unable to find appropriate key")
|
||||||
}
|
}
|
||||||
|
|
||||||
func generatePemFromJWK(jwk JSONWebKey) (string, error) {
|
func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
|
||||||
decodedModulus, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
|
||||||
if err != nil {
|
if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" {
|
||||||
return "", fmt.Errorf("unable to decode JWK modulus, error: %s", err)
|
return nil, fmt.Errorf("ecdsa key incomplete")
|
||||||
}
|
}
|
||||||
|
|
||||||
intModules := big.NewInt(0)
|
var xCoordinate []byte
|
||||||
intModules.SetBytes(decodedModulus)
|
if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil {
|
||||||
|
return nil, err
|
||||||
exponent, err := convertExponentStringToInt(jwk.E)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("unable to decode JWK exponent, error: %s", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKey := &rsa.PublicKey{
|
var yCoordinate []byte
|
||||||
N: intModules,
|
if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil {
|
||||||
E: exponent,
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
derKey, err := x509.MarshalPKIXPublicKey(publicKey)
|
publicKey = &ecdsa.PublicKey{}
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("unable to convert public key to DER, error: %s", err)
|
var curve elliptic.Curve
|
||||||
|
switch jwk.Crv {
|
||||||
|
case p256:
|
||||||
|
curve = elliptic.P256()
|
||||||
|
case p384:
|
||||||
|
curve = elliptic.P384()
|
||||||
|
case p521:
|
||||||
|
curve = elliptic.P521()
|
||||||
}
|
}
|
||||||
|
|
||||||
block := &pem.Block{
|
publicKey.Curve = curve
|
||||||
Type: "RSA PUBLIC KEY",
|
publicKey.X = big.NewInt(0).SetBytes(xCoordinate)
|
||||||
Bytes: derKey,
|
publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)
|
||||||
}
|
|
||||||
|
|
||||||
var out bytes.Buffer
|
return publicKey, nil
|
||||||
err = pem.Encode(&out, block)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("unable to encode Pem block , error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out.String(), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertExponentStringToInt(stringExponent string) (int, error) {
|
func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) {
|
||||||
decodedString, err := base64.StdEncoding.DecodeString(stringExponent)
|
|
||||||
|
decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
exponentBytes := decodedString
|
decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||||
if len(decodedString) < 8 {
|
if err != nil {
|
||||||
exponentBytes = make([]byte, 8-len(decodedString), 8)
|
return nil, err
|
||||||
exponentBytes = append(exponentBytes, decodedString...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bytesReader := bytes.NewReader(exponentBytes)
|
var n, e big.Int
|
||||||
var exponent uint64
|
e.SetBytes(decodedE)
|
||||||
err = binary.Read(bytesReader, binary.BigEndian, &exponent)
|
n.SetBytes(decodedN)
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return int(exponent), nil
|
return &rsa.PublicKey{
|
||||||
|
E: int(e.Int64()),
|
||||||
|
N: &n,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
|
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
|
||||||
@ -306,3 +310,4 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
|
|||||||
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user