diff --git a/endpoints/publicProxy/config.go b/endpoints/publicProxy/config.go index c39708a7..aa5d1a23 100644 --- a/endpoints/publicProxy/config.go +++ b/endpoints/publicProxy/config.go @@ -2,11 +2,14 @@ package publicProxy import ( "context" + "crypto/md5" + "github.com/michaelquigley/cf" "github.com/openziti/zrok/endpoints" "github.com/pkg/errors" "github.com/sirupsen/logrus" zhttp "github.com/zitadel/oidc/v2/pkg/http" + "golang.org/x/oauth2" ) const V = 3 @@ -35,19 +38,23 @@ type OauthConfig struct { Providers []*OauthProviderConfig } -func (oc *OauthConfig) GetProvider(name string) *OauthProviderConfig { - for _, provider := range oc.Providers { - if provider.Name == name { - return provider - } - } - return nil +type OauthProviderConfig struct { + Name string + ClientId string + ClientSecret string `cf:"+secret"` + Scopes []string + AuthURL string + TokenURL string + EmailEndpoint string + EmailPath string + SupportsPKCE bool } -type OauthProviderConfig struct { - Name string - ClientId string - ClientSecret string `cf:"+secret"` +func (p *OauthProviderConfig) GetEndpoint() oauth2.Endpoint { + return oauth2.Endpoint{ + AuthURL: p.AuthURL, + TokenURL: p.TokenURL, + } } func DefaultConfig() *Config { @@ -72,12 +79,24 @@ func configureOauthHandlers(ctx context.Context, cfg *Config, tls bool) error { logrus.Info("no oauth configuration; skipping oauth handler startup") return nil } - if err := configureGoogleOauth(cfg.Oauth, tls); err != nil { + + hash := md5.New() + if n, err := hash.Write([]byte(cfg.Oauth.HashKey)); err != nil { return err + } else if n != len(cfg.Oauth.HashKey) { + return errors.New("short hash") } - if err := configureGithubOauth(cfg.Oauth, tls); err != nil { - return err + key := hash.Sum(nil) + + for _, providerCfg := range cfg.Oauth.Providers { + provider, err := configureOIDCProvider(cfg.Oauth, providerCfg, tls) + if err != nil { + logrus.Warnf("failed to configure provider %s: %v", providerCfg.Name, err) + continue + } + provider.setupHandlers(cfg.Oauth, key, tls) } + zhttp.StartServer(ctx, cfg.Oauth.BindAddress) return nil } diff --git a/endpoints/publicProxy/google.go b/endpoints/publicProxy/google.go deleted file mode 100644 index 5b62257c..00000000 --- a/endpoints/publicProxy/google.go +++ /dev/null @@ -1,154 +0,0 @@ -package publicProxy - -import ( - "crypto/md5" - "encoding/json" - "errors" - "fmt" - "github.com/golang-jwt/jwt/v5" - "github.com/google/uuid" - "github.com/sirupsen/logrus" - "github.com/zitadel/oidc/v2/pkg/client/rp" - zhttp "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" - "golang.org/x/oauth2" - googleOauth "golang.org/x/oauth2/google" - "io" - "net/http" - "net/url" - "time" -) - -func configureGoogleOauth(cfg *OauthConfig, tls bool) error { - scheme := "http" - if tls { - scheme = "https" - } - - providerCfg := cfg.GetProvider("google") - if providerCfg == nil { - logrus.Info("unable to find provider config for google. Skipping.") - return nil - } - - clientID := providerCfg.ClientId - rpConfig := &oauth2.Config{ - ClientID: clientID, - ClientSecret: providerCfg.ClientSecret, - RedirectURL: fmt.Sprintf("%v/google/oauth", cfg.RedirectUrl), - Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"}, - Endpoint: googleOauth.Endpoint, - } - - hash := md5.New() - n, err := hash.Write([]byte(cfg.HashKey)) - if err != nil { - return err - } - if n != len(cfg.HashKey) { - return errors.New("short hash") - } - key := hash.Sum(nil) - - cookieHandler := zhttp.NewCookieHandler(key, key, zhttp.WithUnsecure(), zhttp.WithDomain(cfg.CookieDomain)) - - options := []rp.Option{ - rp.WithCookieHandler(cookieHandler), - rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)), - rp.WithPKCE(cookieHandler), - } - - relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, options...) - if err != nil { - return err - } - - type IntermediateJWT struct { - State string `json:"state"` - Host string `json:"host"` - AuthorizationCheckInterval string `json:"authorizationCheckInterval"` - jwt.RegisteredClaims - } - - type googleOauthEmailResp struct { - Email string - } - - authHandlerWithQueryState := func(party rp.RelyingParty) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - host, err := url.QueryUnescape(r.URL.Query().Get("targethost")) - if err != nil { - logrus.Errorf("unable to unescape target host: %v", err) - } - rp.AuthURLHandler(func() string { - id := uuid.New().String() - t := jwt.NewWithClaims(jwt.SigningMethodHS256, IntermediateJWT{ - id, - host, - r.URL.Query().Get("checkInterval"), - jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - Issuer: "zrok", - Subject: "intermediate_token", - ID: id, - }, - }) - s, err := t.SignedString(key) - if err != nil { - logrus.Errorf("unable to sign intermediate JWT: %v", err) - } - return s - }, party, rp.WithURLParam("access_type", "offline"))(w, r) - } - } - - http.Handle("/google/login", authHandlerWithQueryState(relyingParty)) - getEmail := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty) { - resp, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + url.QueryEscape(tokens.AccessToken)) - if err != nil { - logrus.Errorf("error getting user info from google: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer func() { - _ = resp.Body.Close() - }() - response, err := io.ReadAll(resp.Body) - if err != nil { - logrus.Errorf("error reading response body: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - logrus.Infof("response from google userinfo endpoint: %s", string(response)) - rDat := googleOauthEmailResp{} - err = json.Unmarshal(response, &rDat) - if err != nil { - logrus.Errorf("error unmarshalling google oauth response: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - token, err := jwt.ParseWithClaims(state, &IntermediateJWT{}, func(t *jwt.Token) (interface{}, error) { - return key, nil - }) - if err != nil { - http.Error(w, fmt.Sprintf("After intermediate token parse: %v", err.Error()), http.StatusInternalServerError) - return - } - - authCheckInterval := 3 * time.Hour - i, err := time.ParseDuration(token.Claims.(*IntermediateJWT).AuthorizationCheckInterval) - if err != nil { - logrus.Errorf("unable to parse authorization check interval: %v. Defaulting to 3 hours", err) - } else { - authCheckInterval = i - } - SetZrokCookie(w, cfg.CookieDomain, rDat.Email, tokens.AccessToken, "google", authCheckInterval, key, token.Claims.(*IntermediateJWT).Host) - http.Redirect(w, r, fmt.Sprintf("%s://%s", scheme, token.Claims.(*IntermediateJWT).Host), http.StatusFound) - } - - http.Handle("/google/oauth", rp.CodeExchangeHandler(getEmail, relyingParty)) - return nil -} diff --git a/endpoints/publicProxy/github.go b/endpoints/publicProxy/oidc.go similarity index 52% rename from endpoints/publicProxy/github.go rename to endpoints/publicProxy/oidc.go index e8c4ef25..88769084 100644 --- a/endpoints/publicProxy/github.go +++ b/endpoints/publicProxy/oidc.go @@ -5,6 +5,12 @@ import ( "encoding/json" "errors" "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/sirupsen/logrus" @@ -12,40 +18,41 @@ import ( zhttp "github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v2/pkg/oidc" "golang.org/x/oauth2" - githubOAuth "golang.org/x/oauth2/github" - "io" - "net/http" - "net/url" - "time" ) -func configureGithubOauth(cfg *OauthConfig, tls bool) error { - scheme := "http" - if tls { - scheme = "https" +type OIDCProvider struct { + name string + config *oauth2.Config + relyingParty rp.RelyingParty + emailEndpoint string + emailPath string +} + +type IntermediateJWT struct { + State string `json:"state"` + Host string `json:"host"` + AuthorizationCheckInterval string `json:"authorizationCheckInterval"` + jwt.RegisteredClaims +} + +func configureOIDCProvider(cfg *OauthConfig, providerCfg *OauthProviderConfig, tls bool) (*OIDCProvider, error) { + if providerCfg == nil { + return nil, errors.New("provider configuration is required") } - providerCfg := cfg.GetProvider("github") - if providerCfg == nil { - logrus.Info("unable to find provider config for github; skipping") - return nil - } - clientID := providerCfg.ClientId rpConfig := &oauth2.Config{ - ClientID: clientID, + ClientID: providerCfg.ClientId, ClientSecret: providerCfg.ClientSecret, - RedirectURL: fmt.Sprintf("%v/github/oauth", cfg.RedirectUrl), - Scopes: []string{"user:email"}, - Endpoint: githubOAuth.Endpoint, + RedirectURL: fmt.Sprintf("%v/%s/oauth", cfg.RedirectUrl, providerCfg.Name), + Scopes: providerCfg.Scopes, + Endpoint: providerCfg.GetEndpoint(), } hash := md5.New() - n, err := hash.Write([]byte(cfg.HashKey)) - if err != nil { - return err - } - if n != len(cfg.HashKey) { - return errors.New("short hash") + if n, err := hash.Write([]byte(cfg.HashKey)); err != nil { + return nil, err + } else if n != len(cfg.HashKey) { + return nil, errors.New("short hash") } key := hash.Sum(nil) @@ -54,26 +61,30 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error { options := []rp.Option{ rp.WithCookieHandler(cookieHandler), rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)), - //rp.WithPKCE(cookieHandler), //Github currently doesn't support pkce. Update when that changes. + } + + if providerCfg.SupportsPKCE { + options = append(options, rp.WithPKCE(cookieHandler)) } relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, options...) if err != nil { - return err + return nil, err } - type IntermediateJWT struct { - State string `json:"state"` - Host string `json:"host"` - AuthorizationCheckInterval string `json:"authorizationCheckInterval"` - jwt.RegisteredClaims - } + return &OIDCProvider{ + name: providerCfg.Name, + config: rpConfig, + relyingParty: relyingParty, + emailEndpoint: providerCfg.EmailEndpoint, + emailPath: providerCfg.EmailPath, + }, nil +} - type githubUserResp struct { - Email string - Primary bool - Verified bool - Visibility string +func (p *OIDCProvider) setupHandlers(cfg *OauthConfig, key []byte, tls bool) { + scheme := "http" + if tls { + scheme = "https" } authHandlerWithQueryState := func(party rp.RelyingParty) http.HandlerFunc { @@ -106,9 +117,8 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error { } } - http.Handle("/github/login", authHandlerWithQueryState(relyingParty)) getEmail := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty) { - parsedUrl, err := url.Parse("https://api.github.com/user/emails") + parsedUrl, err := url.Parse(p.emailEndpoint) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -121,35 +131,26 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error { req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokens.AccessToken)) resp, err := http.DefaultClient.Do(req) if err != nil { - logrus.Errorf("error getting user info from github: %v", err) + logrus.Errorf("error getting user info from %s: %v", p.name, err) http.Error(w, err.Error(), http.StatusInternalServerError) return } - defer func() { - _ = resp.Body.Close() - }() + defer func() { _ = resp.Body.Close() }() + response, err := io.ReadAll(resp.Body) if err != nil { logrus.Errorf("error reading response body: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } - var rDat []githubUserResp - err = json.Unmarshal(response, &rDat) + + email, err := p.extractEmail(response) if err != nil { - logrus.Errorf("error unmarshalling google oauth response: %v", err) + logrus.Errorf("error extracting email: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } - primaryEmail := "" - for _, email := range rDat { - if email.Primary { - primaryEmail = email.Email - break - } - } - token, err := jwt.ParseWithClaims(state, &IntermediateJWT{}, func(t *jwt.Token) (interface{}, error) { return key, nil }) @@ -165,10 +166,60 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error { } else { authCheckInterval = i } - SetZrokCookie(w, cfg.CookieDomain, primaryEmail, tokens.AccessToken, "github", authCheckInterval, key, token.Claims.(*IntermediateJWT).Host) + + SetZrokCookie(w, cfg.CookieDomain, email, tokens.AccessToken, p.name, authCheckInterval, key, token.Claims.(*IntermediateJWT).Host) http.Redirect(w, r, fmt.Sprintf("%s://%s", scheme, token.Claims.(*IntermediateJWT).Host), http.StatusFound) } - http.Handle("/github/oauth", rp.CodeExchangeHandler(getEmail, relyingParty)) - return nil + http.Handle(fmt.Sprintf("/%s/login", p.name), authHandlerWithQueryState(p.relyingParty)) + http.Handle(fmt.Sprintf("/%s/oauth", p.name), rp.CodeExchangeHandler(getEmail, p.relyingParty)) +} + +func (p *OIDCProvider) extractEmail(response []byte) (string, error) { + var data interface{} + if err := json.Unmarshal(response, &data); err != nil { + return "", err + } + + // handle array response (like GitHub's email endpoint) + if arr, ok := data.([]interface{}); ok { + for _, item := range arr { + if email, found := p.findEmailInMap(item.(map[string]interface{})); found { + return email, nil + } + } + return "", errors.New("no primary email found in array response") + } + + // handle single object response (like Google's userinfo endpoint) + if obj, ok := data.(map[string]interface{}); ok { + if email, found := p.findEmailInMap(obj); found { + return email, nil + } + return "", errors.New("no email found in object response") + } + + return "", errors.New("unexpected response format") +} + +func (p *OIDCProvider) findEmailInMap(obj map[string]interface{}) (string, bool) { + paths := strings.Split(p.emailPath, ".") + current := obj + + for i, path := range paths { + if i == len(paths)-1 { + if email, ok := current[path].(string); ok { + return email, true + } + return "", false + } + + if next, ok := current[path].(map[string]interface{}); ok { + current = next + } else { + return "", false + } + } + + return "", false }