refactoring

This commit is contained in:
Bethuel 2023-04-15 03:44:42 +03:00
parent 53d78ad982
commit f7196cd9a5
5 changed files with 17 additions and 18 deletions

View File

@ -80,7 +80,7 @@ var (
if err != nil { if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err) return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
} }
config.HttpConfig.KeyRotationEnabled = useKeyCacheHeaders config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled
tlsEnabled := false tlsEnabled := false
if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") { if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") {
@ -187,7 +187,7 @@ var (
config.HttpConfig.AuthIssuer, config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(), config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation, config.HttpConfig.AuthKeysLocation,
config.HttpConfig.KeyRotationEnabled, config.HttpConfig.IdpSignKeyRefreshEnabled,
) )
if err != nil { if err != nil {
return fmt.Errorf("failed creating JWT validator: %v", err) return fmt.Errorf("failed creating JWT validator: %v", err)

View File

@ -16,14 +16,14 @@ const (
) )
var ( var (
dnsDomain string dnsDomain string
mgmtDataDir string mgmtDataDir string
mgmtConfig string mgmtConfig string
logLevel string logLevel string
logFile string logFile string
disableMetrics bool disableMetrics bool
disableSingleAccMode bool disableSingleAccMode bool
useKeyCacheHeaders bool idpSignKeyRefreshEnabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird-mgmt", Use: "netbird-mgmt",
@ -55,7 +55,7 @@ func init() {
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird") mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain)) mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max lenght is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
mgmtCmd.Flags().BoolVar(&useKeyCacheHeaders, "use-key-cache-headers", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.") mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
rootCmd.MarkFlagRequired("config") //nolint rootCmd.MarkFlagRequired("config") //nolint
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "") rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")

View File

@ -80,8 +80,8 @@ type HttpServerConfig struct {
AuthKeysLocation string AuthKeysLocation string
// OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration // OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration
OIDCConfigEndpoint string OIDCConfigEndpoint string
// KeyRotationEnabled identifies the signing key is currently being rotated or not // IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not
KeyRotationEnabled bool IdpSignKeyRefreshEnabled bool
} }
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal) // Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)

View File

@ -53,7 +53,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
config.HttpConfig.AuthIssuer, config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(), config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation, config.HttpConfig.AuthKeysLocation,
config.HttpConfig.KeyRotationEnabled, config.HttpConfig.IdpSignKeyRefreshEnabled,
) )
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err) return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)

View File

@ -68,7 +68,7 @@ type JWTValidator struct {
} }
// NewJWTValidator constructor // NewJWTValidator constructor
func NewJWTValidator(issuer string, audienceList []string, keysLocation string, keyRotationEnabled bool) (*JWTValidator, error) { func NewJWTValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
keys, err := getPemKeys(keysLocation) keys, err := getPemKeys(keysLocation)
if err != nil { if err != nil {
return nil, err return nil, err
@ -94,13 +94,12 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string,
} }
// If keys are rotated, verify the keys prior to token validation // If keys are rotated, verify the keys prior to token validation
if keyRotationEnabled { if idpSignkeyRefreshEnabled {
// If the keys are invalid, retrieve new ones // If the keys are invalid, retrieve new ones
if !keys.stillValid() { if !keys.stillValid() {
keys, err = getPemKeys(keysLocation) keys, err = getPemKeys(keysLocation)
if err != nil { if err != nil {
log.Errorf("cannot get JSONWebKey: %v", err) log.Debugf("cannot get JSONWebKey: %v", err)
return nil, err return nil, err
} }
} }