use MD5 hash to get reliable 16-byte key (#404)

This commit is contained in:
Michael Quigley 2023-09-28 14:39:31 -04:00
parent 27fcf98fbd
commit 483c599c93
No known key found for this signature in database
GPG Key ID: 9B60314A9DD20A62
4 changed files with 42 additions and 8 deletions

View File

@ -21,7 +21,7 @@ type OauthConfig struct {
Host string Host string
Port int Port int
RedirectUrl string RedirectUrl string
HashKeyRaw string HashKeyRaw string `cf:"+secret"`
Providers []*OauthProviderConfig Providers []*OauthProviderConfig
} }

View File

@ -1,7 +1,9 @@
package publicProxy package publicProxy
import ( import (
"crypto/md5"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -27,7 +29,7 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error {
providerCfg := cfg.GetProvider("github") providerCfg := cfg.GetProvider("github")
if providerCfg == nil { if providerCfg == nil {
logrus.Info("unable to find provider config for github. Skipping.") logrus.Info("unable to find provider config for github; skipping")
return nil return nil
} }
clientID := providerCfg.ClientId clientID := providerCfg.ClientId
@ -42,7 +44,15 @@ func configureGithubOauth(cfg *OauthConfig, tls bool) error {
Endpoint: githubOAuth.Endpoint, Endpoint: githubOAuth.Endpoint,
} }
key := []byte(cfg.HashKeyRaw) hash := md5.New()
n, err := hash.Write([]byte(cfg.HashKeyRaw))
if err != nil {
return err
}
if n != len(cfg.HashKeyRaw) {
return errors.New("short hash")
}
key := hash.Sum(nil)
u, err := url.Parse(redirectUrl) u, err := url.Parse(redirectUrl)
if err != nil { if err != nil {

View File

@ -1,7 +1,9 @@
package publicProxy package publicProxy
import ( import (
"crypto/md5"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -43,7 +45,15 @@ func configureGoogleOauth(cfg *OauthConfig, tls bool) error {
Endpoint: googleOauth.Endpoint, Endpoint: googleOauth.Endpoint,
} }
key := []byte(cfg.HashKeyRaw) hash := md5.New()
n, err := hash.Write([]byte(cfg.HashKeyRaw))
if err != nil {
return err
}
if n != len(cfg.HashKeyRaw) {
return errors.New("short hash")
}
key := hash.Sum(nil)
u, err := url.Parse(redirectUrl) u, err := url.Parse(redirectUrl)
if err != nil { if err != nil {

View File

@ -2,6 +2,7 @@ package publicProxy
import ( import (
"context" "context"
"crypto/md5"
"fmt" "fmt"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/openziti/sdk-golang/ziti" "github.com/openziti/sdk-golang/ziti"
@ -29,6 +30,19 @@ type HttpFrontend struct {
} }
func NewHTTP(cfg *Config) (*HttpFrontend, error) { func NewHTTP(cfg *Config) (*HttpFrontend, error) {
var key []byte
if cfg.Oauth != nil {
hash := md5.New()
n, err := hash.Write([]byte(cfg.Oauth.HashKeyRaw))
if err != nil {
return nil, err
}
if n != len(cfg.Oauth.HashKeyRaw) {
return nil, errors.New("short hash")
}
key = hash.Sum(nil)
}
root, err := environment.LoadRoot() root, err := environment.LoadRoot()
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error loading environment root") return nil, errors.Wrap(err, "error loading environment root")
@ -58,7 +72,7 @@ func NewHTTP(cfg *Config) (*HttpFrontend, error) {
if err := configureOauthHandlers(context.Background(), cfg, false); err != nil { if err := configureOauthHandlers(context.Background(), cfg, false); err != nil {
return nil, err return nil, err
} }
handler := authHandler(util.NewProxyHandler(proxy), cfg, zCtx) handler := authHandler(util.NewProxyHandler(proxy), cfg, key, zCtx)
return &HttpFrontend{ return &HttpFrontend{
cfg: cfg, cfg: cfg,
zCtx: zCtx, zCtx: zCtx,
@ -133,7 +147,7 @@ func hostTargetReverseProxy(cfg *Config, ctx ziti.Context) *httputil.ReverseProx
return &httputil.ReverseProxy{Director: director} return &httputil.ReverseProxy{Director: director}
} }
func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.HandlerFunc { func authHandler(handler http.Handler, pcfg *Config, key []byte, ctx ziti.Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
shrToken := resolveService(pcfg.HostMatch, r.Host) shrToken := resolveService(pcfg.HostMatch, r.Host)
if shrToken != "" { if shrToken != "" {
@ -191,7 +205,7 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
case string(sdk.Oauth): case string(sdk.Oauth):
logrus.Debugf("auth scheme oauth '%v", shrToken) logrus.Debugf("auth scheme oauth '%v'", shrToken)
if oauthCfg, found := cfg["oauth"]; found { if oauthCfg, found := cfg["oauth"]; found {
if provider, found := oauthCfg.(map[string]interface{})["provider"]; found { if provider, found := oauthCfg.(map[string]interface{})["provider"]; found {
@ -221,7 +235,7 @@ func authHandler(handler http.Handler, pcfg *Config, ctx ziti.Context) http.Hand
if pcfg.Oauth == nil { if pcfg.Oauth == nil {
return nil, fmt.Errorf("missing oauth configuration for access point; unable to parse jwt") return nil, fmt.Errorf("missing oauth configuration for access point; unable to parse jwt")
} }
return []byte(pcfg.Oauth.HashKeyRaw), nil return key, nil
}) })
if err != nil { if err != nil {
logrus.Errorf("unable to parse jwt: %v", err) logrus.Errorf("unable to parse jwt: %v", err)