feat: Add TLS to atuin-server (#1457)

* Add TLS to atuin-server

atuin as a project already includes most of the dependencies necessary
for server-side TLS.  This allows `atuin server start` to use a TLS
certificate when self-hosting in order to avoid the complication of
wrapping it in a TLS-aware proxy server.

Configuration is handled similar to the metrics server with its own
struct and currently accepts only the private key and certificate file
paths.

Starting a TLS server and a TCP server are divergent because the tests
need to bind to an arbitrary port to avoid collisions across tests.  The
API to accomplish this for a TLS server is much more verbose.

* Fix clippy, fmt

* Add TLS section to self-hosting
This commit is contained in:
Eric Hodel 2023-12-27 06:15:48 -08:00 committed by GitHub
parent 86f50e0356
commit d52e576129
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 175 additions and 18 deletions

33
Cargo.lock generated
View File

@ -109,6 +109,12 @@ version = "1.0.75"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6"
[[package]]
name = "arc-swap"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6"
[[package]] [[package]]
name = "argon2" name = "argon2"
version = "0.5.2" version = "0.5.2"
@ -269,15 +275,20 @@ dependencies = [
"atuin-common", "atuin-common",
"atuin-server-database", "atuin-server-database",
"axum", "axum",
"axum-server",
"base64 0.21.5", "base64 0.21.5",
"config", "config",
"eyre", "eyre",
"fs-err", "fs-err",
"http", "http",
"hyper",
"hyper-rustls",
"metrics", "metrics",
"metrics-exporter-prometheus", "metrics-exporter-prometheus",
"rand", "rand",
"reqwest", "reqwest",
"rustls",
"rustls-pemfile",
"semver", "semver",
"serde", "serde",
"serde_json", "serde_json",
@ -372,6 +383,26 @@ dependencies = [
"tower-service", "tower-service",
] ]
[[package]]
name = "axum-server"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "447f28c85900215cc1bea282f32d4a2f22d55c5a300afdfbc661c8d6a632e063"
dependencies = [
"arc-swap",
"bytes",
"futures-util",
"http",
"http-body",
"hyper",
"pin-project-lite",
"rustls",
"rustls-pemfile",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]] [[package]]
name = "backtrace" name = "backtrace"
version = "0.3.69" version = "0.3.69"
@ -1445,7 +1476,9 @@ dependencies = [
"futures-util", "futures-util",
"http", "http",
"hyper", "hyper",
"log",
"rustls", "rustls",
"rustls-native-certs",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
] ]

View File

@ -26,11 +26,16 @@ rand = { workspace = true }
tokio = { workspace = true } tokio = { workspace = true }
async-trait = { workspace = true } async-trait = { workspace = true }
axum = "0.6.4" axum = "0.6.4"
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
http = "0.2" http = "0.2"
hyper = "0.14"
hyper-rustls = "0.24"
fs-err = { workspace = true } fs-err = { workspace = true }
tower = "0.4" tower = "0.4"
tower-http = { version = "0.4", features = ["trace"] } tower-http = { version = "0.4", features = ["trace"] }
reqwest = { workspace = true } reqwest = { workspace = true }
rustls = "0.21"
rustls-pemfile = "1.0"
argon2 = "0.5.0" argon2 = "0.5.0"
semver = { workspace = true } semver = { workspace = true }
metrics-exporter-prometheus = "0.12.1" metrics-exporter-prometheus = "0.12.1"

View File

@ -27,3 +27,8 @@
# enable = false # enable = false
# host = 127.0.0.1 # host = 127.0.0.1
# port = 9001 # port = 9001
# [tls]
# enable = false
# cert_path = ""
# pkey_path = ""

View File

