From fd17dcd204ce7e18968f27c7edc00a9e3e31bf49 Mon Sep 17 00:00:00 2001 From: TwiN Date: Thu, 20 Jul 2023 19:02:34 -0400 Subject: [PATCH] fix(tls): Pass certificate and private key files to listener method (#531) Fixes #530 --- README.md | 6 +++--- config/web/web.go | 20 +++++++------------ config/web/web_test.go | 43 ++++++++++++++++++++++++++++++++++------ controller/controller.go | 14 +++++++++---- 4 files changed, 57 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 7335d44d..c4db0c76 100644 --- a/README.md +++ b/README.md @@ -1278,14 +1278,14 @@ Confused? Read [Securing Gatus with OIDC using Auth0](https://twin.sh/articles/5 ### TLS Encryption Gatus supports basic encryption with TLS. To enable this, certificate files in PEM format have to be provided. -The example below shows an example configuration which makes gatus respond on port 4443 to HTTPS requests. +The example below shows an example configuration which makes gatus respond on port 4443 to HTTPS requests: ```yaml web: port: 4443 tls: - certificate-file: "server.crt" - private-key-file: "server.key" + certificate-file: "certificate.crt" + private-key-file: "private.key" ``` ### Metrics diff --git a/config/web/web.go b/config/web/web.go index 2e9391fb..fd3d2079 100644 --- a/config/web/web.go +++ b/config/web/web.go @@ -34,8 +34,6 @@ type TLSConfig struct { // PrivateKeyFile is the private key file for TLS in PEM format. PrivateKeyFile string `yaml:"private-key-file,omitempty"` - - tlsConfig *tls.Config } // GetDefaultConfig returns a Config struct with the default values @@ -57,33 +55,29 @@ func (web *Config) ValidateAndSetDefaults() error { } // Try to load the TLS certificates if web.TLS != nil { - if err := web.TLS.loadConfig(); err != nil { + if err := web.TLS.isValid(); err != nil { return fmt.Errorf("invalid tls config: %w", err) } } return nil } +func (web *Config) HasTLS() bool { + return web.TLS != nil && len(web.TLS.CertificateFile) > 0 && len(web.TLS.PrivateKeyFile) > 0 +} + // SocketAddress returns the combination of the Address and the Port func (web *Config) SocketAddress() string { return fmt.Sprintf("%s:%d", web.Address, web.Port) } -func (t *TLSConfig) loadConfig() error { +func (t *TLSConfig) isValid() error { if len(t.CertificateFile) > 0 && len(t.PrivateKeyFile) > 0 { - certificate, err := tls.LoadX509KeyPair(t.CertificateFile, t.PrivateKeyFile) + _, err := tls.LoadX509KeyPair(t.CertificateFile, t.PrivateKeyFile) if err != nil { return err } - t.tlsConfig = &tls.Config{Certificates: []tls.Certificate{certificate}} return nil } return errors.New("certificate-file and private-key-file must be specified") } - -func (web *Config) TLSConfig() *tls.Config { - if web.TLS != nil { - return web.TLS.tlsConfig - } - return nil -} diff --git a/config/web/web_test.go b/config/web/web_test.go index 9a8f2a34..85818926 100644 --- a/config/web/web_test.go +++ b/config/web/web_test.go @@ -37,6 +37,27 @@ func TestConfig_ValidateAndSetDefaults(t *testing.T) { cfg: &Config{Port: 100000000}, expectedErr: true, }, + { + name: "with-good-tls-config", + cfg: &Config{Port: 443, TLS: &TLSConfig{CertificateFile: "../../testdata/cert.pem", PrivateKeyFile: "../../testdata/cert.key"}}, + expectedAddress: "0.0.0.0", + expectedPort: 443, + expectedErr: false, + }, + { + name: "with-bad-tls-config", + cfg: &Config{Port: 443, TLS: &TLSConfig{CertificateFile: "../../testdata/badcert.pem", PrivateKeyFile: "../../testdata/cert.key"}}, + expectedAddress: "0.0.0.0", + expectedPort: 443, + expectedErr: true, + }, + { + name: "with-partial-tls-config", + cfg: &Config{Port: 443, TLS: &TLSConfig{CertificateFile: "", PrivateKeyFile: "../../testdata/cert.key"}}, + expectedAddress: "0.0.0.0", + expectedPort: 443, + expectedErr: true, + }, } for _, scenario := range scenarios { t.Run(scenario.name, func(t *testing.T) { @@ -67,7 +88,7 @@ func TestConfig_SocketAddress(t *testing.T) { } } -func TestConfig_TLSConfig(t *testing.T) { +func TestConfig_isValid(t *testing.T) { scenarios := []struct { name string cfg *Config @@ -79,27 +100,37 @@ func TestConfig_TLSConfig(t *testing.T) { expectedErr: false, }, { - name: "missing-crt-file", + name: "missing-certificate-file", cfg: &Config{TLS: &TLSConfig{CertificateFile: "doesnotexist", PrivateKeyFile: "../../testdata/cert.key"}}, expectedErr: true, }, { - name: "bad-crt-file", + name: "bad-certificate-file", cfg: &Config{TLS: &TLSConfig{CertificateFile: "../../testdata/badcert.pem", PrivateKeyFile: "../../testdata/cert.key"}}, expectedErr: true, }, + { + name: "no-certificate-file", + cfg: &Config{TLS: &TLSConfig{CertificateFile: "", PrivateKeyFile: "../../testdata/cert.key"}}, + expectedErr: true, + }, { name: "missing-private-key-file", cfg: &Config{TLS: &TLSConfig{CertificateFile: "../../testdata/cert.pem", PrivateKeyFile: "doesnotexist"}}, expectedErr: true, }, + { + name: "no-private-key-file", + cfg: &Config{TLS: &TLSConfig{CertificateFile: "../../testdata/cert.pem", PrivateKeyFile: ""}}, + expectedErr: true, + }, { name: "bad-private-key-file", cfg: &Config{TLS: &TLSConfig{CertificateFile: "../../testdata/cert.pem", PrivateKeyFile: "../../testdata/badcert.key"}}, expectedErr: true, }, { - name: "bad-cert-and-private-key-file", + name: "bad-certificate-and-private-key-file", cfg: &Config{TLS: &TLSConfig{CertificateFile: "../../testdata/badcert.pem", PrivateKeyFile: "../../testdata/badcert.key"}}, expectedErr: true, }, @@ -112,8 +143,8 @@ func TestConfig_TLSConfig(t *testing.T) { return } if !scenario.expectedErr { - if scenario.cfg.TLS.tlsConfig == nil { - t.Error("TLS configuration was not correctly loaded although no error was returned") + if scenario.cfg.TLS.isValid() != nil { + t.Error("cfg.TLS.isValid() returned an error even though no error was expected") } } }) diff --git a/controller/controller.go b/controller/controller.go index cc20be89..acdd175d 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -22,16 +22,22 @@ func Handle(cfg *config.Config) { server.ReadTimeout = 15 * time.Second server.WriteTimeout = 15 * time.Second server.IdleTimeout = 15 * time.Second - server.TLSConfig = cfg.Web.TLSConfig() if os.Getenv("ROUTER_TEST") == "true" { return } log.Println("[controller][Handle] Listening on " + cfg.Web.SocketAddress()) - if server.TLSConfig != nil { - log.Println("[controller][Handle]", app.ListenTLS(cfg.Web.SocketAddress(), "", "")) + if cfg.Web.HasTLS() { + err := app.ListenTLS(cfg.Web.SocketAddress(), cfg.Web.TLS.CertificateFile, cfg.Web.TLS.PrivateKeyFile) + if err != nil { + log.Fatal("[controller][Handle]", err) + } } else { - log.Println("[controller][Handle]", app.Listen(cfg.Web.SocketAddress())) + err := app.Listen(cfg.Web.SocketAddress()) + if err != nil { + log.Fatal("[controller][Handle]", err) + } } + log.Println("[controller][Handle] Server has shut down successfully") } // Shutdown stops the server