Merge pull request #808 from bcmmbaga/main

Add support for refreshing signing keys on expiry
This commit is contained in:
pascal-fischer 2023-05-02 17:17:09 +02:00 committed by GitHub
commit 88678ef364
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 73 additions and 10 deletions

View File

@ -80,6 +80,7 @@ var (
if err != nil { if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err) return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
} }
config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled
tlsEnabled := false tlsEnabled := false
if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") { if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") {
@ -186,6 +187,7 @@ var (
config.HttpConfig.AuthIssuer, config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(), config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation, config.HttpConfig.AuthKeysLocation,
config.HttpConfig.IdpSignKeyRefreshEnabled,
) )
if err != nil { if err != nil {
return fmt.Errorf("failed creating JWT validator: %v", err) return fmt.Errorf("failed creating JWT validator: %v", err)

View File

@ -16,13 +16,14 @@ const (
) )
var ( var (
dnsDomain string dnsDomain string
mgmtDataDir string mgmtDataDir string
mgmtConfig string mgmtConfig string
logLevel string logLevel string
logFile string logFile string
disableMetrics bool disableMetrics bool
disableSingleAccMode bool disableSingleAccMode bool
idpSignKeyRefreshEnabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird-mgmt", Use: "netbird-mgmt",
@ -54,6 +55,7 @@ func init() {
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird") mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain)) mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
rootCmd.MarkFlagRequired("config") //nolint rootCmd.MarkFlagRequired("config") //nolint
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "") rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")

View File

@ -80,6 +80,8 @@ type HttpServerConfig struct {
AuthKeysLocation string AuthKeysLocation string
// OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration // OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration
OIDCConfigEndpoint string OIDCConfigEndpoint string
// IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not
IdpSignKeyRefreshEnabled bool
} }
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal) // Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)

View File

@ -52,7 +52,9 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
jwtValidator, err = jwtclaims.NewJWTValidator( jwtValidator, err = jwtclaims.NewJWTValidator(
config.HttpConfig.AuthIssuer, config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(), config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation) config.HttpConfig.AuthKeysLocation,
config.HttpConfig.IdpSignKeyRefreshEnabled,
)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err) return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
} }

View File

@ -12,6 +12,10 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"net/http" "net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -45,7 +49,8 @@ type Options struct {
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct { type Jwks struct {
Keys []JSONWebKey `json:"keys"` Keys []JSONWebKey `json:"keys"`
expiresInTime time.Time
} }
// JSONWebKey is a representation of a Jason Web Key // JSONWebKey is a representation of a Jason Web Key
@ -64,12 +69,13 @@ type JWTValidator struct {
} }
// NewJWTValidator constructor // NewJWTValidator constructor
func NewJWTValidator(issuer string, audienceList []string, keysLocation string) (*JWTValidator, error) { func NewJWTValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
keys, err := getPemKeys(keysLocation) keys, err := getPemKeys(keysLocation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var lock sync.Mutex
options := Options{ options := Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
// Verify 'aud' claim // Verify 'aud' claim
@ -89,6 +95,23 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string)
return token, errors.New("invalid issuer") return token, errors.New("invalid issuer")
} }
// If keys are rotated, verify the keys prior to token validation
if idpSignkeyRefreshEnabled {
// If the keys are invalid, retrieve new ones
if !keys.stillValid() {
lock.Lock()
defer lock.Unlock()
refreshedKeys, err := getPemKeys(keysLocation)
if err != nil {
log.Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
refreshedKeys = keys
}
keys = refreshedKeys
}
}
cert, err := getPemCert(token, keys) cert, err := getPemCert(token, keys)
if err != nil { if err != nil {
return nil, err return nil, err
@ -154,6 +177,11 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
return parsedToken, nil return parsedToken, nil
} }
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
func (jwks *Jwks) stillValid() bool {
return jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
}
func getPemKeys(keysLocation string) (*Jwks, error) { func getPemKeys(keysLocation string) (*Jwks, error) {
resp, err := http.Get(keysLocation) resp, err := http.Get(keysLocation)
if err != nil { if err != nil {
@ -167,6 +195,10 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
return jwks, err return jwks, err
} }
cacheControlHeader := resp.Header.Get("Cache-Control")
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
return jwks, err return jwks, err
} }
@ -248,3 +280,26 @@ func convertExponentStringToInt(stringExponent string) (int, error) {
return int(exponent), nil return int(exponent), nil
} }
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
func getMaxAgeFromCacheHeader(cacheControl string) int {
// Split into individual directives
directives := strings.Split(cacheControl, ",")
for _, directive := range directives {
directive = strings.TrimSpace(directive)
if strings.HasPrefix(directive, "max-age=") {
// Extract the max-age value
maxAgeStr := strings.TrimPrefix(directive, "max-age=")
maxAge, err := strconv.Atoi(maxAgeStr)
if err != nil {
log.Debugf("error parsing max-age: %v", err)
return 0
}
return maxAge
}
}
return 0
}