lint and cleanup; oauthLoginRequired (#404)

This commit is contained in:
Michael Quigley 2023-09-26 13:42:41 -04:00
parent b63b1fc145
commit adbe4e78c0
No known key found for this signature in database
GPG Key ID: 9B60314A9DD20A62
2 changed files with 42 additions and 31 deletions

View File

@ -1,8 +1,13 @@
package publicProxy
import (
"context"
"fmt"
"github.com/michaelquigley/cf"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
zhttp "github.com/zitadel/oidc/v2/pkg/http"
"strings"
)
type Config struct {
@ -16,10 +21,10 @@ type OauthConfig struct {
Port int
RedirectUrl string
HashKeyRaw string
Providers []*OauthProviderSecrets
Providers []*OauthProviderConfig
}
func (oc *OauthConfig) GetProvider(name string) *OauthProviderSecrets {
func (oc *OauthConfig) GetProvider(name string) *OauthProviderConfig {
for _, provider := range oc.Providers {
if provider.Name == name {
return provider
@ -28,7 +33,7 @@ func (oc *OauthConfig) GetProvider(name string) *OauthProviderSecrets {
return nil
}
type OauthProviderSecrets struct {
type OauthProviderConfig struct {
Name string
ClientId string
ClientSecret string
@ -47,3 +52,18 @@ func (c *Config) Load(path string) error {
}
return nil
}
func configureOauthHandlers(ctx context.Context, cfg *Config, tls bool) error {
if cfg.Oauth == nil {
logrus.Info("no oauth configuration; skipping oauth handler startup")
return nil
}
if err := configureGoogleOauth(cfg.Oauth, tls); err != nil {
return err
}
if err := configureGithubOauth(cfg.Oauth, tls); err != nil {
return err
}
zhttp.StartServer(ctx, fmt.Sprintf("%s:%d", strings.Split(cfg.Address, ":")[0], cfg.Oauth.Port))
return nil
}

View File

@ -14,7 +14,6 @@ import (
"github.com/openziti/zrok/util"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
zhttp "github.com/zitadel/oidc/v2/pkg/http"
"net"
"net/http"
"net/http/httputil"
@ -192,6 +191,8 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
handler.ServeHTTP(w, r)
case string(sdk.Oauth):
logrus.Debugf("auth scheme oauth '%v", shrToken)
if oauthCfg, found := cfg["oauth"]; found {
if provider, found := oauthCfg.(map[string]interface{})["provider"]; found {
var authCheckInterval time.Duration
@ -212,35 +213,35 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
cookie, err := r.Cookie("zrok-access")
if err != nil {
logrus.Errorf("Unable to get access cookie: %v", err)
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound)
logrus.Errorf("unable to get 'zrok-access' cookie: %v", err)
oauthLoginRequired(w, r, shrToken, pcfg, provider.(string), target, authCheckInterval)
return
}
tkn, err := jwt.ParseWithClaims(cookie.Value, &ZrokClaims{}, func(t *jwt.Token) (interface{}, error) {
if pcfg.Oauth == nil {
return nil, fmt.Errorf("missing oauth configuration for access point. Unable to parse jwt")
return nil, fmt.Errorf("missing oauth configuration for access point; unable to parse jwt")
}
return []byte(pcfg.Oauth.HashKeyRaw), nil
})
if err != nil {
logrus.Errorf("Unable to parse JWT: %v", err)
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound)
logrus.Errorf("unable to parse jwt: %v", err)
oauthLoginRequired(w, r, shrToken, pcfg, provider.(string), target, authCheckInterval)
return
}
claims := tkn.Claims.(*ZrokClaims)
if claims.Provider != provider {
logrus.Error("Provider mismatch. Redoing auth flow")
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound)
logrus.Error("provider mismatch; restarting auth flow")
oauthLoginRequired(w, r, shrToken, pcfg, provider.(string), target, authCheckInterval)
return
}
if claims.AuthorizationCheckInterval != authCheckInterval {
logrus.Error("Authorization check interval mismatch. Redoing auth flow")
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound)
logrus.Error("authorization check interval mismatch; restarting auth flow")
oauthLoginRequired(w, r, shrToken, pcfg, provider.(string), target, authCheckInterval)
return
}
if validDomains, found := oauthCfg.(map[string]interface{})["email_domains"]; found {
if castedDomains, ok := validDomains.([]interface{}); !ok {
logrus.Error("Invalid format for valid email domains")
logrus.Error("invalid email domain format")
return
} else {
if len(castedDomains) > 0 {
@ -252,7 +253,7 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
}
}
if !found {
logrus.Warnf("Email not a valid domain")
logrus.Warnf("invalid email domain")
unauthorizedUi.WriteUnauthorized(w)
return
}
@ -261,6 +262,7 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
}
handler.ServeHTTP(w, r)
return
} else {
logrus.Warnf("%v -> no provider for '%v'", r.RemoteAddr, provider)
notFoundUi.WriteNotFound(w)
@ -293,21 +295,6 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
}
}
func configureOauthHandlers(ctx context.Context, cfg *Config, tls bool) error {
if cfg.Oauth == nil {
logrus.Info("No oauth config for access point. Skipping spin up.")
return nil
}
if err := configureGoogleOauth(cfg.Oauth, tls); err != nil {
return err
}
if err := configureGithubOauth(cfg.Oauth, tls); err != nil {
return err
}
zhttp.StartServer(ctx, fmt.Sprintf("%s:%d", strings.Split(cfg.Address, ":")[0], cfg.Oauth.Port))
return nil
}
type ZrokClaims struct {
Email string `json:"email"`
AccessToken string `json:"accessToken"`
@ -325,7 +312,7 @@ func SetZrokCookie(w http.ResponseWriter, domain, email, accessToken, provider s
})
sTkn, err := tkn.SignedString(key)
if err != nil {
http.Error(w, fmt.Sprintf("After signing cookie token: %v", err.Error()), http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("after signing cookie token: %v", err.Error()), http.StatusInternalServerError)
return
}
@ -346,6 +333,10 @@ func basicAuthRequired(w http.ResponseWriter, realm string) {
w.Write([]byte("No Authorization\n"))
}
func oauthLoginRequired(w http.ResponseWriter, r *http.Request, shrToken string, pcfg *Config, provider, target string, authCheckInterval time.Duration) {
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider, url.QueryEscape(target), authCheckInterval.String()), http.StatusFound)
}
func resolveService(hostMatch string, host string) string {
logrus.Debugf("host = '%v'", host)
if hostMatch == "" || strings.Contains(host, hostMatch) {