From 4be9089cfe5047d2baf8aea80b788b155846c5c5 Mon Sep 17 00:00:00 2001 From: Cam Date: Wed, 13 Sep 2023 10:37:38 -0500 Subject: [PATCH] fixed redirect to respect intended route, added additional logging around token swapping --- endpoints/publicProxy/github.go | 14 ++++++++++---- endpoints/publicProxy/google.go | 14 ++++++++++---- endpoints/publicProxy/http.go | 10 ++++++---- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/endpoints/publicProxy/github.go b/endpoints/publicProxy/github.go index fbc2dc55..2113d7f3 100644 --- a/endpoints/publicProxy/github.go +++ b/endpoints/publicProxy/github.go @@ -67,7 +67,7 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error { type IntermediateJWT struct { State string `json:"state"` - Share string `json:"share"` + Host string `json:"host"` AuthorizationCheckInterval string `json:"authorizationCheckInterval"` jwt.RegisteredClaims } @@ -81,11 +81,15 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error { authHandlerWithQueryState := func(party rp.RelyingParty) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + host, err := url.QueryUnescape(r.URL.Query().Get("targethost")) + if err != nil { + logrus.Errorf("Unable to unescape target host: %v", err) + } rp.AuthURLHandler(func() string { id := uuid.New().String() t := jwt.NewWithClaims(jwt.SigningMethodHS256, IntermediateJWT{ id, - r.URL.Query().Get("share"), + host, r.URL.Query().Get("checkInterval"), jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), @@ -120,19 +124,21 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error { req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokens.AccessToken)) resp, err := http.DefaultClient.Do(req) if err != nil { - logrus.Error("Get: " + err.Error() + "\n") + logrus.Error("Error getting user info from github: " + err.Error() + "\n") http.Error(w, err.Error(), http.StatusInternalServerError) return } defer resp.Body.Close() response, err := io.ReadAll(resp.Body) if err != nil { + logrus.Errorf("Error reading response body: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } rDat := []githubUserResp{} err = json.Unmarshal(response, &rDat) if err != nil { + logrus.Errorf("Error unmarshalling google oauth response: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -162,7 +168,7 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error { } SetZrokCookie(w, primaryEmail, tokens.AccessToken, "github", authCheckInterval, key) - http.Redirect(w, r, fmt.Sprintf("%s://%s.%s:8080", scheme, token.Claims.(*IntermediateJWT).Share, cfg.RedirectUrl), http.StatusFound) + http.Redirect(w, r, fmt.Sprintf("%s://%s", scheme, token.Claims.(*IntermediateJWT).Host), http.StatusFound) } http.Handle(callbackPath, rp.CodeExchangeHandler(getEmail, relyingParty)) diff --git a/endpoints/publicProxy/google.go b/endpoints/publicProxy/google.go index 66cfebc0..b3f136ab 100644 --- a/endpoints/publicProxy/google.go +++ b/endpoints/publicProxy/google.go @@ -68,7 +68,7 @@ func configureGoogleOauth(cfg *OauthConfig, tls bool) error { type IntermediateJWT struct { State string `json:"state"` - Share string `json:"share"` + Host string `json:"host"` AuthorizationCheckInterval string `json:"authorizationCheckInterval"` jwt.RegisteredClaims } @@ -79,11 +79,15 @@ func configureGoogleOauth(cfg *OauthConfig, tls bool) error { authHandlerWithQueryState := func(party rp.RelyingParty) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + host, err := url.QueryUnescape(r.URL.Query().Get("targethost")) + if err != nil { + logrus.Errorf("Unable to unescape target host: %v", err) + } rp.AuthURLHandler(func() string { id := uuid.New().String() t := jwt.NewWithClaims(jwt.SigningMethodHS256, IntermediateJWT{ id, - r.URL.Query().Get("share"), + host, r.URL.Query().Get("checkInterval"), jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), @@ -107,19 +111,21 @@ func configureGoogleOauth(cfg *OauthConfig, tls bool) error { getEmail := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty) { resp, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + url.QueryEscape(tokens.AccessToken)) if err != nil { - logrus.Error("Get: " + err.Error() + "\n") + logrus.Error("Error getting user info from google: " + err.Error() + "\n") http.Error(w, err.Error(), http.StatusInternalServerError) return } defer resp.Body.Close() response, err := io.ReadAll(resp.Body) if err != nil { + logrus.Errorf("Error reading response body: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } rDat := googleOauthEmailResp{} err = json.Unmarshal(response, &rDat) if err != nil { + logrus.Errorf("Error unmarshalling google oauth response: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -141,7 +147,7 @@ func configureGoogleOauth(cfg *OauthConfig, tls bool) error { } SetZrokCookie(w, rDat.Email, tokens.AccessToken, "google", authCheckInterval, key) - http.Redirect(w, r, fmt.Sprintf("%s://%s.%s:8080", scheme, token.Claims.(*IntermediateJWT).Share, cfg.RedirectUrl), http.StatusFound) + http.Redirect(w, r, fmt.Sprintf("%s://%s", scheme, token.Claims.(*IntermediateJWT).Host), http.StatusFound) } http.Handle(callbackPath, rp.CodeExchangeHandler(getEmail, relyingParty)) diff --git a/endpoints/publicProxy/http.go b/endpoints/publicProxy/http.go index 48ceaf57..240f5e98 100644 --- a/endpoints/publicProxy/http.go +++ b/endpoints/publicProxy/http.go @@ -208,10 +208,12 @@ func authHandler(handler http.Handler, realm string, pcfg *Config, ctx ziti.Cont } } + target := fmt.Sprintf("%s%s", r.Host, r.URL.Path) + cookie, err := r.Cookie("zrok-access") if err != nil { logrus.Errorf("Unable to get access cookie: %v", err) - http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?share=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), shrToken, authCheckInterval.String()), http.StatusFound) + http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound) return } tkn, err := jwt.ParseWithClaims(cookie.Value, &ZrokClaims{}, func(t *jwt.Token) (interface{}, error) { @@ -222,18 +224,18 @@ func authHandler(handler http.Handler, realm string, pcfg *Config, ctx ziti.Cont }) if err != nil { logrus.Errorf("Unable to parse JWT: %v", err) - http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?share=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), shrToken, authCheckInterval.String()), http.StatusFound) + http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound) return } claims := tkn.Claims.(*ZrokClaims) if claims.Provider != provider { logrus.Error("Provider mismatch. Redoing auth flow") - http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?share=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), shrToken, authCheckInterval.String()), http.StatusFound) + http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound) return } if claims.AuthorizationCheckInterval != authCheckInterval { logrus.Error("Authorization check interval mismatch. Redoing auth flow") - http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?share=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), shrToken, authCheckInterval.String()), http.StatusFound) + http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound) return } if validDomains, found := oauthCfg.(map[string]interface{})["email_domains"]; found {