diff --git a/endpoints/publicProxy/config.go b/endpoints/publicProxy/config.go index 233c1a4e..c39708a7 100644 --- a/endpoints/publicProxy/config.go +++ b/endpoints/publicProxy/config.go @@ -23,6 +23,7 @@ type Config struct { type InterstitialConfig struct { Enabled bool + HtmlPath string UserAgentPrefixes []string } diff --git a/endpoints/publicProxy/http.go b/endpoints/publicProxy/http.go index 69de6330..2cd52eba 100644 --- a/endpoints/publicProxy/http.go +++ b/endpoints/publicProxy/http.go @@ -180,7 +180,7 @@ func shareHandler(handler http.Handler, pcfg *Config, key []byte, ctx ziti.Conte _, zrokOkErr := r.Cookie("zrok_interstitial") if skip == "" && zrokOkErr != nil { logrus.Debugf("forcing interstitial for '%v'", r.URL) - interstitialUi.WriteInterstitialAnnounce(w) + interstitialUi.WriteInterstitialAnnounce(w, pcfg.Interstitial.HtmlPath) return } } diff --git a/endpoints/publicProxy/interstitialUi/handler.go b/endpoints/publicProxy/interstitialUi/handler.go index 98bd5dbc..2ae08753 100644 --- a/endpoints/publicProxy/interstitialUi/handler.go +++ b/endpoints/publicProxy/interstitialUi/handler.go @@ -3,19 +3,35 @@ package interstitialUi import ( "github.com/sirupsen/logrus" "net/http" + "os" ) -func WriteInterstitialAnnounce(w http.ResponseWriter) { - if data, err := FS.ReadFile("index.html"); err == nil { - w.WriteHeader(http.StatusOK) - n, err := w.Write(data) - if n != len(data) { - logrus.Errorf("short write") - return - } - if err != nil { - logrus.Error(err) - return +var externalFile []byte + +func WriteInterstitialAnnounce(w http.ResponseWriter, htmlPath string) { + if htmlPath != "" && externalFile == nil { + if data, err := os.ReadFile(htmlPath); err == nil { + externalFile = data + } else { + logrus.Errorf("error reading external interstitial file '%v': %v", htmlPath, err) } } + var htmlData = externalFile + if htmlData == nil { + if data, err := FS.ReadFile("index.html"); err == nil { + htmlData = data + } else { + logrus.Errorf("error reading embedded interstitial html 'index.html': %v", err) + } + } + w.WriteHeader(http.StatusOK) + n, err := w.Write(htmlData) + if n != len(htmlData) { + logrus.Errorf("short write") + return + } + if err != nil { + logrus.Error(err) + return + } }