mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-25 09:33:24 +01:00
[management] Add support to ECDSA public Keys (#2461)
Update the JWT validation logic to handle ECDSA keys in addition to the existing RSA keys --------- Co-authored-by: Harry Kodden <harry.kodden@surf.nl> Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
be6bc46bcd
commit
00944bcdbf
@ -1,14 +1,12 @@
|
||||
package jwtclaims
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
@ -41,11 +39,6 @@ type Options struct {
|
||||
// When set, all requests with the OPTIONS method will use authentication
|
||||
// Default: false
|
||||
EnableAuthOnOptions bool
|
||||
// When set, the middelware verifies that tokens are signed with the specific signing algorithm
|
||||
// If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
|
||||
// Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
|
||||
// Default: nil
|
||||
SigningMethod jwt.SigningMethod
|
||||
}
|
||||
|
||||
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
|
||||
@ -54,6 +47,18 @@ type Jwks struct {
|
||||
expiresInTime time.Time
|
||||
}
|
||||
|
||||
// The supported elliptic curves types
|
||||
const (
|
||||
// p256 represents a cryptographic elliptical curve type.
|
||||
p256 = "P-256"
|
||||
|
||||
// p384 represents a cryptographic elliptical curve type.
|
||||
p384 = "P-384"
|
||||
|
||||
// p521 represents a cryptographic elliptical curve type.
|
||||
p521 = "P-521"
|
||||
)
|
||||
|
||||
// JSONWebKey is a representation of a Jason Web Key
|
||||
type JSONWebKey struct {
|
||||
Kty string `json:"kty"`
|
||||
@ -61,6 +66,9 @@ type JSONWebKey struct {
|
||||
Use string `json:"use"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
Crv string `json:"crv"`
|
||||
X string `json:"x"`
|
||||
Y string `json:"y"`
|
||||
X5c []string `json:"x5c"`
|
||||
}
|
||||
|
||||
@ -115,15 +123,14 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
|
||||
}
|
||||
}
|
||||
|
||||
cert, err := getPemCert(ctx, token, keys)
|
||||
publicKey, err := getPublicKey(ctx, token, keys)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("getPublicKey error: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
|
||||
return result, nil
|
||||
return publicKey, nil
|
||||
},
|
||||
SigningMethod: jwt.SigningMethodRS256,
|
||||
EnableAuthOnOptions: false,
|
||||
}
|
||||
|
||||
@ -159,15 +166,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
|
||||
// Check if there was an error in parsing...
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error parsing token: %v", err)
|
||||
return nil, fmt.Errorf("Error parsing token: %w", err)
|
||||
}
|
||||
|
||||
if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] {
|
||||
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
|
||||
m.options.SigningMethod.Alg(),
|
||||
parsedToken.Header["alg"])
|
||||
log.WithContext(ctx).Debugf("error validating token algorithm: %s", errorMsg)
|
||||
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
|
||||
return nil, fmt.Errorf("error parsing token: %w", err)
|
||||
}
|
||||
|
||||
// Check if the parsed token is valid...
|
||||
@ -205,9 +204,8 @@ func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) {
|
||||
return jwks, err
|
||||
}
|
||||
|
||||
func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, error) {
|
||||
func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{}, error) {
|
||||
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
|
||||
cert := ""
|
||||
|
||||
for k := range jwks.Keys {
|
||||
if token.Header["kid"] != jwks.Keys[k].Kid {
|
||||
@ -215,73 +213,79 @@ func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, erro
|
||||
}
|
||||
|
||||
if len(jwks.Keys[k].X5c) != 0 {
|
||||
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
|
||||
return cert, nil
|
||||
}
|
||||
log.WithContext(ctx).Debugf("generating validation pem from JWK")
|
||||
return generatePemFromJWK(jwks.Keys[k])
|
||||
cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
|
||||
return jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
|
||||
}
|
||||
|
||||
return cert, errors.New("unable to find appropriate key")
|
||||
if jwks.Keys[k].Kty == "RSA" {
|
||||
log.WithContext(ctx).Debugf("generating PublicKey from RSA JWK")
|
||||
return getPublicKeyFromRSA(jwks.Keys[k])
|
||||
}
|
||||
if jwks.Keys[k].Kty == "EC" {
|
||||
log.WithContext(ctx).Debugf("generating PublicKey from ECDSA JWK")
|
||||
return getPublicKeyFromECDSA(jwks.Keys[k])
|
||||
}
|
||||
|
||||
func generatePemFromJWK(jwk JSONWebKey) (string, error) {
|
||||
decodedModulus, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty)
|
||||
}
|
||||
|
||||
return nil, errors.New("unable to find appropriate key")
|
||||
}
|
||||
|
||||
func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
|
||||
|
||||
if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" {
|
||||
return nil, fmt.Errorf("ecdsa key incomplete")
|
||||
}
|
||||
|
||||
var xCoordinate []byte
|
||||
if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var yCoordinate []byte
|
||||
if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey = &ecdsa.PublicKey{}
|
||||
|
||||
var curve elliptic.Curve
|
||||
switch jwk.Crv {
|
||||
case p256:
|
||||
curve = elliptic.P256()
|
||||
case p384:
|
||||
curve = elliptic.P384()
|
||||
case p521:
|
||||
curve = elliptic.P521()
|
||||
}
|
||||
|
||||
publicKey.Curve = curve
|
||||
publicKey.X = big.NewInt(0).SetBytes(xCoordinate)
|
||||
publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)
|
||||
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) {
|
||||
|
||||
decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to decode JWK modulus, error: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
intModules := big.NewInt(0)
|
||||
intModules.SetBytes(decodedModulus)
|
||||
|
||||
exponent, err := convertExponentStringToInt(jwk.E)
|
||||
decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to decode JWK exponent, error: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey := &rsa.PublicKey{
|
||||
N: intModules,
|
||||
E: exponent,
|
||||
}
|
||||
var n, e big.Int
|
||||
e.SetBytes(decodedE)
|
||||
n.SetBytes(decodedN)
|
||||
|
||||
derKey, err := x509.MarshalPKIXPublicKey(publicKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to convert public key to DER, error: %s", err)
|
||||
}
|
||||
|
||||
block := &pem.Block{
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: derKey,
|
||||
}
|
||||
|
||||
var out bytes.Buffer
|
||||
err = pem.Encode(&out, block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to encode Pem block , error: %s", err)
|
||||
}
|
||||
|
||||
return out.String(), nil
|
||||
}
|
||||
|
||||
func convertExponentStringToInt(stringExponent string) (int, error) {
|
||||
decodedString, err := base64.StdEncoding.DecodeString(stringExponent)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
exponentBytes := decodedString
|
||||
if len(decodedString) < 8 {
|
||||
exponentBytes = make([]byte, 8-len(decodedString), 8)
|
||||
exponentBytes = append(exponentBytes, decodedString...)
|
||||
}
|
||||
|
||||
bytesReader := bytes.NewReader(exponentBytes)
|
||||
var exponent uint64
|
||||
err = binary.Read(bytesReader, binary.BigEndian, &exponent)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int(exponent), nil
|
||||
return &rsa.PublicKey{
|
||||
E: int(e.Int64()),
|
||||
N: &n,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
|
||||
@ -306,3 +310,4 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user