mirror of
https://github.com/openziti/zrok.git
synced 2025-08-16 19:01:16 +02:00
updates to the oauth work
This commit is contained in:
@ -8,16 +8,20 @@ import (
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/openziti/sdk-golang/ziti"
|
||||
"github.com/openziti/zrok/endpoints"
|
||||
"github.com/openziti/zrok/endpoints/publicProxy/healthUi"
|
||||
"github.com/openziti/zrok/endpoints/publicProxy/notFoundUi"
|
||||
"github.com/openziti/zrok/endpoints/publicProxy/unauthorizedUi"
|
||||
"github.com/openziti/zrok/model"
|
||||
"github.com/openziti/zrok/util"
|
||||
"github.com/openziti/zrok/zrokdir"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
zhttp "github.com/zitadel/oidc/v2/pkg/http"
|
||||
)
|
||||
|
||||
type httpFrontend struct {
|
||||
@ -49,7 +53,9 @@ func NewHTTP(cfg *Config) (*httpFrontend, error) {
|
||||
return nil, err
|
||||
}
|
||||
proxy.Transport = zTransport
|
||||
|
||||
if err := configureOauthHandlers(context.Background(), cfg, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handler := authHandler(util.NewProxyHandler(proxy), "zrok", cfg, zCtx)
|
||||
return &httpFrontend{
|
||||
cfg: cfg,
|
||||
@ -125,9 +131,9 @@ func hostTargetReverseProxy(cfg *Config, ctx ziti.Context) *httputil.ReverseProx
|
||||
return &httputil.ReverseProxy{Director: director}
|
||||
}
|
||||
|
||||
func authHandler(handler http.Handler, realm string, cfg *Config, ctx ziti.Context) http.HandlerFunc {
|
||||
func authHandler(handler http.Handler, realm string, pcfg *Config, ctx ziti.Context) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
shrToken := resolveService(cfg.HostMatch, r.Host)
|
||||
shrToken := resolveService(pcfg.HostMatch, r.Host)
|
||||
if shrToken != "" {
|
||||
if svc, found := endpoints.GetRefreshedService(shrToken, ctx); found {
|
||||
if cfg, found := svc.Config[model.ZrokProxyConfig]; found {
|
||||
@ -183,17 +189,80 @@ func authHandler(handler http.Handler, realm string, cfg *Config, ctx ziti.Conte
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
case string(model.Oauth):
|
||||
logrus.Debugf("auth scheme oauth '%v'", shrToken)
|
||||
awsUrl := "https://oauth2/authorize" // COGNITO URL OR WHATEVER OAUTH PROVIDER URL
|
||||
responseType := "code"
|
||||
clientId := "" // PROVIDER CLIENT ID
|
||||
scope := "email"
|
||||
redirectUri := "http://localhost:18080/api/v1/oauth/authorize"
|
||||
redirectUrl := fmt.Sprintf("%s?response_type=%s&client_id=%s&redirect_uri=%s&state=STATE&scope=%s", awsUrl, responseType, clientId, redirectUri, scope)
|
||||
http.Redirect(w, r, redirectUrl, http.StatusFound)
|
||||
handler.ServeHTTP(w, r)
|
||||
return
|
||||
|
||||
if oauthCfg, found := cfg["oauth"]; found {
|
||||
if provider, found := oauthCfg.(map[string]interface{})["provider"]; found {
|
||||
cookie, err := r.Cookie("zrok-access")
|
||||
if err != nil {
|
||||
logrus.Errorf("Unable to get access cookie: %v", err)
|
||||
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:28080/%s/login?share=%s", shrToken, pcfg.HostMatch, provider.(string), shrToken), http.StatusFound)
|
||||
return
|
||||
}
|
||||
tkn, err := jwt.ParseWithClaims(cookie.Value, &ZrokClaims{}, func(t *jwt.Token) (interface{}, error) {
|
||||
if pcfg.Oauth == nil {
|
||||
return nil, fmt.Errorf("Missing oauth configuration for access point. Unable to parse jwt")
|
||||
}
|
||||
return pcfg.Oauth.HashKeyRaw, nil
|
||||
})
|
||||
if err != nil {
|
||||
logrus.Errorf("Unable to parse JWT: %v", err)
|
||||
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:28080/%s/login?share=%s", shrToken, pcfg.HostMatch, provider.(string), shrToken), http.StatusFound)
|
||||
return
|
||||
}
|
||||
claims := tkn.Claims.(*ZrokClaims)
|
||||
if claims.Provider != provider {
|
||||
logrus.Error("Provider mismatch. Redoing auth flow")
|
||||
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:28080/%s/login?share=%s", shrToken, pcfg.HostMatch, provider.(string), shrToken), http.StatusFound)
|
||||
return
|
||||
}
|
||||
var authCheckInterval time.Duration
|
||||
if checkInterval, found := oauthCfg.(map[string]interface{})["authorization_check_interval"]; !found {
|
||||
logrus.Errorf("Missing authorization check interval in share config. Defaulting to 3 hours")
|
||||
authCheckInterval = 3 * time.Hour
|
||||
} else {
|
||||
i, err := time.ParseDuration(checkInterval.(string))
|
||||
if err != nil {
|
||||
logrus.Errorf("unable to parse authorization check interval in share config (%v). Defaulting to 3 hours", checkInterval)
|
||||
authCheckInterval = 3 * time.Hour
|
||||
} else {
|
||||
authCheckInterval = i
|
||||
}
|
||||
}
|
||||
if claims.AuthorizationCheckInterval != authCheckInterval {
|
||||
logrus.Error("Authorization check interval mismatch. Redoing auth flow")
|
||||
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:28080/%s/login?share=%s", shrToken, pcfg.HostMatch, provider.(string), shrToken), http.StatusFound)
|
||||
return
|
||||
}
|
||||
if validDomains, found := oauthCfg.(map[string]interface{})["email_domains"]; found {
|
||||
if castedDomains, ok := validDomains.([]interface{}); !ok {
|
||||
logrus.Error("Invalid format for valid email domains")
|
||||
return
|
||||
} else {
|
||||
if len(castedDomains) > 0 {
|
||||
found := false
|
||||
for _, domain := range castedDomains {
|
||||
if strings.HasSuffix(claims.Email, domain.(string)) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
logrus.Warnf("Email not a valid domain")
|
||||
unauthorizedUi.WriteUnauthorized(w)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
handler.ServeHTTP(w, r)
|
||||
return
|
||||
} else {
|
||||
logrus.Warnf("%v -> no provider for '%v'", r.RemoteAddr, provider)
|
||||
notFoundUi.WriteNotFound(w)
|
||||
}
|
||||
} else {
|
||||
logrus.Warnf("%v -> no oauth cfg for '%v'", r.RemoteAddr, shrToken)
|
||||
notFoundUi.WriteNotFound(w)
|
||||
}
|
||||
default:
|
||||
logrus.Infof("invalid auth scheme '%v'", scheme)
|
||||
writeUnauthorizedResponse(w, realm)
|
||||
@ -218,6 +287,53 @@ func authHandler(handler http.Handler, realm string, cfg *Config, ctx ziti.Conte
|
||||
}
|
||||
}
|
||||
|
||||
func configureOauthHandlers(ctx context.Context, cfg *Config, tls bool) error {
|
||||
if cfg.Oauth == nil {
|
||||
logrus.Info("No oauth config for access point. Skipping spin up.")
|
||||
return nil
|
||||
}
|
||||
if err := configureGoogleOauth(cfg.Oauth, tls); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := configureGithubOauth(cfg.Oauth, tls); err != nil {
|
||||
return err
|
||||
}
|
||||
zhttp.StartServer(ctx, "0.0.0.0:28080")
|
||||
return nil
|
||||
}
|
||||
|
||||
type ZrokClaims struct {
|
||||
Email string `json:"email"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
Provider string `json:"provider"`
|
||||
AuthorizationCheckInterval time.Duration `json:"authorizationCheckInterval"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func SetZrokCookie(w http.ResponseWriter, email, accessToken, provider string, checkInterval time.Duration, key []byte) {
|
||||
tkn := jwt.NewWithClaims(jwt.SigningMethodHS256, ZrokClaims{
|
||||
Email: email,
|
||||
AccessToken: accessToken,
|
||||
Provider: provider,
|
||||
AuthorizationCheckInterval: checkInterval,
|
||||
})
|
||||
sTkn, err := tkn.SignedString(key)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("After signing cookie token: %v", err.Error()), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "zrok-access",
|
||||
Value: sTkn,
|
||||
MaxAge: 3000,
|
||||
Domain: "localzrok.io",
|
||||
Path: "/",
|
||||
Expires: time.Now().Add(checkInterval),
|
||||
//Secure: true, //When tls gets added have this be configured on if tls
|
||||
})
|
||||
}
|
||||
|
||||
func writeUnauthorizedResponse(w http.ResponseWriter, realm string) {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
|
||||
w.WriteHeader(401)
|
||||
|
Reference in New Issue
Block a user