feat: Add TLS to atuin-server (#1457)

* Add TLS to atuin-server

atuin as a project already includes most of the dependencies necessary
for server-side TLS.  This allows `atuin server start` to use a TLS
certificate when self-hosting in order to avoid the complication of
wrapping it in a TLS-aware proxy server.

Configuration is handled similar to the metrics server with its own
struct and currently accepts only the private key and certificate file
paths.

Starting a TLS server and a TCP server are divergent because the tests
need to bind to an arbitrary port to avoid collisions across tests.  The
API to accomplish this for a TLS server is much more verbose.

* Fix clippy, fmt

* Add TLS section to self-hosting
This commit is contained in:
Eric Hodel 2023-12-27 06:15:48 -08:00 committed by GitHub
parent 86f50e0356
commit d52e576129
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 175 additions and 18 deletions

33
Cargo.lock generated
View File

@ -109,6 +109,12 @@ version = "1.0.75"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6"
[[package]]
name = "arc-swap"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6"
[[package]]
name = "argon2"
version = "0.5.2"
@ -269,15 +275,20 @@ dependencies = [
"atuin-common",
"atuin-server-database",
"axum",
"axum-server",
"base64 0.21.5",
"config",
"eyre",
"fs-err",
"http",
"hyper",
"hyper-rustls",
"metrics",
"metrics-exporter-prometheus",
"rand",
"reqwest",
"rustls",
"rustls-pemfile",
"semver",
"serde",
"serde_json",
@ -372,6 +383,26 @@ dependencies = [
"tower-service",
]
[[package]]
name = "axum-server"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "447f28c85900215cc1bea282f32d4a2f22d55c5a300afdfbc661c8d6a632e063"
dependencies = [
"arc-swap",
"bytes",
"futures-util",
"http",
"http-body",
"hyper",
"pin-project-lite",
"rustls",
"rustls-pemfile",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "backtrace"
version = "0.3.69"
@ -1445,7 +1476,9 @@ dependencies = [
"futures-util",
"http",
"hyper",
"log",
"rustls",
"rustls-native-certs",
"tokio",
"tokio-rustls",
]

View File

@ -26,11 +26,16 @@ rand = { workspace = true }
tokio = { workspace = true }
async-trait = { workspace = true }
axum = "0.6.4"
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
http = "0.2"
hyper = "0.14"
hyper-rustls = "0.24"
fs-err = { workspace = true }
tower = "0.4"
tower-http = { version = "0.4", features = ["trace"] }
reqwest = { workspace = true }
rustls = "0.21"
rustls-pemfile = "1.0"
argon2 = "0.5.0"
semver = { workspace = true }
metrics-exporter-prometheus = "0.12.1"

View File

@ -27,3 +27,8 @@
# enable = false
# host = 127.0.0.1
# port = 9001
# [tls]
# enable = false
# cert_path = ""
# pkey_path = ""

View File

@ -1,10 +1,13 @@
#![forbid(unsafe_code)]
use std::net::SocketAddr;
use std::sync::Arc;
use std::{future::Future, net::TcpListener};
use atuin_server_database::Database;
use axum::Router;
use axum::Server;
use axum_server::Handle;
use eyre::{Context, Result};
mod handlers;
@ -12,6 +15,7 @@ mod metrics;
mod router;
mod utils;
use rustls::ServerConfig;
pub use settings::example_config;
pub use settings::Settings;
@ -44,27 +48,26 @@ async fn shutdown_signal() {
pub async fn launch<Db: Database>(
settings: Settings<Db::Settings>,
host: &str,
port: u16,
addr: SocketAddr,
) -> Result<()> {
launch_with_listener::<Db>(
settings,
TcpListener::bind((host, port)).context("could not connect to socket")?,
shutdown_signal(),
)
.await
if settings.tls.enable {
launch_with_tls::<Db>(settings, addr, shutdown_signal()).await
} else {
launch_with_tcp_listener::<Db>(
settings,
TcpListener::bind(addr).context("could not connect to socket")?,
shutdown_signal(),
)
.await
}
}
pub async fn launch_with_listener<Db: Database>(
pub async fn launch_with_tcp_listener<Db: Database>(
settings: Settings<Db::Settings>,
listener: TcpListener,
shutdown: impl Future<Output = ()>,
) -> Result<()> {
let db = Db::new(&settings.db_settings)
.await
.wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
let r = router::router(db, settings);
let r = make_router::<Db>(settings).await?;
Server::from_tcp(listener)
.context("could not launch server")?
@ -75,6 +78,40 @@ pub async fn launch_with_listener<Db: Database>(
Ok(())
}
async fn launch_with_tls<Db: Database>(
settings: Settings<Db::Settings>,
addr: SocketAddr,
shutdown: impl Future<Output = ()>,
) -> Result<()> {
let certificates = settings.tls.certificates()?;
let pkey = settings.tls.private_key()?;
let server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certificates, pkey)?;
let server_config = Arc::new(server_config);
let rustls_config = axum_server::tls_rustls::RustlsConfig::from_config(server_config);
let r = make_router::<Db>(settings).await?;
let handle = Handle::new();
let server = axum_server::bind_rustls(addr, rustls_config)
.handle(handle.clone())
.serve(r.into_make_service());
tokio::select! {
_ = server => {}
_ = shutdown => {
handle.graceful_shutdown(None);
}
}
Ok(())
}
// The separate listener means it's much easier to ensure metrics are not accidentally exposed to
// the public.
pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
@ -95,3 +132,13 @@ pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
Ok(())
}
async fn make_router<Db: Database>(
settings: Settings<<Db as Database>::Settings>,
) -> Result<Router, eyre::Error> {
let db = Db::new(&settings.db_settings)
.await
.wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
let r = router::router(db, settings);
Ok(r)
}

View File

@ -1,7 +1,7 @@
use std::{io::prelude::*, path::PathBuf};
use config::{Config, Environment, File as ConfigFile, FileFormat};
use eyre::{eyre, Result};
use eyre::{bail, eyre, Context, Result};
use fs_err::{create_dir_all, File};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
@ -36,6 +36,7 @@ pub struct Settings<DbSettings> {
pub register_webhook_url: Option<String>,
pub register_webhook_username: String,
pub metrics: Metrics,
pub tls: Tls,
#[serde(flatten)]
pub db_settings: DbSettings,
@ -67,6 +68,9 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
.set_default("metrics.enable", false)?
.set_default("metrics.host", "127.0.0.1")?
.set_default("metrics.port", 9001)?
.set_default("tls.enable", false)?
.set_default("tls.cert_path", "")?
.set_default("tls.key_path", "")?
.add_source(
Environment::with_prefix("atuin")
.prefix_separator("_")
@ -97,3 +101,51 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
pub fn example_config() -> &'static str {
EXAMPLE_CONFIG
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct Tls {
pub enable: bool,
pub cert_path: PathBuf,
pub pkey_path: PathBuf,
}
impl Tls {
pub fn certificates(&self) -> Result<Vec<rustls::Certificate>> {
let cert_file = std::fs::File::open(&self.cert_path)
.with_context(|| format!("tls.cert_path {:?} is missing", self.cert_path))?;
let mut reader = std::io::BufReader::new(cert_file);
let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
.with_context(|| format!("tls.cert_path {:?} is invalid", self.cert_path))?
.into_iter()
.map(rustls::Certificate)
.collect();
if certs.is_empty() {
bail!(
"tls.cert_path {:?} must have at least one certificate",
self.cert_path
);
}
Ok(certs)
}
pub fn private_key(&self) -> Result<rustls::PrivateKey> {
let pkey_file = std::fs::File::open(&self.pkey_path)
.with_context(|| format!("tls.pkey_path {:?} is missing", self.pkey_path))?;
let mut reader = std::io::BufReader::new(pkey_file);
let keys = rustls_pemfile::pkcs8_private_keys(&mut reader)
.with_context(|| format!("tls.pkey_path {:?} is not PKCS8-encoded", self.pkey_path))?;
if keys.is_empty() {
bail!(
"tls.pkey_path {:?} must have at least one private key",
self.pkey_path
);
}
let key = rustls::PrivateKey(keys[0].clone());
Ok(key)
}
}

View File

@ -1,3 +1,5 @@
use std::net::SocketAddr;
use atuin_server_postgres::Postgres;
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
@ -39,6 +41,7 @@ impl Cmd {
let settings = Settings::new().wrap_err("could not load server settings")?;
let host = host.as_ref().unwrap_or(&settings.host).clone();
let port = port.unwrap_or(settings.port);
let addr = SocketAddr::new(host.parse()?, port);
if settings.metrics.enable {
tokio::spawn(launch_metrics_server(
@ -47,7 +50,7 @@ impl Cmd {
));
}
launch::<Postgres>(settings, &host, port).await
launch::<Postgres>(settings, addr).await
}
Self::DefaultConfig => {
println!("{}", example_config());

View File

@ -2,7 +2,7 @@ use std::{env, net::TcpListener, time::Duration};
use atuin_client::api_client;
use atuin_common::{api::AddHistoryRequest, utils::uuid_v7};
use atuin_server::{launch_with_listener, Settings as ServerSettings};
use atuin_server::{launch_with_tcp_listener, Settings as ServerSettings};
use atuin_server_postgres::{Postgres, PostgresSettings};
use futures_util::TryFutureExt;
use time::OffsetDateTime;
@ -38,6 +38,7 @@ async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandle<()
register_webhook_username: String::new(),
db_settings: PostgresSettings { db_uri },
metrics: atuin_server::settings::Metrics::default(),
tls: atuin_server::settings::Tls::default(),
};
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
@ -46,7 +47,7 @@ async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandle<()
let server = tokio::spawn(async move {
let _tracing_guard = dispatcher::set_default(&dispatch);
if let Err(e) = launch_with_listener::<Postgres>(
if let Err(e) = launch_with_tcp_listener::<Postgres>(
server_settings,
listener,
shutdown_rx.unwrap_or_else(|_| ()),

View File

@ -39,3 +39,14 @@ ATUIN_DB_URI="postgres://user:password@hostname/database"
| `db_uri` | A valid PostgreSQL URI, for saving history (default: false) |
| `path` | A path to prepend to all routes of the server (default: false) |
### TLS
The server supports TLS through the `[tls]` section:
```toml
[tls]
enabled = true
cert_path = "/path/to/letsencrypt/live/fully.qualified.domain/fullchain.pem"
pkey_path = "/path/to/letsencrypt/live/fully.qualified.domain/privkey.pem"
```