@ -1,10 +1,13 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
use std::net::SocketAddr;
use std::sync::Arc;
use std::{future::Future, net::TcpListener}; use std::{future::Future, net::TcpListener};
use atuin_server_database::Database; use atuin_server_database::Database;
use axum::Router; use axum::Router;
use axum::Server; use axum::Server;
use axum_server::Handle;
use eyre::{Context, Result}; use eyre::{Context, Result};
mod handlers; mod handlers;
@ -12,6 +15,7 @@ mod metrics;
mod router; mod router;
mod utils; mod utils;
use rustls::ServerConfig;
pub use settings::example_config; pub use settings::example_config;
pub use settings::Settings; pub use settings::Settings;
@ -44,27 +48,26 @@ async fn shutdown_signal() {
pub async fn launch<Db: Database>( pub async fn launch<Db: Database>(
settings: Settings<Db::Settings>, settings: Settings<Db::Settings>,
host: &str, addr: SocketAddr,
port: u16,
) -> Result<()> { ) -> Result<()> {
launch_with_listener::<Db>( if settings.tls.enable {
launch_with_tls::<Db>(settings, addr, shutdown_signal()).await
} else {
launch_with_tcp_listener::<Db>(
settings, settings,
TcpListener::bind((host, port)).context("could not connect to socket")?, TcpListener::bind(addr).context("could not connect to socket")?,
shutdown_signal(), shutdown_signal(),
) )
.await .await
}
} }
pub async fn launch_with_listener<Db: Database>( pub async fn launch_with_tcp_listener<Db: Database>(
settings: Settings<Db::Settings>, settings: Settings<Db::Settings>,
listener: TcpListener, listener: TcpListener,
shutdown: impl Future<Output = ()>, shutdown: impl Future<Output = ()>,
) -> Result<()> { ) -> Result<()> {
let db = Db::new(&settings.db_settings) let r = make_router::<Db>(settings).await?;
.await
.wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
let r = router::router(db, settings);
Server::from_tcp(listener) Server::from_tcp(listener)
.context("could not launch server")? .context("could not launch server")?
@ -75,6 +78,40 @@ pub async fn launch_with_listener<Db: Database>(
Ok(()) Ok(())
} }
async fn launch_with_tls<Db: Database>(
settings: Settings<Db::Settings>,
addr: SocketAddr,
shutdown: impl Future<Output = ()>,
) -> Result<()> {
let certificates = settings.tls.certificates()?;
let pkey = settings.tls.private_key()?;
let server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certificates, pkey)?;
let server_config = Arc::new(server_config);
let rustls_config = axum_server::tls_rustls::RustlsConfig::from_config(server_config);
let r = make_router::<Db>(settings).await?;
let handle = Handle::new();
let server = axum_server::bind_rustls(addr, rustls_config)
.handle(handle.clone())
.serve(r.into_make_service());
tokio::select! {
_ = server => {}
_ = shutdown => {
handle.graceful_shutdown(None);
}
}
Ok(())
}
// The separate listener means it's much easier to ensure metrics are not accidentally exposed to // The separate listener means it's much easier to ensure metrics are not accidentally exposed to
// the public. // the public.
pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> { pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
@ -95,3 +132,13 @@ pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
Ok(()) Ok(())
} }
async fn make_router<Db: Database>(
settings: Settings<<Db as Database>::Settings>,
) -> Result<Router, eyre::Error> {
let db = Db::new(&settings.db_settings)
.await
.wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
let r = router::router(db, settings);
Ok(r)
}

View File

