Merge pull request #808 from bcmmbaga/main

Add support for refreshing signing keys on expiry
This commit is contained in:
pascal-fischer 2023-05-02 17:17:09 +02:00 committed by GitHub
commit 88678ef364
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 73 additions and 10 deletions

View File

@ -80,6 +80,7 @@ var (
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
}
config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled
tlsEnabled := false
if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") {
@ -186,6 +187,7 @@ var (
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation,
config.HttpConfig.IdpSignKeyRefreshEnabled,
)
if err != nil {
return fmt.Errorf("failed creating JWT validator: %v", err)

View File

@ -16,13 +16,14 @@ const (
)
var (
dnsDomain string
mgmtDataDir string
mgmtConfig string
logLevel string
logFile string
disableMetrics bool
disableSingleAccMode bool
dnsDomain string
mgmtDataDir string
mgmtConfig string
logLevel string
logFile string
disableMetrics bool
disableSingleAccMode bool
idpSignKeyRefreshEnabled bool
rootCmd = &cobra.Command{
Use: "netbird-mgmt",
@ -54,6 +55,7 @@ func init() {
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
rootCmd.MarkFlagRequired("config") //nolint
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")

View File

@ -80,6 +80,8 @@ type HttpServerConfig struct {
AuthKeysLocation string
// OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration
OIDCConfigEndpoint string
// IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not
IdpSignKeyRefreshEnabled bool
}
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)

View File

@ -52,7 +52,9 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
jwtValidator, err = jwtclaims.NewJWTValidator(
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation)
config.HttpConfig.AuthKeysLocation,
config.HttpConfig.IdpSignKeyRefreshEnabled,
)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
}

View File

@ -12,6 +12,10 @@ import (
"fmt"
"math/big"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
@ -45,7 +49,8 @@ type Options struct {
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct {
Keys []JSONWebKey `json:"keys"`
Keys []JSONWebKey `json:"keys"`
expiresInTime time.Time
}
// JSONWebKey is a representation of a Jason Web Key
@ -64,12 +69,13 @@ type JWTValidator struct {
}
// NewJWTValidator constructor
func NewJWTValidator(issuer string, audienceList []string, keysLocation string) (*JWTValidator, error) {
func NewJWTValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
keys, err := getPemKeys(keysLocation)
if err != nil {
return nil, err
}
var lock sync.Mutex
options := Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
// Verify 'aud' claim
@ -89,6 +95,23 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string)
return token, errors.New("invalid issuer")
}
// If keys are rotated, verify the keys prior to token validation
if idpSignkeyRefreshEnabled {
// If the keys are invalid, retrieve new ones
if !keys.stillValid() {
lock.Lock()
defer lock.Unlock()
refreshedKeys, err := getPemKeys(keysLocation)
if err != nil {
log.Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
refreshedKeys = keys
}
keys = refreshedKeys
}
}
cert, err := getPemCert(token, keys)
if err != nil {
return nil, err
@ -154,6 +177,11 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
return parsedToken, nil
}
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
func (jwks *Jwks) stillValid() bool {
return jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
}
func getPemKeys(keysLocation string) (*Jwks, error) {
resp, err := http.Get(keysLocation)
if err != nil {
@ -167,6 +195,10 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
return jwks, err
}
cacheControlHeader := resp.Header.Get("Cache-Control")
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
return jwks, err
}
@ -248,3 +280,26 @@ func convertExponentStringToInt(stringExponent string) (int, error) {
return int(exponent), nil
}
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
func getMaxAgeFromCacheHeader(cacheControl string) int {
// Split into individual directives
directives := strings.Split(cacheControl, ",")
for _, directive := range directives {
directive = strings.TrimSpace(directive)
if strings.HasPrefix(directive, "max-age=") {
// Extract the max-age value
maxAgeStr := strings.TrimPrefix(directive, "max-age=")
maxAge, err := strconv.Atoi(maxAgeStr)
if err != nil {
log.Debugf("error parsing max-age: %v", err)
return 0
}
return maxAge
}
}
return 0
}