lint and cleanup; oauthLoginRequired (#404)

This commit is contained in:
Michael Quigley 2023-09-26 13:42:41 -04:00
parent b63b1fc145
commit adbe4e78c0
No known key found for this signature in database
GPG Key ID: 9B60314A9DD20A62
2 changed files with 42 additions and 31 deletions

View File

@ -1,8 +1,13 @@
package publicProxy package publicProxy
import ( import (
"context"
"fmt"
"github.com/michaelquigley/cf" "github.com/michaelquigley/cf"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus"
zhttp "github.com/zitadel/oidc/v2/pkg/http"
"strings"
) )
type Config struct { type Config struct {
@ -16,10 +21,10 @@ type OauthConfig struct {
Port int Port int
RedirectUrl string RedirectUrl string
HashKeyRaw string HashKeyRaw string
Providers []*OauthProviderSecrets Providers []*OauthProviderConfig
} }
func (oc *OauthConfig) GetProvider(name string) *OauthProviderSecrets { func (oc *OauthConfig) GetProvider(name string) *OauthProviderConfig {
for _, provider := range oc.Providers { for _, provider := range oc.Providers {
if provider.Name == name { if provider.Name == name {
return provider return provider
@ -28,7 +33,7 @@ func (oc *OauthConfig) GetProvider(name string) *OauthProviderSecrets {
return nil return nil
} }
type OauthProviderSecrets struct { type OauthProviderConfig struct {
Name string Name string
ClientId string ClientId string
ClientSecret string ClientSecret string
@ -47,3 +52,18 @@ func (c *Config) Load(path string) error {
} }
return nil return nil
} }
func configureOauthHandlers(ctx context.Context, cfg *Config, tls bool) error {
if cfg.Oauth == nil {
logrus.Info("no oauth configuration; skipping oauth handler startup")
return nil
}
if err := configureGoogleOauth(cfg.Oauth, tls); err != nil {
return err
}
if err := configureGithubOauth(cfg.Oauth, tls); err != nil {
return err
}
zhttp.StartServer(ctx, fmt.Sprintf("%s:%d", strings.Split(cfg.Address, ":")[0], cfg.Oauth.Port))
return nil
}

View File

@ -14,7 +14,6 @@ import (
"github.com/openziti/zrok/util" "github.com/openziti/zrok/util"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
zhttp "github.com/zitadel/oidc/v2/pkg/http"
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
@ -192,6 +191,8 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
case string(sdk.Oauth): case string(sdk.Oauth):
logrus.Debugf("auth scheme oauth '%v", shrToken)
if oauthCfg, found := cfg["oauth"]; found { if oauthCfg, found := cfg["oauth"]; found {
if provider, found := oauthCfg.(map[string]interface{})["provider"]; found { if provider, found := oauthCfg.(map[string]interface{})["provider"]; found {
var authCheckInterval time.Duration var authCheckInterval time.Duration
@ -212,35 +213,35 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
cookie, err := r.Cookie("zrok-access") cookie, err := r.Cookie("zrok-access")
if err != nil { if err != nil {
logrus.Errorf("Unable to get access cookie: %v", err) logrus.Errorf("unable to get 'zrok-access' cookie: %v", err)
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound) oauthLoginRequired(w, r, shrToken, pcfg, provider.(string), target, authCheckInterval)
return return
} }
tkn, err := jwt.ParseWithClaims(cookie.Value, &ZrokClaims{}, func(t *jwt.Token) (interface{}, error) { tkn, err := jwt.ParseWithClaims(cookie.Value, &ZrokClaims{}, func(t *jwt.Token) (interface{}, error) {
if pcfg.Oauth == nil { if pcfg.Oauth == nil {
return nil, fmt.Errorf("missing oauth configuration for access point. Unable to parse jwt") return nil, fmt.Errorf("missing oauth configuration for access point; unable to parse jwt")
} }
return []byte(pcfg.Oauth.HashKeyRaw), nil return []byte(pcfg.Oauth.HashKeyRaw), nil
}) })
if err != nil { if err != nil {
logrus.Errorf("Unable to parse JWT: %v", err) logrus.Errorf("unable to parse jwt: %v", err)
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound) oauthLoginRequired(w, r, shrToken, pcfg, provider.(string), target, authCheckInterval)
return return
} }
claims := tkn.Claims.(*ZrokClaims) claims := tkn.Claims.(*ZrokClaims)
if claims.Provider != provider { if claims.Provider != provider {
logrus.Error("Provider mismatch. Redoing auth flow") logrus.Error("provider mismatch; restarting auth flow")
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound) oauthLoginRequired(w, r, shrToken, pcfg, provider.(string), target, authCheckInterval)
return return
} }
if claims.AuthorizationCheckInterval != authCheckInterval { if claims.AuthorizationCheckInterval != authCheckInterval {
logrus.Error("Authorization check interval mismatch. Redoing auth flow") logrus.Error("authorization check interval mismatch; restarting auth flow")
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound) oauthLoginRequired(w, r, shrToken, pcfg, provider.(string), target, authCheckInterval)
return return
} }
if validDomains, found := oauthCfg.(map[string]interface{})["email_domains"]; found { if validDomains, found := oauthCfg.(map[string]interface{})["email_domains"]; found {
if castedDomains, ok := validDomains.([]interface{}); !ok { if castedDomains, ok := validDomains.([]interface{}); !ok {
logrus.Error("Invalid format for valid email domains") logrus.Error("invalid email domain format")
return return
} else { } else {
if len(castedDomains) > 0 { if len(castedDomains) > 0 {
@ -252,7 +253,7 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
} }
} }
if !found { if !found {
logrus.Warnf("Email not a valid domain") logrus.Warnf("invalid email domain")
unauthorizedUi.WriteUnauthorized(w) unauthorizedUi.WriteUnauthorized(w)
return return
} }
@ -261,6 +262,7 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
} }
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
return return
} else { } else {
logrus.Warnf("%v -> no provider for '%v'", r.RemoteAddr, provider) logrus.Warnf("%v -> no provider for '%v'", r.RemoteAddr, provider)
notFoundUi.WriteNotFound(w) notFoundUi.WriteNotFound(w)
@ -293,21 +295,6 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
} }
} }
func configureOauthHandlers(ctx context.Context, cfg *Config, tls bool) error {
if cfg.Oauth == nil {
logrus.Info("No oauth config for access point. Skipping spin up.")
return nil
}
if err := configureGoogleOauth(cfg.Oauth, tls); err != nil {
return err
}
if err := configureGithubOauth(cfg.Oauth, tls); err != nil {
return err
}
zhttp.StartServer(ctx, fmt.Sprintf("%s:%d", strings.Split(cfg.Address, ":")[0], cfg.Oauth.Port))
return nil
}
type ZrokClaims struct { type ZrokClaims struct {
Email string `json:"email"` Email string `json:"email"`
AccessToken string `json:"accessToken"` AccessToken string `json:"accessToken"`
@ -325,7 +312,7 @@ func SetZrokCookie(w http.ResponseWriter, domain, email, accessToken, provider s
}) })
sTkn, err := tkn.SignedString(key) sTkn, err := tkn.SignedString(key)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("After signing cookie token: %v", err.Error()), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("after signing cookie token: %v", err.Error()), http.StatusInternalServerError)
return return
} }
@ -346,6 +333,10 @@ func basicAuthRequired(w http.ResponseWriter, realm string) {
w.Write([]byte("No Authorization\n")) w.Write([]byte("No Authorization\n"))
} }
func oauthLoginRequired(w http.ResponseWriter, r *http.Request, shrToken string, pcfg *Config, provider, target string, authCheckInterval time.Duration) {
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider, url.QueryEscape(target), authCheckInterval.String()), http.StatusFound)
}
func resolveService(hostMatch string, host string) string { func resolveService(hostMatch string, host string) string {
logrus.Debugf("host = '%v'", host) logrus.Debugf("host = '%v'", host)
if hostMatch == "" || strings.Contains(host, hostMatch) { if hostMatch == "" || strings.Contains(host, hostMatch) {