mirror of
https://github.com/openziti/zrok.git
synced 2025-06-26 12:42:18 +02:00
draft flexible oidc configuration (#968)
This commit is contained in:
parent
1ac77fa5b9
commit
017c17156f
@ -2,11 +2,14 @@ package publicProxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
|
||||
"github.com/michaelquigley/cf"
|
||||
"github.com/openziti/zrok/endpoints"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
zhttp "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const V = 3
|
||||
@ -35,19 +38,23 @@ type OauthConfig struct {
|
||||
Providers []*OauthProviderConfig
|
||||
}
|
||||
|
||||
func (oc *OauthConfig) GetProvider(name string) *OauthProviderConfig {
|
||||
for _, provider := range oc.Providers {
|
||||
if provider.Name == name {
|
||||
return provider
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type OauthProviderConfig struct {
|
||||
Name string
|
||||
ClientId string
|
||||
ClientSecret string `cf:"+secret"`
|
||||
Scopes []string
|
||||
AuthURL string
|
||||
TokenURL string
|
||||
EmailEndpoint string
|
||||
EmailPath string
|
||||
SupportsPKCE bool
|
||||
}
|
||||
|
||||
func (p *OauthProviderConfig) GetEndpoint() oauth2.Endpoint {
|
||||
return oauth2.Endpoint{
|
||||
AuthURL: p.AuthURL,
|
||||
TokenURL: p.TokenURL,
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
@ -72,12 +79,24 @@ func configureOauthHandlers(ctx context.Context, cfg *Config, tls bool) error {
|
||||
logrus.Info("no oauth configuration; skipping oauth handler startup")
|
||||
return nil
|
||||
}
|
||||
if err := configureGoogleOauth(cfg.Oauth, tls); err != nil {
|
||||
|
||||
hash := md5.New()
|
||||
if n, err := hash.Write([]byte(cfg.Oauth.HashKey)); err != nil {
|
||||
return err
|
||||
} else if n != len(cfg.Oauth.HashKey) {
|
||||
return errors.New("short hash")
|
||||
}
|
||||
if err := configureGithubOauth(cfg.Oauth, tls); err != nil {
|
||||
return err
|
||||
key := hash.Sum(nil)
|
||||
|
||||
for _, providerCfg := range cfg.Oauth.Providers {
|
||||
provider, err := configureOIDCProvider(cfg.Oauth, providerCfg, tls)
|
||||
if err != nil {
|
||||
logrus.Warnf("failed to configure provider %s: %v", providerCfg.Name, err)
|
||||
continue
|
||||
}
|
||||
provider.setupHandlers(cfg.Oauth, key, tls)
|
||||
}
|
||||
|
||||
zhttp.StartServer(ctx, cfg.Oauth.BindAddress)
|
||||
return nil
|
||||
}
|
||||
|
@ -1,154 +0,0 @@
|
||||
package publicProxy
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
zhttp "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
googleOauth "golang.org/x/oauth2/google"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
func configureGoogleOauth(cfg *OauthConfig, tls bool) error {
|
||||
scheme := "http"
|
||||
if tls {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
providerCfg := cfg.GetProvider("google")
|
||||
if providerCfg == nil {
|
||||
logrus.Info("unable to find provider config for google. Skipping.")
|
||||
return nil
|
||||
}
|
||||
|
||||
clientID := providerCfg.ClientId
|
||||
rpConfig := &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: providerCfg.ClientSecret,
|
||||
RedirectURL: fmt.Sprintf("%v/google/oauth", cfg.RedirectUrl),
|
||||
Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"},
|
||||
Endpoint: googleOauth.Endpoint,
|
||||
}
|
||||
|
||||
hash := md5.New()
|
||||
n, err := hash.Write([]byte(cfg.HashKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n != len(cfg.HashKey) {
|
||||
return errors.New("short hash")
|
||||
}
|
||||
key := hash.Sum(nil)
|
||||
|
||||
cookieHandler := zhttp.NewCookieHandler(key, key, zhttp.WithUnsecure(), zhttp.WithDomain(cfg.CookieDomain))
|
||||
|
||||
options := []rp.Option{
|
||||
rp.WithCookieHandler(cookieHandler),
|
||||
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
|
||||
rp.WithPKCE(cookieHandler),
|
||||
}
|
||||
|
||||
relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, options...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
type IntermediateJWT struct {
|
||||
State string `json:"state"`
|
||||
Host string `json:"host"`
|
||||
AuthorizationCheckInterval string `json:"authorizationCheckInterval"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type googleOauthEmailResp struct {
|
||||
Email string
|
||||
}
|
||||
|
||||
authHandlerWithQueryState := func(party rp.RelyingParty) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
host, err := url.QueryUnescape(r.URL.Query().Get("targethost"))
|
||||
if err != nil {
|
||||
logrus.Errorf("unable to unescape target host: %v", err)
|
||||
}
|
||||
rp.AuthURLHandler(func() string {
|
||||
id := uuid.New().String()
|
||||
t := jwt.NewWithClaims(jwt.SigningMethodHS256, IntermediateJWT{
|
||||
id,
|
||||
host,
|
||||
r.URL.Query().Get("checkInterval"),
|
||||
jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
Issuer: "zrok",
|
||||
Subject: "intermediate_token",
|
||||
ID: id,
|
||||
},
|
||||
})
|
||||
s, err := t.SignedString(key)
|
||||
if err != nil {
|
||||
logrus.Errorf("unable to sign intermediate JWT: %v", err)
|
||||
}
|
||||
return s
|
||||
}, party, rp.WithURLParam("access_type", "offline"))(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
http.Handle("/google/login", authHandlerWithQueryState(relyingParty))
|
||||
getEmail := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty) {
|
||||
resp, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + url.QueryEscape(tokens.AccessToken))
|
||||
if err != nil {
|
||||
logrus.Errorf("error getting user info from google: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
response, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logrus.Errorf("error reading response body: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
logrus.Infof("response from google userinfo endpoint: %s", string(response))
|
||||
rDat := googleOauthEmailResp{}
|
||||
err = json.Unmarshal(response, &rDat)
|
||||
if err != nil {
|
||||
logrus.Errorf("error unmarshalling google oauth response: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(state, &IntermediateJWT{}, func(t *jwt.Token) (interface{}, error) {
|
||||
return key, nil
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("After intermediate token parse: %v", err.Error()), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
authCheckInterval := 3 * time.Hour
|
||||
i, err := time.ParseDuration(token.Claims.(*IntermediateJWT).AuthorizationCheckInterval)
|
||||
if err != nil {
|
||||
logrus.Errorf("unable to parse authorization check interval: %v. Defaulting to 3 hours", err)
|
||||
} else {
|
||||
authCheckInterval = i
|
||||
}
|
||||
SetZrokCookie(w, cfg.CookieDomain, rDat.Email, tokens.AccessToken, "google", authCheckInterval, key, token.Claims.(*IntermediateJWT).Host)
|
||||
http.Redirect(w, r, fmt.Sprintf("%s://%s", scheme, token.Claims.(*IntermediateJWT).Host), http.StatusFound)
|
||||
}
|
||||
|
||||
http.Handle("/google/oauth", rp.CodeExchangeHandler(getEmail, relyingParty))
|
||||
return nil
|
||||
}
|
@ -5,6 +5,12 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -12,54 +18,14 @@ import (
|
||||
zhttp "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
githubOAuth "golang.org/x/oauth2/github"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
func configureGithubOauth(cfg *OauthConfig, tls bool) error {
|
||||
scheme := "http"
|
||||
if tls {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
providerCfg := cfg.GetProvider("github")
|
||||
if providerCfg == nil {
|
||||
logrus.Info("unable to find provider config for github; skipping")
|
||||
return nil
|
||||
}
|
||||
clientID := providerCfg.ClientId
|
||||
rpConfig := &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: providerCfg.ClientSecret,
|
||||
RedirectURL: fmt.Sprintf("%v/github/oauth", cfg.RedirectUrl),
|
||||
Scopes: []string{"user:email"},
|
||||
Endpoint: githubOAuth.Endpoint,
|
||||
}
|
||||
|
||||
hash := md5.New()
|
||||
n, err := hash.Write([]byte(cfg.HashKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n != len(cfg.HashKey) {
|
||||
return errors.New("short hash")
|
||||
}
|
||||
key := hash.Sum(nil)
|
||||
|
||||
cookieHandler := zhttp.NewCookieHandler(key, key, zhttp.WithUnsecure(), zhttp.WithDomain(cfg.CookieDomain))
|
||||
|
||||
options := []rp.Option{
|
||||
rp.WithCookieHandler(cookieHandler),
|
||||
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
|
||||
//rp.WithPKCE(cookieHandler), //Github currently doesn't support pkce. Update when that changes.
|
||||
}
|
||||
|
||||
relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, options...)
|
||||
if err != nil {
|
||||
return err
|
||||
type OIDCProvider struct {
|
||||
name string
|
||||
config *oauth2.Config
|
||||
relyingParty rp.RelyingParty
|
||||
emailEndpoint string
|
||||
emailPath string
|
||||
}
|
||||
|
||||
type IntermediateJWT struct {
|
||||
@ -69,11 +35,56 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error {
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type githubUserResp struct {
|
||||
Email string
|
||||
Primary bool
|
||||
Verified bool
|
||||
Visibility string
|
||||
func configureOIDCProvider(cfg *OauthConfig, providerCfg *OauthProviderConfig, tls bool) (*OIDCProvider, error) {
|
||||
if providerCfg == nil {
|
||||
return nil, errors.New("provider configuration is required")
|
||||
}
|
||||
|
||||
rpConfig := &oauth2.Config{
|
||||
ClientID: providerCfg.ClientId,
|
||||
ClientSecret: providerCfg.ClientSecret,
|
||||
RedirectURL: fmt.Sprintf("%v/%s/oauth", cfg.RedirectUrl, providerCfg.Name),
|
||||
Scopes: providerCfg.Scopes,
|
||||
Endpoint: providerCfg.GetEndpoint(),
|
||||
}
|
||||
|
||||
hash := md5.New()
|
||||
if n, err := hash.Write([]byte(cfg.HashKey)); err != nil {
|
||||
return nil, err
|
||||
} else if n != len(cfg.HashKey) {
|
||||
return nil, errors.New("short hash")
|
||||
}
|
||||
key := hash.Sum(nil)
|
||||
|
||||
cookieHandler := zhttp.NewCookieHandler(key, key, zhttp.WithUnsecure(), zhttp.WithDomain(cfg.CookieDomain))
|
||||
|
||||
options := []rp.Option{
|
||||
rp.WithCookieHandler(cookieHandler),
|
||||
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
|
||||
}
|
||||
|
||||
if providerCfg.SupportsPKCE {
|
||||
options = append(options, rp.WithPKCE(cookieHandler))
|
||||
}
|
||||
|
||||
relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &OIDCProvider{
|
||||
name: providerCfg.Name,
|
||||
config: rpConfig,
|
||||
relyingParty: relyingParty,
|
||||
emailEndpoint: providerCfg.EmailEndpoint,
|
||||
emailPath: providerCfg.EmailPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) setupHandlers(cfg *OauthConfig, key []byte, tls bool) {
|
||||
scheme := "http"
|
||||
if tls {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
authHandlerWithQueryState := func(party rp.RelyingParty) http.HandlerFunc {
|
||||
@ -106,9 +117,8 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error {
|
||||
}
|
||||
}
|
||||
|
||||
http.Handle("/github/login", authHandlerWithQueryState(relyingParty))
|
||||
getEmail := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty) {
|
||||
parsedUrl, err := url.Parse("https://api.github.com/user/emails")
|
||||
parsedUrl, err := url.Parse(p.emailEndpoint)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
@ -121,35 +131,26 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error {
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokens.AccessToken))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
logrus.Errorf("error getting user info from github: %v", err)
|
||||
logrus.Errorf("error getting user info from %s: %v", p.name, err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
response, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logrus.Errorf("error reading response body: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
var rDat []githubUserResp
|
||||
err = json.Unmarshal(response, &rDat)
|
||||
|
||||
email, err := p.extractEmail(response)
|
||||
if err != nil {
|
||||
logrus.Errorf("error unmarshalling google oauth response: %v", err)
|
||||
logrus.Errorf("error extracting email: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
primaryEmail := ""
|
||||
for _, email := range rDat {
|
||||
if email.Primary {
|
||||
primaryEmail = email.Email
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(state, &IntermediateJWT{}, func(t *jwt.Token) (interface{}, error) {
|
||||
return key, nil
|
||||
})
|
||||
@ -165,10 +166,60 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error {
|
||||
} else {
|
||||
authCheckInterval = i
|
||||
}
|
||||
SetZrokCookie(w, cfg.CookieDomain, primaryEmail, tokens.AccessToken, "github", authCheckInterval, key, token.Claims.(*IntermediateJWT).Host)
|
||||
|
||||
SetZrokCookie(w, cfg.CookieDomain, email, tokens.AccessToken, p.name, authCheckInterval, key, token.Claims.(*IntermediateJWT).Host)
|
||||
http.Redirect(w, r, fmt.Sprintf("%s://%s", scheme, token.Claims.(*IntermediateJWT).Host), http.StatusFound)
|
||||
}
|
||||
|
||||
http.Handle("/github/oauth", rp.CodeExchangeHandler(getEmail, relyingParty))
|
||||
return nil
|
||||
http.Handle(fmt.Sprintf("/%s/login", p.name), authHandlerWithQueryState(p.relyingParty))
|
||||
http.Handle(fmt.Sprintf("/%s/oauth", p.name), rp.CodeExchangeHandler(getEmail, p.relyingParty))
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) extractEmail(response []byte) (string, error) {
|
||||
var data interface{}
|
||||
if err := json.Unmarshal(response, &data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// handle array response (like GitHub's email endpoint)
|
||||
if arr, ok := data.([]interface{}); ok {
|
||||
for _, item := range arr {
|
||||
if email, found := p.findEmailInMap(item.(map[string]interface{})); found {
|
||||
return email, nil
|
||||
}
|
||||
}
|
||||
return "", errors.New("no primary email found in array response")
|
||||
}
|
||||
|
||||
// handle single object response (like Google's userinfo endpoint)
|
||||
if obj, ok := data.(map[string]interface{}); ok {
|
||||
if email, found := p.findEmailInMap(obj); found {
|
||||
return email, nil
|
||||
}
|
||||
return "", errors.New("no email found in object response")
|
||||
}
|
||||
|
||||
return "", errors.New("unexpected response format")
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) findEmailInMap(obj map[string]interface{}) (string, bool) {
|
||||
paths := strings.Split(p.emailPath, ".")
|
||||
current := obj
|
||||
|
||||
for i, path := range paths {
|
||||
if i == len(paths)-1 {
|
||||
if email, ok := current[path].(string); ok {
|
||||
return email, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
if next, ok := current[path].(map[string]interface{}); ok {
|
||||
current = next
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user