From 20f0f3a0e871cbcfd16ef486e4a552a3460e8b16 Mon Sep 17 00:00:00 2001 From: Michael Quigley Date: Tue, 10 Jun 2025 14:32:47 -0400 Subject: [PATCH] tested with google; need to fix host/claims mismatch reauth (#968) --- endpoints/publicProxy/http.go | 27 +++++++++++++++------ endpoints/publicProxy/oidc.go | 44 +++++++++++++++++++++++++++++++---- 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/endpoints/publicProxy/http.go b/endpoints/publicProxy/http.go index 62a11057..5b0c8e9c 100644 --- a/endpoints/publicProxy/http.go +++ b/endpoints/publicProxy/http.go @@ -4,6 +4,13 @@ import ( "context" "crypto/md5" "fmt" + "net" + "net/http" + "net/http/httputil" + "net/url" + "strings" + "time" + "github.com/gobwas/glob" "github.com/golang-jwt/jwt/v5" "github.com/openziti/sdk-golang/ziti" @@ -17,12 +24,6 @@ import ( "github.com/openziti/zrok/util" "github.com/pkg/errors" "github.com/sirupsen/logrus" - "net" - "net/http" - "net/http/httputil" - "net/url" - "strings" - "time" ) type HttpFrontend struct { @@ -435,7 +436,19 @@ func basicAuthRequired(w http.ResponseWriter, realm string) { } func oauthLoginRequired(w http.ResponseWriter, r *http.Request, cfg *OauthConfig, provider, target string, authCheckInterval time.Duration) { - http.Redirect(w, r, fmt.Sprintf("%s/%s/login?targethost=%s&checkInterval=%s", cfg.RedirectUrl, provider, url.QueryEscape(target), authCheckInterval.String()), http.StatusFound) + targetHost := r.Host + if targetHost == "" { + logrus.Error("request host is empty") + http.Error(w, "Invalid request host", http.StatusBadRequest) + return + } + + http.Redirect(w, r, fmt.Sprintf("%s/oauth/%s/login?targethost=%s&checkInterval=%s", + cfg.RedirectUrl, + provider, + url.QueryEscape(targetHost), + authCheckInterval.String()), + http.StatusFound) } func resolveService(hostMatch string, host string) string { diff --git a/endpoints/publicProxy/oidc.go b/endpoints/publicProxy/oidc.go index 88769084..21c1f58e 100644 --- a/endpoints/publicProxy/oidc.go +++ b/endpoints/publicProxy/oidc.go @@ -36,6 +36,8 @@ type IntermediateJWT struct { } func configureOIDCProvider(cfg *OauthConfig, providerCfg *OauthProviderConfig, tls bool) (*OIDCProvider, error) { + logrus.Infof("configuring oidc provider: %v", providerCfg.Name) + if providerCfg == nil { return nil, errors.New("provider configuration is required") } @@ -43,7 +45,7 @@ func configureOIDCProvider(cfg *OauthConfig, providerCfg *OauthProviderConfig, t rpConfig := &oauth2.Config{ ClientID: providerCfg.ClientId, ClientSecret: providerCfg.ClientSecret, - RedirectURL: fmt.Sprintf("%v/%s/oauth", cfg.RedirectUrl, providerCfg.Name), + RedirectURL: fmt.Sprintf("%v/oauth/%s", cfg.RedirectUrl, providerCfg.Name), Scopes: providerCfg.Scopes, Endpoint: providerCfg.GetEndpoint(), } @@ -92,7 +94,36 @@ func (p *OIDCProvider) setupHandlers(cfg *OauthConfig, key []byte, tls bool) { host, err := url.QueryUnescape(r.URL.Query().Get("targethost")) if err != nil { logrus.Errorf("unable to unescape target host: %v", err) + deleteZrokCookies(w, r) + http.Error(w, "Invalid target host", http.StatusBadRequest) + return } + + // Clean up the host value + host = strings.TrimSpace(host) + if host == "" { + logrus.Error("target host is empty") + deleteZrokCookies(w, r) + http.Error(w, "Empty target host", http.StatusBadRequest) + return + } + + // Remove any scheme, path, or query parameters + if strings.Contains(host, "://") { + if parsedURL, err := url.Parse(host); err == nil && parsedURL.Host != "" { + host = parsedURL.Host + } + } + // If there's still a path component, take only the first part + host = strings.Split(host, "/")[0] + + if host == "" { + logrus.Error("failed to extract valid host") + deleteZrokCookies(w, r) + http.Error(w, "Invalid target host", http.StatusBadRequest) + return + } + rp.AuthURLHandler(func() string { id := uuid.New().String() t := jwt.NewWithClaims(jwt.SigningMethodHS256, IntermediateJWT{ @@ -167,12 +198,15 @@ func (p *OIDCProvider) setupHandlers(cfg *OauthConfig, key []byte, tls bool) { authCheckInterval = i } - SetZrokCookie(w, cfg.CookieDomain, email, tokens.AccessToken, p.name, authCheckInterval, key, token.Claims.(*IntermediateJWT).Host) - http.Redirect(w, r, fmt.Sprintf("%s://%s", scheme, token.Claims.(*IntermediateJWT).Host), http.StatusFound) + targetHost := token.Claims.(*IntermediateJWT).Host + logrus.Infof("setting cookie and redirecting to host: %s", targetHost) + + SetZrokCookie(w, cfg.CookieDomain, email, tokens.AccessToken, p.name, authCheckInterval, key, targetHost) + http.Redirect(w, r, fmt.Sprintf("%s://%s", scheme, targetHost), http.StatusFound) } - http.Handle(fmt.Sprintf("/%s/login", p.name), authHandlerWithQueryState(p.relyingParty)) - http.Handle(fmt.Sprintf("/%s/oauth", p.name), rp.CodeExchangeHandler(getEmail, p.relyingParty)) + http.Handle(fmt.Sprintf("/oauth/%s/login", p.name), authHandlerWithQueryState(p.relyingParty)) + http.Handle(fmt.Sprintf("/oauth/%s", p.name), rp.CodeExchangeHandler(getEmail, p.relyingParty)) } func (p *OIDCProvider) extractEmail(response []byte) (string, error) {