feat: add semver checking to client requests (#1456)

* feat: add semver checking to client requests

This enforces that the client and the server run the same major version
in order to sync successfully.

We're using the `Atuin-Version` http header to transfer this information

If the user is not on the same MAJOR, then they will see an error like
this

> Atuin version mismatch! In order to successfully sync, the client and the server must run the same *major* version
> Client: 17.1.0
> Server: 18.1.0
> Error: could not sync records due to version mismatch

This change means two things

1. We will now only increment major versions if there is a breaking
   change for sync
2. We can now add breaking changes to sync, for any version >17.1.0.
   Clients will fail in a meaningful way.

* lint, fmt, etc

* only check for client newer than server

* Add version header to client too
This commit is contained in:
Ellie Huxtable 2023-12-20 09:03:04 +00:00 committed by GitHub
parent 42ac150fe3
commit 86f50e0356
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 84 additions and 4 deletions

2
Cargo.lock generated
View File

@ -249,8 +249,10 @@ name = "atuin-common"
version = "17.1.0" version = "17.1.0"
dependencies = [ dependencies = [
"eyre", "eyre",
"lazy_static",
"pretty_assertions", "pretty_assertions",
"rand", "rand",
"semver",
"serde", "serde",
"sqlx", "sqlx",
"time", "time",

View File

@ -5,10 +5,9 @@ use std::time::Duration;
use eyre::{bail, Result}; use eyre::{bail, Result};
use reqwest::{ use reqwest::{
header::{HeaderMap, AUTHORIZATION, USER_AGENT}, header::{HeaderMap, AUTHORIZATION, USER_AGENT},
StatusCode, Url, Response, StatusCode, Url,
}; };
use atuin_common::record::{EncryptedData, HostId, Record, RecordId};
use atuin_common::{ use atuin_common::{
api::{ api::{
AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse, AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse,
@ -16,6 +15,10 @@ use atuin_common::{
}, },
record::RecordIndex, record::RecordIndex,
}; };
use atuin_common::{
api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION},
record::{EncryptedData, HostId, Record, RecordId},
};
use semver::Version; use semver::Version;
use time::format_description::well_known::Rfc3339; use time::format_description::well_known::Rfc3339;
use time::OffsetDateTime; use time::OffsetDateTime;
@ -52,10 +55,15 @@ pub async fn register(
let resp = client let resp = client
.post(url) .post(url)
.header(USER_AGENT, APP_USER_AGENT) .header(USER_AGENT, APP_USER_AGENT)
.header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION)
.json(&map) .json(&map)
.send() .send()
.await?; .await?;
if !ensure_version(&resp)? {
bail!("could not register user due to version mismatch");
}
if !resp.status().is_success() { if !resp.status().is_success() {
let error = resp.json::<ErrorResponse>().await?; let error = resp.json::<ErrorResponse>().await?;
bail!("failed to register user: {}", error.reason); bail!("failed to register user: {}", error.reason);
@ -76,6 +84,10 @@ pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> {
.send() .send()
.await?; .await?;
if !ensure_version(&resp)? {
bail!("could not login due to version mismatch");
}
if resp.status() != reqwest::StatusCode::OK { if resp.status() != reqwest::StatusCode::OK {
let error = resp.json::<ErrorResponse>().await?; let error = resp.json::<ErrorResponse>().await?;
bail!("invalid login details: {}", error.reason); bail!("invalid login details: {}", error.reason);
@ -106,6 +118,31 @@ pub async fn latest_version() -> Result<Version> {
Ok(version) Ok(version)
} }
pub fn ensure_version(response: &Response) -> Result<bool> {
let version = response.headers().get(ATUIN_HEADER_VERSION);
let version = if let Some(version) = version {
match version.to_str() {
Ok(v) => Version::parse(v),
Err(e) => bail!("failed to parse server version: {:?}", e),
}
} else {
// if there is no version header, then the newest this server can possibly be is 17.1.0
Version::parse("17.1.0")
}?;
// If the client is newer than the server
if version.major < ATUIN_VERSION.major {
println!("Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin");
println!("Client: {}", ATUIN_CARGO_VERSION);
println!("Server: {}", version);
return Ok(false);
}
Ok(true)
}
impl<'a> Client<'a> { impl<'a> Client<'a> {
pub fn new( pub fn new(
sync_addr: &'a str, sync_addr: &'a str,
@ -116,6 +153,9 @@ impl<'a> Client<'a> {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, format!("Token {session_token}").parse()?); headers.insert(AUTHORIZATION, format!("Token {session_token}").parse()?);
// used for semver server check
headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?);
Ok(Client { Ok(Client {
sync_addr, sync_addr,
client: reqwest::Client::builder() client: reqwest::Client::builder()
@ -133,6 +173,10 @@ impl<'a> Client<'a> {
let resp = self.client.get(url).send().await?; let resp = self.client.get(url).send().await?;
if !ensure_version(&resp)? {
bail!("could not sync due to version mismatch");
}
if resp.status() != StatusCode::OK { if resp.status() != StatusCode::OK {
bail!("failed to get count (are you logged in?)"); bail!("failed to get count (are you logged in?)");
} }
@ -148,6 +192,10 @@ impl<'a> Client<'a> {
let resp = self.client.get(url).send().await?; let resp = self.client.get(url).send().await?;
if !ensure_version(&resp)? {
bail!("could not sync due to version mismatch");
}
if resp.status() != StatusCode::OK { if resp.status() != StatusCode::OK {
bail!("failed to get status (are you logged in?)"); bail!("failed to get status (are you logged in?)");
} }
@ -262,6 +310,11 @@ impl<'a> Client<'a> {
let url = Url::parse(url.as_str())?; let url = Url::parse(url.as_str())?;
let resp = self.client.get(url).send().await?; let resp = self.client.get(url).send().await?;
if !ensure_version(&resp)? {
bail!("could not sync records due to version mismatch");
}
let index = resp.json().await?; let index = resp.json().await?;
Ok(index) Ok(index)

View File

@ -20,6 +20,9 @@ rand = { workspace = true }
typed-builder = { workspace = true } typed-builder = { workspace = true }
eyre = { workspace = true } eyre = { workspace = true }
sqlx = { workspace = true } sqlx = { workspace = true }
semver = { workspace = true }
lazy_static = "1.4.0"
[dev-dependencies] [dev-dependencies]
pretty_assertions = { workspace = true } pretty_assertions = { workspace = true }

View File

@ -1,7 +1,18 @@
use lazy_static::lazy_static;
use semver::Version;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::borrow::Cow; use std::borrow::Cow;
use time::OffsetDateTime; use time::OffsetDateTime;
// the usage of X- has been deprecated for quite along time, it turns out
pub static ATUIN_HEADER_VERSION: &str = "Atuin-Version";
pub static ATUIN_CARGO_VERSION: &str = env!("CARGO_PKG_VERSION");
lazy_static! {
pub static ref ATUIN_VERSION: Version =
Version::parse(ATUIN_CARGO_VERSION).expect("failed to parse self semver");
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct UserResponse { pub struct UserResponse {
pub username: String, pub username: String,

View File

@ -1,5 +1,5 @@
use async_trait::async_trait; use async_trait::async_trait;
use atuin_common::api::ErrorResponse; use atuin_common::api::{ErrorResponse, ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION};
use axum::{ use axum::{
extract::FromRequestParts, extract::FromRequestParts,
http::Request, http::Request,
@ -91,6 +91,16 @@ async fn clacks_overhead<B>(request: Request<B>, next: Next<B>) -> Response {
response response
} }
/// Ensure that we only try and sync with clients on the same major version
async fn semver<B>(request: Request<B>, next: Next<B>) -> Response {
let mut response = next.run(request).await;
response
.headers_mut()
.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse().unwrap());
response
}
#[derive(Clone)] #[derive(Clone)]
pub struct AppState<DB: Database> { pub struct AppState<DB: Database> {
pub database: DB, pub database: DB,
@ -126,6 +136,7 @@ pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> R
ServiceBuilder::new() ServiceBuilder::new()
.layer(axum::middleware::from_fn(clacks_overhead)) .layer(axum::middleware::from_fn(clacks_overhead))
.layer(TraceLayer::new_for_http()) .layer(TraceLayer::new_for_http())
.layer(axum::middleware::from_fn(metrics::track_metrics)), .layer(axum::middleware::from_fn(metrics::track_metrics))
.layer(axum::middleware::from_fn(semver)),
) )
} }