mirror of
https://github.com/openziti/zrok.git
synced 2025-06-20 09:48:07 +02:00
lint and cleanup; oauthLoginRequired (#404)
This commit is contained in:
parent
b63b1fc145
commit
adbe4e78c0
@ -1,8 +1,13 @@
|
|||||||
package publicProxy
|
package publicProxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
"github.com/michaelquigley/cf"
|
"github.com/michaelquigley/cf"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
zhttp "github.com/zitadel/oidc/v2/pkg/http"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@ -16,10 +21,10 @@ type OauthConfig struct {
|
|||||||
Port int
|
Port int
|
||||||
RedirectUrl string
|
RedirectUrl string
|
||||||
HashKeyRaw string
|
HashKeyRaw string
|
||||||
Providers []*OauthProviderSecrets
|
Providers []*OauthProviderConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oc *OauthConfig) GetProvider(name string) *OauthProviderSecrets {
|
func (oc *OauthConfig) GetProvider(name string) *OauthProviderConfig {
|
||||||
for _, provider := range oc.Providers {
|
for _, provider := range oc.Providers {
|
||||||
if provider.Name == name {
|
if provider.Name == name {
|
||||||
return provider
|
return provider
|
||||||
@ -28,7 +33,7 @@ func (oc *OauthConfig) GetProvider(name string) *OauthProviderSecrets {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type OauthProviderSecrets struct {
|
type OauthProviderConfig struct {
|
||||||
Name string
|
Name string
|
||||||
ClientId string
|
ClientId string
|
||||||
ClientSecret string
|
ClientSecret string
|
||||||
@ -47,3 +52,18 @@ func (c *Config) Load(path string) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func configureOauthHandlers(ctx context.Context, cfg *Config, tls bool) error {
|
||||||
|
if cfg.Oauth == nil {
|
||||||
|
logrus.Info("no oauth configuration; skipping oauth handler startup")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := configureGoogleOauth(cfg.Oauth, tls); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := configureGithubOauth(cfg.Oauth, tls); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
zhttp.StartServer(ctx, fmt.Sprintf("%s:%d", strings.Split(cfg.Address, ":")[0], cfg.Oauth.Port))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/openziti/zrok/util"
|
"github.com/openziti/zrok/util"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
zhttp "github.com/zitadel/oidc/v2/pkg/http"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
@ -192,6 +191,8 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
|
|||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
|
|
||||||
case string(sdk.Oauth):
|
case string(sdk.Oauth):
|
||||||
|
logrus.Debugf("auth scheme oauth '%v", shrToken)
|
||||||
|
|
||||||
if oauthCfg, found := cfg["oauth"]; found {
|
if oauthCfg, found := cfg["oauth"]; found {
|
||||||
if provider, found := oauthCfg.(map[string]interface{})["provider"]; found {
|
if provider, found := oauthCfg.(map[string]interface{})["provider"]; found {
|
||||||
var authCheckInterval time.Duration
|
var authCheckInterval time.Duration
|
||||||
@ -212,35 +213,35 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
|
|||||||
|
|
||||||
cookie, err := r.Cookie("zrok-access")
|
cookie, err := r.Cookie("zrok-access")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorf("Unable to get access cookie: %v", err)
|
logrus.Errorf("unable to get 'zrok-access' cookie: %v", err)
|
||||||
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound)
|
oauthLoginRequired(w, r, shrToken, pcfg, provider.(string), target, authCheckInterval)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tkn, err := jwt.ParseWithClaims(cookie.Value, &ZrokClaims{}, func(t *jwt.Token) (interface{}, error) {
|
tkn, err := jwt.ParseWithClaims(cookie.Value, &ZrokClaims{}, func(t *jwt.Token) (interface{}, error) {
|
||||||
if pcfg.Oauth == nil {
|
if pcfg.Oauth == nil {
|
||||||
return nil, fmt.Errorf("missing oauth configuration for access point. Unable to parse jwt")
|
return nil, fmt.Errorf("missing oauth configuration for access point; unable to parse jwt")
|
||||||
}
|
}
|
||||||
return []byte(pcfg.Oauth.HashKeyRaw), nil
|
return []byte(pcfg.Oauth.HashKeyRaw), nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorf("Unable to parse JWT: %v", err)
|
logrus.Errorf("unable to parse jwt: %v", err)
|
||||||
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound)
|
oauthLoginRequired(w, r, shrToken, pcfg, provider.(string), target, authCheckInterval)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
claims := tkn.Claims.(*ZrokClaims)
|
claims := tkn.Claims.(*ZrokClaims)
|
||||||
if claims.Provider != provider {
|
if claims.Provider != provider {
|
||||||
logrus.Error("Provider mismatch. Redoing auth flow")
|
logrus.Error("provider mismatch; restarting auth flow")
|
||||||
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound)
|
oauthLoginRequired(w, r, shrToken, pcfg, provider.(string), target, authCheckInterval)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if claims.AuthorizationCheckInterval != authCheckInterval {
|
if claims.AuthorizationCheckInterval != authCheckInterval {
|
||||||
logrus.Error("Authorization check interval mismatch. Redoing auth flow")
|
logrus.Error("authorization check interval mismatch; restarting auth flow")
|
||||||
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider.(string), url.QueryEscape(target), authCheckInterval.String()), http.StatusFound)
|
oauthLoginRequired(w, r, shrToken, pcfg, provider.(string), target, authCheckInterval)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if validDomains, found := oauthCfg.(map[string]interface{})["email_domains"]; found {
|
if validDomains, found := oauthCfg.(map[string]interface{})["email_domains"]; found {
|
||||||
if castedDomains, ok := validDomains.([]interface{}); !ok {
|
if castedDomains, ok := validDomains.([]interface{}); !ok {
|
||||||
logrus.Error("Invalid format for valid email domains")
|
logrus.Error("invalid email domain format")
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
if len(castedDomains) > 0 {
|
if len(castedDomains) > 0 {
|
||||||
@ -252,7 +253,7 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
logrus.Warnf("Email not a valid domain")
|
logrus.Warnf("invalid email domain")
|
||||||
unauthorizedUi.WriteUnauthorized(w)
|
unauthorizedUi.WriteUnauthorized(w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -261,6 +262,7 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
|
|||||||
}
|
}
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
logrus.Warnf("%v -> no provider for '%v'", r.RemoteAddr, provider)
|
logrus.Warnf("%v -> no provider for '%v'", r.RemoteAddr, provider)
|
||||||
notFoundUi.WriteNotFound(w)
|
notFoundUi.WriteNotFound(w)
|
||||||
@ -293,21 +295,6 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func configureOauthHandlers(ctx context.Context, cfg *Config, tls bool) error {
|
|
||||||
if cfg.Oauth == nil {
|
|
||||||
logrus.Info("No oauth config for access point. Skipping spin up.")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := configureGoogleOauth(cfg.Oauth, tls); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := configureGithubOauth(cfg.Oauth, tls); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
zhttp.StartServer(ctx, fmt.Sprintf("%s:%d", strings.Split(cfg.Address, ":")[0], cfg.Oauth.Port))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type ZrokClaims struct {
|
type ZrokClaims struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
AccessToken string `json:"accessToken"`
|
AccessToken string `json:"accessToken"`
|
||||||
@ -325,7 +312,7 @@ func SetZrokCookie(w http.ResponseWriter, domain, email, accessToken, provider s
|
|||||||
})
|
})
|
||||||
sTkn, err := tkn.SignedString(key)
|
sTkn, err := tkn.SignedString(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("After signing cookie token: %v", err.Error()), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("after signing cookie token: %v", err.Error()), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -346,6 +333,10 @@ func basicAuthRequired(w http.ResponseWriter, realm string) {
|
|||||||
w.Write([]byte("No Authorization\n"))
|
w.Write([]byte("No Authorization\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func oauthLoginRequired(w http.ResponseWriter, r *http.Request, shrToken string, pcfg *Config, provider, target string, authCheckInterval time.Duration) {
|
||||||
|
http.Redirect(w, r, fmt.Sprintf("http://%s.%s:%d/%s/login?targethost=%s&checkInterval=%s", shrToken, pcfg.HostMatch, pcfg.Oauth.Port, provider, url.QueryEscape(target), authCheckInterval.String()), http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
func resolveService(hostMatch string, host string) string {
|
func resolveService(hostMatch string, host string) string {
|
||||||
logrus.Debugf("host = '%v'", host)
|
logrus.Debugf("host = '%v'", host)
|
||||||
if hostMatch == "" || strings.Contains(host, hostMatch) {
|
if hostMatch == "" || strings.Contains(host, hostMatch) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user