draft flexible oidc configuration (#968)

This commit is contained in:
Michael Quigley 2025-06-09 16:59:33 -04:00
parent 1ac77fa5b9
commit 017c17156f
No known key found for this signature in database
GPG Key ID: 9B60314A9DD20A62
3 changed files with 142 additions and 226 deletions

View File

@ -2,11 +2,14 @@ package publicProxy
import (
"context"
"crypto/md5"
"github.com/michaelquigley/cf"
"github.com/openziti/zrok/endpoints"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
zhttp "github.com/zitadel/oidc/v2/pkg/http"
"golang.org/x/oauth2"
)
const V = 3
@ -35,19 +38,23 @@ type OauthConfig struct {
Providers []*OauthProviderConfig
}
func (oc *OauthConfig) GetProvider(name string) *OauthProviderConfig {
for _, provider := range oc.Providers {
if provider.Name == name {
return provider
}
}
return nil
}
type OauthProviderConfig struct {
Name string
ClientId string
ClientSecret string `cf:"+secret"`
Scopes []string
AuthURL string
TokenURL string
EmailEndpoint string
EmailPath string
SupportsPKCE bool
}
func (p *OauthProviderConfig) GetEndpoint() oauth2.Endpoint {
return oauth2.Endpoint{
AuthURL: p.AuthURL,
TokenURL: p.TokenURL,
}
}
func DefaultConfig() *Config {
@ -72,12 +79,24 @@ func configureOauthHandlers(ctx context.Context, cfg *Config, tls bool) error {
logrus.Info("no oauth configuration; skipping oauth handler startup")
return nil
}
if err := configureGoogleOauth(cfg.Oauth, tls); err != nil {
hash := md5.New()
if n, err := hash.Write([]byte(cfg.Oauth.HashKey)); err != nil {
return err
} else if n != len(cfg.Oauth.HashKey) {
return errors.New("short hash")
}
if err := configureGithubOauth(cfg.Oauth, tls); err != nil {
return err
key := hash.Sum(nil)
for _, providerCfg := range cfg.Oauth.Providers {
provider, err := configureOIDCProvider(cfg.Oauth, providerCfg, tls)
if err != nil {
logrus.Warnf("failed to configure provider %s: %v", providerCfg.Name, err)
continue
}
provider.setupHandlers(cfg.Oauth, key, tls)
}
zhttp.StartServer(ctx, cfg.Oauth.BindAddress)
return nil
}

View File

@ -1,154 +0,0 @@
package publicProxy
import (
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/zitadel/oidc/v2/pkg/client/rp"
zhttp "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
googleOauth "golang.org/x/oauth2/google"
"io"
"net/http"
"net/url"
"time"
)
func configureGoogleOauth(cfg *OauthConfig, tls bool) error {
scheme := "http"
if tls {
scheme = "https"
}
providerCfg := cfg.GetProvider("google")
if providerCfg == nil {
logrus.Info("unable to find provider config for google. Skipping.")
return nil
}
clientID := providerCfg.ClientId
rpConfig := &oauth2.Config{
ClientID: clientID,
ClientSecret: providerCfg.ClientSecret,
RedirectURL: fmt.Sprintf("%v/google/oauth", cfg.RedirectUrl),
Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"},
Endpoint: googleOauth.Endpoint,
}
hash := md5.New()
n, err := hash.Write([]byte(cfg.HashKey))
if err != nil {
return err
}
if n != len(cfg.HashKey) {
return errors.New("short hash")
}
key := hash.Sum(nil)
cookieHandler := zhttp.NewCookieHandler(key, key, zhttp.WithUnsecure(), zhttp.WithDomain(cfg.CookieDomain))
options := []rp.Option{
rp.WithCookieHandler(cookieHandler),
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
rp.WithPKCE(cookieHandler),
}
relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, options...)
if err != nil {
return err
}
type IntermediateJWT struct {
State string `json:"state"`
Host string `json:"host"`
AuthorizationCheckInterval string `json:"authorizationCheckInterval"`
jwt.RegisteredClaims
}
type googleOauthEmailResp struct {
Email string
}
authHandlerWithQueryState := func(party rp.RelyingParty) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
host, err := url.QueryUnescape(r.URL.Query().Get("targethost"))
if err != nil {
logrus.Errorf("unable to unescape target host: %v", err)
}
rp.AuthURLHandler(func() string {
id := uuid.New().String()
t := jwt.NewWithClaims(jwt.SigningMethodHS256, IntermediateJWT{
id,
host,
r.URL.Query().Get("checkInterval"),
jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: "zrok",
Subject: "intermediate_token",
ID: id,
},
})
s, err := t.SignedString(key)
if err != nil {
logrus.Errorf("unable to sign intermediate JWT: %v", err)
}
return s
}, party, rp.WithURLParam("access_type", "offline"))(w, r)
}
}
http.Handle("/google/login", authHandlerWithQueryState(relyingParty))
getEmail := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty) {
resp, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + url.QueryEscape(tokens.AccessToken))
if err != nil {
logrus.Errorf("error getting user info from google: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer func() {
_ = resp.Body.Close()
}()
response, err := io.ReadAll(resp.Body)
if err != nil {
logrus.Errorf("error reading response body: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
logrus.Infof("response from google userinfo endpoint: %s", string(response))
rDat := googleOauthEmailResp{}
err = json.Unmarshal(response, &rDat)
if err != nil {
logrus.Errorf("error unmarshalling google oauth response: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
token, err := jwt.ParseWithClaims(state, &IntermediateJWT{}, func(t *jwt.Token) (interface{}, error) {
return key, nil
})
if err != nil {
http.Error(w, fmt.Sprintf("After intermediate token parse: %v", err.Error()), http.StatusInternalServerError)
return
}
authCheckInterval := 3 * time.Hour
i, err := time.ParseDuration(token.Claims.(*IntermediateJWT).AuthorizationCheckInterval)
if err != nil {
logrus.Errorf("unable to parse authorization check interval: %v. Defaulting to 3 hours", err)
} else {
authCheckInterval = i
}
SetZrokCookie(w, cfg.CookieDomain, rDat.Email, tokens.AccessToken, "google", authCheckInterval, key, token.Claims.(*IntermediateJWT).Host)
http.Redirect(w, r, fmt.Sprintf("%s://%s", scheme, token.Claims.(*IntermediateJWT).Host), http.StatusFound)
}
http.Handle("/google/oauth", rp.CodeExchangeHandler(getEmail, relyingParty))
return nil
}

View File

@ -5,6 +5,12 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
@ -12,40 +18,41 @@ import (
zhttp "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
githubOAuth "golang.org/x/oauth2/github"
"io"
"net/http"
"net/url"
"time"
)
func configureGithubOauth(cfg *OauthConfig, tls bool) error {
scheme := "http"
if tls {
scheme = "https"
type OIDCProvider struct {
name string
config *oauth2.Config
relyingParty rp.RelyingParty
emailEndpoint string
emailPath string
}
type IntermediateJWT struct {
State string `json:"state"`
Host string `json:"host"`
AuthorizationCheckInterval string `json:"authorizationCheckInterval"`
jwt.RegisteredClaims
}
func configureOIDCProvider(cfg *OauthConfig, providerCfg *OauthProviderConfig, tls bool) (*OIDCProvider, error) {
if providerCfg == nil {
return nil, errors.New("provider configuration is required")
}
providerCfg := cfg.GetProvider("github")
if providerCfg == nil {
logrus.Info("unable to find provider config for github; skipping")
return nil
}
clientID := providerCfg.ClientId
rpConfig := &oauth2.Config{
ClientID: clientID,
ClientID: providerCfg.ClientId,
ClientSecret: providerCfg.ClientSecret,
RedirectURL: fmt.Sprintf("%v/github/oauth", cfg.RedirectUrl),
Scopes: []string{"user:email"},
Endpoint: githubOAuth.Endpoint,
RedirectURL: fmt.Sprintf("%v/%s/oauth", cfg.RedirectUrl, providerCfg.Name),
Scopes: providerCfg.Scopes,
Endpoint: providerCfg.GetEndpoint(),
}
hash := md5.New()
n, err := hash.Write([]byte(cfg.HashKey))
if err != nil {
return err
}
if n != len(cfg.HashKey) {
return errors.New("short hash")
if n, err := hash.Write([]byte(cfg.HashKey)); err != nil {
return nil, err
} else if n != len(cfg.HashKey) {
return nil, errors.New("short hash")
}
key := hash.Sum(nil)
@ -54,26 +61,30 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error {
options := []rp.Option{
rp.WithCookieHandler(cookieHandler),
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
//rp.WithPKCE(cookieHandler), //Github currently doesn't support pkce. Update when that changes.
}
if providerCfg.SupportsPKCE {
options = append(options, rp.WithPKCE(cookieHandler))
}
relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, options...)
if err != nil {
return err
return nil, err
}
type IntermediateJWT struct {
State string `json:"state"`
Host string `json:"host"`
AuthorizationCheckInterval string `json:"authorizationCheckInterval"`
jwt.RegisteredClaims
}
return &OIDCProvider{
name: providerCfg.Name,
config: rpConfig,
relyingParty: relyingParty,
emailEndpoint: providerCfg.EmailEndpoint,
emailPath: providerCfg.EmailPath,
}, nil
}
type githubUserResp struct {
Email string
Primary bool
Verified bool
Visibility string
func (p *OIDCProvider) setupHandlers(cfg *OauthConfig, key []byte, tls bool) {
scheme := "http"
if tls {
scheme = "https"
}
authHandlerWithQueryState := func(party rp.RelyingParty) http.HandlerFunc {
@ -106,9 +117,8 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error {
}
}
http.Handle("/github/login", authHandlerWithQueryState(relyingParty))
getEmail := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty) {
parsedUrl, err := url.Parse("https://api.github.com/user/emails")
parsedUrl, err := url.Parse(p.emailEndpoint)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -121,35 +131,26 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error {
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokens.AccessToken))
resp, err := http.DefaultClient.Do(req)
if err != nil {
logrus.Errorf("error getting user info from github: %v", err)
logrus.Errorf("error getting user info from %s: %v", p.name, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer func() {
_ = resp.Body.Close()
}()
defer func() { _ = resp.Body.Close() }()
response, err := io.ReadAll(resp.Body)
if err != nil {
logrus.Errorf("error reading response body: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var rDat []githubUserResp
err = json.Unmarshal(response, &rDat)
email, err := p.extractEmail(response)
if err != nil {
logrus.Errorf("error unmarshalling google oauth response: %v", err)
logrus.Errorf("error extracting email: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
primaryEmail := ""
for _, email := range rDat {
if email.Primary {
primaryEmail = email.Email
break
}
}
token, err := jwt.ParseWithClaims(state, &IntermediateJWT{}, func(t *jwt.Token) (interface{}, error) {
return key, nil
})
@ -165,10 +166,60 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error {
} else {
authCheckInterval = i
}
SetZrokCookie(w, cfg.CookieDomain, primaryEmail, tokens.AccessToken, "github", authCheckInterval, key, token.Claims.(*IntermediateJWT).Host)
SetZrokCookie(w, cfg.CookieDomain, email, tokens.AccessToken, p.name, authCheckInterval, key, token.Claims.(*IntermediateJWT).Host)
http.Redirect(w, r, fmt.Sprintf("%s://%s", scheme, token.Claims.(*IntermediateJWT).Host), http.StatusFound)
}
http.Handle("/github/oauth", rp.CodeExchangeHandler(getEmail, relyingParty))
return nil
http.Handle(fmt.Sprintf("/%s/login", p.name), authHandlerWithQueryState(p.relyingParty))
http.Handle(fmt.Sprintf("/%s/oauth", p.name), rp.CodeExchangeHandler(getEmail, p.relyingParty))
}
func (p *OIDCProvider) extractEmail(response []byte) (string, error) {
var data interface{}
if err := json.Unmarshal(response, &data); err != nil {
return "", err
}
// handle array response (like GitHub's email endpoint)
if arr, ok := data.([]interface{}); ok {
for _, item := range arr {
if email, found := p.findEmailInMap(item.(map[string]interface{})); found {
return email, nil
}
}
return "", errors.New("no primary email found in array response")
}
// handle single object response (like Google's userinfo endpoint)
if obj, ok := data.(map[string]interface{}); ok {
if email, found := p.findEmailInMap(obj); found {
return email, nil
}
return "", errors.New("no email found in object response")
}
return "", errors.New("unexpected response format")
}
func (p *OIDCProvider) findEmailInMap(obj map[string]interface{}) (string, bool) {
paths := strings.Split(p.emailPath, ".")
current := obj
for i, path := range paths {
if i == len(paths)-1 {
if email, ok := current[path].(string); ok {
return email, true
}
return "", false
}
if next, ok := current[path].(map[string]interface{}); ok {
current = next
} else {
return "", false
}
}
return "", false
}