@ -1,7 +1,7 @@
use std::{io::prelude::*, path::PathBuf}; use std::{io::prelude::*, path::PathBuf};
use config::{Config, Environment, File as ConfigFile, FileFormat}; use config::{Config, Environment, File as ConfigFile, FileFormat};
use eyre::{eyre, Result}; use eyre::{bail, eyre, Context, Result};
use fs_err::{create_dir_all, File}; use fs_err::{create_dir_all, File};
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
@ -36,6 +36,7 @@ pub struct Settings<DbSettings> {
pub register_webhook_url: Option<String>, pub register_webhook_url: Option<String>,
pub register_webhook_username: String, pub register_webhook_username: String,
pub metrics: Metrics, pub metrics: Metrics,
pub tls: Tls,
#[serde(flatten)] #[serde(flatten)]
pub db_settings: DbSettings, pub db_settings: DbSettings,
@ -67,6 +68,9 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
.set_default("metrics.enable", false)? .set_default("metrics.enable", false)?
.set_default("metrics.host", "127.0.0.1")? .set_default("metrics.host", "127.0.0.1")?
.set_default("metrics.port", 9001)? .set_default("metrics.port", 9001)?
.set_default("tls.enable", false)?
.set_default("tls.cert_path", "")?
.set_default("tls.key_path", "")?
.add_source( .add_source(
Environment::with_prefix("atuin") Environment::with_prefix("atuin")
.prefix_separator("_") .prefix_separator("_")
@ -97,3 +101,51 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
pub fn example_config() -> &'static str { pub fn example_config() -> &'static str {
EXAMPLE_CONFIG EXAMPLE_CONFIG
} }
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct Tls {
pub enable: bool,
pub cert_path: PathBuf,
pub pkey_path: PathBuf,
}
impl Tls {
pub fn certificates(&self) -> Result<Vec<rustls::Certificate>> {
let cert_file = std::fs::File::open(&self.cert_path)
.with_context(|| format!("tls.cert_path {:?} is missing", self.cert_path))?;
let mut reader = std::io::BufReader::new(cert_file);
let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
.with_context(|| format!("tls.cert_path {:?} is invalid", self.cert_path))?
.into_iter()
.map(rustls::Certificate)
.collect();
if certs.is_empty() {
bail!(
"tls.cert_path {:?} must have at least one certificate",
self.cert_path
);
}
Ok(certs)
}
pub fn private_key(&self) -> Result<rustls::PrivateKey> {
let pkey_file = std::fs::File::open(&self.pkey_path)
.with_context(|| format!("tls.pkey_path {:?} is missing", self.pkey_path))?;
let mut reader = std::io::BufReader::new(pkey_file);
let keys = rustls_pemfile::pkcs8_private_keys(&mut reader)
.with_context(|| format!("tls.pkey_path {:?} is not PKCS8-encoded", self.pkey_path))?;
if keys.is_empty() {
bail!(
"tls.pkey_path {:?} must have at least one private key",
self.pkey_path
);
}
let key = rustls::PrivateKey(keys[0].clone());
Ok(key)
}
}

View File

@ -1,3 +1,5 @@
use std::net::SocketAddr;
use atuin_server_postgres::Postgres; use atuin_server_postgres::Postgres;
use tracing_subscriber::{fmt, prelude::*, EnvFilter}; use tracing_subscriber::{fmt, prelude::*, EnvFilter};
@ -39,6 +41,7 @@ impl Cmd {
let settings = Settings::new().wrap_err("could not load server settings")?; let settings = Settings::new().wrap_err("could not load server settings")?;
let host = host.as_ref().unwrap_or(&settings.host).clone(); let host = host.as_ref().unwrap_or(&settings.host).clone();
let port = port.unwrap_or(settings.port); let port = port.unwrap_or(settings.port);
let addr = SocketAddr::new(host.parse()?, port);
if settings.metrics.enable { if settings.metrics.enable {
tokio::spawn(launch_metrics_server( tokio::spawn(launch_metrics_server(
@ -47,7 +50,7 @@ impl Cmd {
)); ));
} }
launch::<Postgres>(settings, &host, port).await launch::<Postgres>(settings, addr).await
} }
Self::DefaultConfig => { Self::DefaultConfig => {
println!("{}", example_config()); println!("{}", example_config());

View File

@ -2,7 +2,7 @@ use std::{env, net::TcpListener, time::Duration};
use atuin_client::api_client; use atuin_client::api_client;
use atuin_common::{api::AddHistoryRequest, utils::uuid_v7}; use atuin_common::{api::AddHistoryRequest, utils::uuid_v7};
use atuin_server::{launch_with_listener, Settings as ServerSettings}; use atuin_server::{launch_with_tcp_listener, Settings as ServerSettings};
use atuin_server_postgres::{Postgres, PostgresSettings}; use atuin_server_postgres::{Postgres, PostgresSettings};
use futures_util::TryFutureExt; use futures_util::TryFutureExt;
use time::OffsetDateTime; use time::OffsetDateTime;
@ -38,6 +38,7 @@ async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandle<()
register_webhook_username: String::new(), register_webhook_username: String::new(),
db_settings: PostgresSettings { db_uri }, db_settings: PostgresSettings { db_uri },
metrics: atuin_server::settings::Metrics::default(), metrics: atuin_server::settings::Metrics::default(),
tls: atuin_server::settings::Tls::default(),
}; };
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
@ -46,7 +47,7 @@ async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandle<()
let server = tokio::spawn(async move { let server = tokio::spawn(async move {
let _tracing_guard = dispatcher::set_default(&dispatch); let _tracing_guard = dispatcher::set_default(&dispatch);
if let Err(e) = launch_with_listener::<Postgres>( if let Err(e) = launch_with_tcp_listener::<Postgres>(
server_settings, server_settings,
listener, listener,
shutdown_rx.unwrap_or_else(|_| ()), shutdown_rx.unwrap_or_else(|_| ()),

View File

@ -39,3 +39,14 @@ ATUIN_DB_URI="postgres://user:password@hostname/database"
| `db_uri` | A valid PostgreSQL URI, for saving history (default: false) | | `db_uri` | A valid PostgreSQL URI, for saving history (default: false) |
| `path` | A path to prepend to all routes of the server (default: false) | | `path` | A path to prepend to all routes of the server (default: false) |
### TLS
The server supports TLS through the `[tls]` section:
```toml
[tls]
enabled = true
cert_path = "/path/to/letsencrypt/live/fully.qualified.domain/fullchain.pem"
pkey_path = "/path/to/letsencrypt/live/fully.qualified.domain/privkey.pem"
```