mirror of
https://github.com/atuinsh/atuin.git
synced 2025-01-13 17:58:54 +01:00
feat: Add TLS to atuin-server (#1457)
* Add TLS to atuin-server atuin as a project already includes most of the dependencies necessary for server-side TLS. This allows `atuin server start` to use a TLS certificate when self-hosting in order to avoid the complication of wrapping it in a TLS-aware proxy server. Configuration is handled similar to the metrics server with its own struct and currently accepts only the private key and certificate file paths. Starting a TLS server and a TCP server are divergent because the tests need to bind to an arbitrary port to avoid collisions across tests. The API to accomplish this for a TLS server is much more verbose. * Fix clippy, fmt * Add TLS section to self-hosting
This commit is contained in:
parent
86f50e0356
commit
d52e576129
33
Cargo.lock
generated
33
Cargo.lock
generated
@ -109,6 +109,12 @@ version = "1.0.75"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6"
|
||||
|
||||
[[package]]
|
||||
name = "arc-swap"
|
||||
version = "1.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6"
|
||||
|
||||
[[package]]
|
||||
name = "argon2"
|
||||
version = "0.5.2"
|
||||
@ -269,15 +275,20 @@ dependencies = [
|
||||
"atuin-common",
|
||||
"atuin-server-database",
|
||||
"axum",
|
||||
"axum-server",
|
||||
"base64 0.21.5",
|
||||
"config",
|
||||
"eyre",
|
||||
"fs-err",
|
||||
"http",
|
||||
"hyper",
|
||||
"hyper-rustls",
|
||||
"metrics",
|
||||
"metrics-exporter-prometheus",
|
||||
"rand",
|
||||
"reqwest",
|
||||
"rustls",
|
||||
"rustls-pemfile",
|
||||
"semver",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@ -372,6 +383,26 @@ dependencies = [
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-server"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "447f28c85900215cc1bea282f32d4a2f22d55c5a300afdfbc661c8d6a632e063"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"pin-project-lite",
|
||||
"rustls",
|
||||
"rustls-pemfile",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.69"
|
||||
@ -1445,7 +1476,9 @@ dependencies = [
|
||||
"futures-util",
|
||||
"http",
|
||||
"hyper",
|
||||
"log",
|
||||
"rustls",
|
||||
"rustls-native-certs",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
]
|
||||
|
@ -26,11 +26,16 @@ rand = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
axum = "0.6.4"
|
||||
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
|
||||
http = "0.2"
|
||||
hyper = "0.14"
|
||||
hyper-rustls = "0.24"
|
||||
fs-err = { workspace = true }
|
||||
tower = "0.4"
|
||||
tower-http = { version = "0.4", features = ["trace"] }
|
||||
reqwest = { workspace = true }
|
||||
rustls = "0.21"
|
||||
rustls-pemfile = "1.0"
|
||||
argon2 = "0.5.0"
|
||||
semver = { workspace = true }
|
||||
metrics-exporter-prometheus = "0.12.1"
|
||||
|
@ -27,3 +27,8 @@
|
||||
# enable = false
|
||||
# host = 127.0.0.1
|
||||
# port = 9001
|
||||
|
||||
# [tls]
|
||||
# enable = false
|
||||
# cert_path = ""
|
||||
# pkey_path = ""
|
||||
|
@ -1,10 +1,13 @@
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::{future::Future, net::TcpListener};
|
||||
|
||||
use atuin_server_database::Database;
|
||||
use axum::Router;
|
||||
use axum::Server;
|
||||
use axum_server::Handle;
|
||||
use eyre::{Context, Result};
|
||||
|
||||
mod handlers;
|
||||
@ -12,6 +15,7 @@ mod metrics;
|
||||
mod router;
|
||||
mod utils;
|
||||
|
||||
use rustls::ServerConfig;
|
||||
pub use settings::example_config;
|
||||
pub use settings::Settings;
|
||||
|
||||
@ -44,27 +48,26 @@ async fn shutdown_signal() {
|
||||
|
||||
pub async fn launch<Db: Database>(
|
||||
settings: Settings<Db::Settings>,
|
||||
host: &str,
|
||||
port: u16,
|
||||
addr: SocketAddr,
|
||||
) -> Result<()> {
|
||||
launch_with_listener::<Db>(
|
||||
settings,
|
||||
TcpListener::bind((host, port)).context("could not connect to socket")?,
|
||||
shutdown_signal(),
|
||||
)
|
||||
.await
|
||||
if settings.tls.enable {
|
||||
launch_with_tls::<Db>(settings, addr, shutdown_signal()).await
|
||||
} else {
|
||||
launch_with_tcp_listener::<Db>(
|
||||
settings,
|
||||
TcpListener::bind(addr).context("could not connect to socket")?,
|
||||
shutdown_signal(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn launch_with_listener<Db: Database>(
|
||||
pub async fn launch_with_tcp_listener<Db: Database>(
|
||||
settings: Settings<Db::Settings>,
|
||||
listener: TcpListener,
|
||||
shutdown: impl Future<Output = ()>,
|
||||
) -> Result<()> {
|
||||
let db = Db::new(&settings.db_settings)
|
||||
.await
|
||||
.wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
|
||||
|
||||
let r = router::router(db, settings);
|
||||
let r = make_router::<Db>(settings).await?;
|
||||
|
||||
Server::from_tcp(listener)
|
||||
.context("could not launch server")?
|
||||
@ -75,6 +78,40 @@ pub async fn launch_with_listener<Db: Database>(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn launch_with_tls<Db: Database>(
|
||||
settings: Settings<Db::Settings>,
|
||||
addr: SocketAddr,
|
||||
shutdown: impl Future<Output = ()>,
|
||||
) -> Result<()> {
|
||||
let certificates = settings.tls.certificates()?;
|
||||
let pkey = settings.tls.private_key()?;
|
||||
|
||||
let server_config = ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certificates, pkey)?;
|
||||
|
||||
let server_config = Arc::new(server_config);
|
||||
let rustls_config = axum_server::tls_rustls::RustlsConfig::from_config(server_config);
|
||||
|
||||
let r = make_router::<Db>(settings).await?;
|
||||
|
||||
let handle = Handle::new();
|
||||
|
||||
let server = axum_server::bind_rustls(addr, rustls_config)
|
||||
.handle(handle.clone())
|
||||
.serve(r.into_make_service());
|
||||
|
||||
tokio::select! {
|
||||
_ = server => {}
|
||||
_ = shutdown => {
|
||||
handle.graceful_shutdown(None);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// The separate listener means it's much easier to ensure metrics are not accidentally exposed to
|
||||
// the public.
|
||||
pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
|
||||
@ -95,3 +132,13 @@ pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn make_router<Db: Database>(
|
||||
settings: Settings<<Db as Database>::Settings>,
|
||||
) -> Result<Router, eyre::Error> {
|
||||
let db = Db::new(&settings.db_settings)
|
||||
.await
|
||||
.wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
|
||||
let r = router::router(db, settings);
|
||||
Ok(r)
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
use std::{io::prelude::*, path::PathBuf};
|
||||
|
||||
use config::{Config, Environment, File as ConfigFile, FileFormat};
|
||||
use eyre::{eyre, Result};
|
||||
use eyre::{bail, eyre, Context, Result};
|
||||
use fs_err::{create_dir_all, File};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
|
||||
@ -36,6 +36,7 @@ pub struct Settings<DbSettings> {
|
||||
pub register_webhook_url: Option<String>,
|
||||
pub register_webhook_username: String,
|
||||
pub metrics: Metrics,
|
||||
pub tls: Tls,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub db_settings: DbSettings,
|
||||
@ -67,6 +68,9 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
|
||||
.set_default("metrics.enable", false)?
|
||||
.set_default("metrics.host", "127.0.0.1")?
|
||||
.set_default("metrics.port", 9001)?
|
||||
.set_default("tls.enable", false)?
|
||||
.set_default("tls.cert_path", "")?
|
||||
.set_default("tls.key_path", "")?
|
||||
.add_source(
|
||||
Environment::with_prefix("atuin")
|
||||
.prefix_separator("_")
|
||||
@ -97,3 +101,51 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
|
||||
pub fn example_config() -> &'static str {
|
||||
EXAMPLE_CONFIG
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||
pub struct Tls {
|
||||
pub enable: bool,
|
||||
pub cert_path: PathBuf,
|
||||
pub pkey_path: PathBuf,
|
||||
}
|
||||
|
||||
impl Tls {
|
||||
pub fn certificates(&self) -> Result<Vec<rustls::Certificate>> {
|
||||
let cert_file = std::fs::File::open(&self.cert_path)
|
||||
.with_context(|| format!("tls.cert_path {:?} is missing", self.cert_path))?;
|
||||
let mut reader = std::io::BufReader::new(cert_file);
|
||||
let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
|
||||
.with_context(|| format!("tls.cert_path {:?} is invalid", self.cert_path))?
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect();
|
||||
|
||||
if certs.is_empty() {
|
||||
bail!(
|
||||
"tls.cert_path {:?} must have at least one certificate",
|
||||
self.cert_path
|
||||
);
|
||||
}
|
||||
|
||||
Ok(certs)
|
||||
}
|
||||
|
||||
pub fn private_key(&self) -> Result<rustls::PrivateKey> {
|
||||
let pkey_file = std::fs::File::open(&self.pkey_path)
|
||||
.with_context(|| format!("tls.pkey_path {:?} is missing", self.pkey_path))?;
|
||||
let mut reader = std::io::BufReader::new(pkey_file);
|
||||
let keys = rustls_pemfile::pkcs8_private_keys(&mut reader)
|
||||
.with_context(|| format!("tls.pkey_path {:?} is not PKCS8-encoded", self.pkey_path))?;
|
||||
|
||||
if keys.is_empty() {
|
||||
bail!(
|
||||
"tls.pkey_path {:?} must have at least one private key",
|
||||
self.pkey_path
|
||||
);
|
||||
}
|
||||
|
||||
let key = rustls::PrivateKey(keys[0].clone());
|
||||
|
||||
Ok(key)
|
||||
}
|
||||
}
|
||||
|
@ -1,3 +1,5 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use atuin_server_postgres::Postgres;
|
||||
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
||||
|
||||
@ -39,6 +41,7 @@ impl Cmd {
|
||||
let settings = Settings::new().wrap_err("could not load server settings")?;
|
||||
let host = host.as_ref().unwrap_or(&settings.host).clone();
|
||||
let port = port.unwrap_or(settings.port);
|
||||
let addr = SocketAddr::new(host.parse()?, port);
|
||||
|
||||
if settings.metrics.enable {
|
||||
tokio::spawn(launch_metrics_server(
|
||||
@ -47,7 +50,7 @@ impl Cmd {
|
||||
));
|
||||
}
|
||||
|
||||
launch::<Postgres>(settings, &host, port).await
|
||||
launch::<Postgres>(settings, addr).await
|
||||
}
|
||||
Self::DefaultConfig => {
|
||||
println!("{}", example_config());
|
||||
|
@ -2,7 +2,7 @@ use std::{env, net::TcpListener, time::Duration};
|
||||
|
||||
use atuin_client::api_client;
|
||||
use atuin_common::{api::AddHistoryRequest, utils::uuid_v7};
|
||||
use atuin_server::{launch_with_listener, Settings as ServerSettings};
|
||||
use atuin_server::{launch_with_tcp_listener, Settings as ServerSettings};
|
||||
use atuin_server_postgres::{Postgres, PostgresSettings};
|
||||
use futures_util::TryFutureExt;
|
||||
use time::OffsetDateTime;
|
||||
@ -38,6 +38,7 @@ async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandle<()
|
||||
register_webhook_username: String::new(),
|
||||
db_settings: PostgresSettings { db_uri },
|
||||
metrics: atuin_server::settings::Metrics::default(),
|
||||
tls: atuin_server::settings::Tls::default(),
|
||||
};
|
||||
|
||||
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
|
||||
@ -46,7 +47,7 @@ async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandle<()
|
||||
let server = tokio::spawn(async move {
|
||||
let _tracing_guard = dispatcher::set_default(&dispatch);
|
||||
|
||||
if let Err(e) = launch_with_listener::<Postgres>(
|
||||
if let Err(e) = launch_with_tcp_listener::<Postgres>(
|
||||
server_settings,
|
||||
listener,
|
||||
shutdown_rx.unwrap_or_else(|_| ()),
|
||||
|
@ -39,3 +39,14 @@ ATUIN_DB_URI="postgres://user:password@hostname/database"
|
||||
| `db_uri` | A valid PostgreSQL URI, for saving history (default: false) |
|
||||
| `path` | A path to prepend to all routes of the server (default: false) |
|
||||
|
||||
### TLS
|
||||
|
||||
The server supports TLS through the `[tls]` section:
|
||||
|
||||
```toml
|
||||
[tls]
|
||||
enabled = true
|
||||
cert_path = "/path/to/letsencrypt/live/fully.qualified.domain/fullchain.pem"
|
||||
pkey_path = "/path/to/letsencrypt/live/fully.qualified.domain/privkey.pem"
|
||||
```
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user