diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 176b4d9f..ca9cfc51 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -51,7 +51,7 @@ func NewAuthVerifier(cfg v1.AuthServerConfig) (authVerifier Verifier) { authVerifier = NewTokenAuth(cfg.AdditionalScopes, cfg.Token) case v1.AuthMethodOIDC: tokenVerifier := NewTokenVerifier(cfg.OIDC) - authVerifier = NewOidcAuthVerifier(cfg.AdditionalScopes, tokenVerifier, cfg.OIDC.AllowedHostedDomains) + authVerifier = NewOidcAuthVerifier(cfg.AdditionalScopes, tokenVerifier, cfg.OIDC.AllowedClaims) } return authVerifier } diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go index ed5bb543..6b926d29 100644 --- a/pkg/auth/oidc.go +++ b/pkg/auth/oidc.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" "slices" + "strconv" "strings" "github.com/coreos/go-oidc/v3/oidc" @@ -111,8 +112,8 @@ type OidcAuthConsumer struct { verifier TokenVerifier subjectsFromLogin []string - // allowedHostedDomains specifies a list of allowed hosted domains for the "hd" claim in the token. - allowedHostedDomains []string + // allowedClaims specifies a map of allowed claims for the OIDC token. + allowedClaims map[string]string } func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier { @@ -129,19 +130,19 @@ func NewTokenVerifier(cfg v1.AuthOIDCServerConfig) TokenVerifier { return provider.Verifier(&verifierConf) } -func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVerifier, allowedHostedDomains []string) *OidcAuthConsumer { +func NewOidcAuthVerifier(additionalAuthScopes []v1.AuthScope, verifier TokenVerifier, allowedClaims map[string]string) *OidcAuthConsumer { return &OidcAuthConsumer{ additionalAuthScopes: additionalAuthScopes, verifier: verifier, subjectsFromLogin: []string{}, - allowedHostedDomains: allowedHostedDomains, + allowedClaims: allowedClaims, } } func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) { - // Verify hosted domain (hd claim). - if len(auth.allowedHostedDomains) > 0 { - // Decode token without verifying signature to retrieved 'hd' claim. + // Verify allowed claims if configured. + if len(auth.allowedClaims) > 0 { + // Decode token without verifying signature. parts := strings.Split(loginMsg.PrivilegeKey, ".") if len(parts) != 3 { return fmt.Errorf("invalid OIDC token format") @@ -157,24 +158,32 @@ func (auth *OidcAuthConsumer) VerifyLogin(loginMsg *msg.Login) (err error) { return fmt.Errorf("invalid OIDC token: failed to unmarshal payload: %v", err) } - hd, ok := claims["hd"].(string) - if !ok { - return fmt.Errorf("OIDC token missing required 'hd' claim") - } - - found := false - for _, domain := range auth.allowedHostedDomains { - if hd == domain { - found = true - break + // Iterate over allowed claims and attempt to verify. + for claimName, expectedValue := range auth.allowedClaims { + claimValue, ok := claims[claimName] + if !ok { + return fmt.Errorf("OIDC token missing required claim: %s", claimName) + } + + if strClaimValue, ok := claimValue.(string); ok { + if strClaimValue != expectedValue { + return fmt.Errorf("OIDC token claim '%s' value [%s] does not match expected value [%s]", claimName, strClaimValue, expectedValue) + } + } else if intClaimValue, ok := claimValue.(int); ok { + expectedIntValue, err := strconv.Atoi(expectedValue) + if err != nil { + return fmt.Errorf("OIDC token claim '%s' is number, expected value [%s] not parseable", claimName, expectedValue) + } + if intClaimValue != expectedIntValue { + return fmt.Errorf("OIDC token claim '%s' value [%d] does not match expected value [%d]", claimName, intClaimValue, expectedIntValue) + } + } else { + return fmt.Errorf("claim %s is of unsupported type", claimName) } - } - if !found { - return fmt.Errorf("OIDC token 'hd' claim [%s] is not in allowed list", hd) } } - // If hd check passes, proceed with standard verification. + // If claim verification passes, proceed with standard verification. token, err := auth.verifier.Verify(context.Background(), loginMsg.PrivilegeKey) if err != nil { return fmt.Errorf("invalid OIDC token in login: %v", err) diff --git a/pkg/auth/oidc_test.go b/pkg/auth/oidc_test.go index 66ff7fb9..dff2b5bc 100644 --- a/pkg/auth/oidc_test.go +++ b/pkg/auth/oidc_test.go @@ -23,7 +23,7 @@ func (m *mockTokenVerifier) Verify(ctx context.Context, subject string) (*oidc.I func TestPingWithEmptySubjectFromLoginFails(t *testing.T) { r := require.New(t) - consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{}) + consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, map[string]string{}) err := consumer.VerifyPing(&msg.Ping{ PrivilegeKey: "ping-without-login", Timestamp: time.Now().UnixMilli(), @@ -34,7 +34,7 @@ func TestPingWithEmptySubjectFromLoginFails(t *testing.T) { func TestPingAfterLoginWithNewSubjectSucceeds(t *testing.T) { r := require.New(t) - consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{}) + consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, map[string]string{}) err := consumer.VerifyLogin(&msg.Login{ PrivilegeKey: "ping-after-login", }) @@ -49,7 +49,7 @@ func TestPingAfterLoginWithNewSubjectSucceeds(t *testing.T) { func TestPingAfterLoginWithDifferentSubjectFails(t *testing.T) { r := require.New(t) - consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, []string{}) + consumer := auth.NewOidcAuthVerifier([]v1.AuthScope{v1.AuthScopeHeartBeats}, &mockTokenVerifier{}, map[string]string{}) err := consumer.VerifyLogin(&msg.Login{ PrivilegeKey: "login-with-first-subject", }) diff --git a/pkg/config/v1/server.go b/pkg/config/v1/server.go index 93281cf5..e233d110 100644 --- a/pkg/config/v1/server.go +++ b/pkg/config/v1/server.go @@ -147,15 +147,14 @@ type AuthOIDCServerConfig struct { // SkipIssuerCheck specifies whether to skip checking if the OIDC token's // issuer claim matches the issuer specified in OidcIssuer. SkipIssuerCheck bool `json:"skipIssuerCheck,omitempty"` - // AllowedHostedDomains specifies a list of allowed hosted domains for the - // "hd" claim in the token. - AllowedHostedDomains []string `json:"allowedHostedDomains,omitempty"` + // AllowedClaims specifies a map of allowed claims for the OIDC token. + AllowedClaims map[string]string `json:"allowedClaims,omitempty"` } func (c *AuthOIDCServerConfig) Complete() { - // Ensure AllowedHostedDomains is an empty slice and not nil - if c.AllowedHostedDomains == nil { - c.AllowedHostedDomains = []string{} + // Ensure AllowedClaims is at least an empty map and not nil + if c.AllowedClaims == nil { + c.AllowedClaims = map[string]string{} } }