Add index handler, use UUIDs not strings

This commit is contained in:
Ellie Huxtable 2023-06-26 21:18:14 +01:00
parent fd51031bde
commit 1774ec93eb
18 changed files with 195 additions and 79 deletions

3
Cargo.lock generated
View File

@ -240,6 +240,7 @@ dependencies = [
"serde",
"sqlx",
"tracing",
"uuid",
]
[[package]]
@ -2549,6 +2550,7 @@ dependencies = [
"thiserror",
"tokio-stream",
"url",
"uuid",
"webpki-roots",
"whoami",
]
@ -3019,6 +3021,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81"
dependencies = [
"getrandom 0.2.7",
"serde",
]
[[package]]

View File

@ -1,11 +1,11 @@
[workspace]
members = [
"atuin",
"atuin-client",
"atuin-server",
"atuin-server-postgres",
"atuin-server-database",
"atuin-common",
"atuin",
"atuin-client",
"atuin-server",
"atuin-server-postgres",
"atuin-server-database",
"atuin-common",
]
[workspace.package]
@ -35,7 +35,7 @@ semver = "1.0.14"
serde = { version = "1.0.145", features = ["derive"] }
serde_json = "1.0.86"
tokio = { version = "1", features = ["full"] }
uuid = { version = "1.3", features = ["v4"] }
uuid = { version = "1.3", features = ["v4", "serde"] }
whoami = "1.1.2"
typed-builder = "0.14.0"
@ -46,4 +46,4 @@ default-features = false
[workspace.dependencies.sqlx]
version = "0.6"
features = ["runtime-tokio-rustls", "chrono", "postgres"]
features = ["runtime-tokio-rustls", "chrono", "postgres", "uuid"]

View File

@ -12,6 +12,7 @@ use atuin_common::api::{
AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse,
LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse,
};
use atuin_common::record::Record;
use semver::Version;
use crate::{history::History, sync::hash_str};
@ -195,6 +196,15 @@ impl<'a> Client<'a> {
Ok(())
}
pub async fn post_records(&self, records: &[Record]) -> Result<()> {
let url = format!("{}/record", self.sync_addr);
let url = Url::parse(url.as_str())?;
self.client.post(url).json(records).send().await?;
Ok(())
}
pub async fn delete(&self) -> Result<()> {
let url = format!("{}/account", self.sync_addr);
let url = Url::parse(url.as_str())?;

View File

@ -13,6 +13,7 @@ use sqlx::{
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
Result, Row,
};
use uuid::Uuid;
use super::{
history::History,
@ -57,7 +58,7 @@ pub fn current_context() -> Context {
session,
hostname,
cwd,
host_id,
host_id: host_id.as_simple().to_string(),
}
}

View File

@ -101,10 +101,7 @@ impl KvStore {
let bytes = record.serialize()?;
let parent = store
.last(host_id.as_str(), KV_TAG)
.await?
.map(|entry| entry.id);
let parent = store.last(host_id, KV_TAG).await?.map(|entry| entry.id);
let record = atuin_common::record::Record::builder()
.host(host_id)
@ -138,7 +135,7 @@ impl KvStore {
// iterate records to find the value we want
// start at the end, so we get the most recent version
let Some(mut record) = store.last(host_id.as_str(), KV_TAG).await? else {
let Some(mut record) = store.last(host_id, KV_TAG).await? else {
return Ok(None);
};

View File

@ -14,6 +14,7 @@ use sqlx::{
};
use atuin_common::record::{EncryptedData, Record};
use uuid::Uuid;
use super::store::Store;
@ -62,11 +63,11 @@ impl SqliteStore {
"insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek)
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)
.bind(r.id.as_str())
.bind(r.host.as_str())
.bind(r.id.as_simple().to_string())
.bind(r.host.as_simple().to_string())
.bind(r.tag.as_str())
.bind(r.timestamp as i64)
.bind(r.parent.as_ref())
.bind(r.parent.map(|p| p.as_simple().to_string()))
.bind(r.version.as_str())
.bind(r.data.data.as_str())
.bind(r.data.content_encryption_key.as_str())
@ -79,10 +80,21 @@ impl SqliteStore {
fn query_row(row: SqliteRow) -> Record<EncryptedData> {
let timestamp: i64 = row.get("timestamp");
// tbh at this point things are pretty fucked so just panic
let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
let parent: Option<&str> = row.get("host");
let parent = if let Some(parent) = parent {
Some(Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB"))
} else {
None
};
Record {
id: row.get("id"),
host: row.get("host"),
parent: row.get("parent"),
id,
host,
parent,
timestamp: timestamp as u64,
tag: row.get("tag"),
version: row.get("version"),
@ -111,7 +123,7 @@ impl Store for SqliteStore {
Ok(())
}
async fn get(&self, id: &str) -> Result<Record<EncryptedData>> {
async fn get(&self, id: Uuid) -> Result<Record<EncryptedData>> {
let res = sqlx::query("select * from records where id = ?1")
.bind(id)
.map(Self::query_row)
@ -121,7 +133,7 @@ impl Store for SqliteStore {
Ok(res)
}
async fn len(&self, host: &str, tag: &str) -> Result<u64> {
async fn len(&self, host: Uuid, tag: &str) -> Result<u64> {
let res: (i64,) =
sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
.bind(host)
@ -146,7 +158,7 @@ impl Store for SqliteStore {
}
}
async fn first(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> {
async fn first(&self, host: Uuid, tag: &str) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query(
"select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
)
@ -159,12 +171,12 @@ impl Store for SqliteStore {
Ok(res)
}
async fn last(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> {
async fn last(&self, host: Uuid, tag: &str) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query(
"select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;",
)
.bind(tag)
.bind(host)
.bind(host.as_simple().to_string())
.map(Self::query_row)
.fetch_optional(&self.pool)
.await?;
@ -183,7 +195,7 @@ mod tests {
fn test_record() -> Record<EncryptedData> {
Record::builder()
.host(atuin_common::utils::uuid_v7().simple().to_string())
.host(atuin_common::utils::uuid_v7())
.version("v1".into())
.tag(atuin_common::utils::uuid_v7().simple().to_string())
.data(EncryptedData {
@ -218,10 +230,7 @@ mod tests {
let record = test_record();
db.push(&record).await.unwrap();
let new_record = db
.get(record.id.as_str())
.await
.expect("failed to fetch record");
let new_record = db.get(record.id).await.expect("failed to fetch record");
assert_eq!(record, new_record, "records are not equal");
}
@ -233,7 +242,7 @@ mod tests {
db.push(&record).await.unwrap();
let len = db
.len(record.host.as_str(), record.tag.as_str())
.len(record.host, record.tag.as_str())
.await
.expect("failed to get store len");
@ -253,14 +262,8 @@ mod tests {
db.push(&first).await.unwrap();
db.push(&second).await.unwrap();
let first_len = db
.len(first.host.as_str(), first.tag.as_str())
.await
.unwrap();
let second_len = db
.len(second.host.as_str(), second.tag.as_str())
.await
.unwrap();
let first_len = db.len(first.host, first.tag.as_str()).await.unwrap();
let second_len = db.len(second.host, second.tag.as_str()).await.unwrap();
assert_eq!(first_len, 1, "expected length of 1 after insert");
assert_eq!(second_len, 1, "expected length of 1 after insert");
@ -281,7 +284,7 @@ mod tests {
}
assert_eq!(
db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(),
db.len(tail.host, tail.tag.as_str()).await.unwrap(),
100,
"failed to insert 100 records"
);
@ -304,7 +307,7 @@ mod tests {
db.push_batch(records.iter()).await.unwrap();
assert_eq!(
db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(),
db.len(tail.host, tail.tag.as_str()).await.unwrap(),
10000,
"failed to insert 10k records"
);
@ -327,7 +330,7 @@ mod tests {
db.push_batch(records.iter()).await.unwrap();
let mut record = db
.first(tail.host.as_str(), tail.tag.as_str())
.first(tail.host, tail.tag.as_str())
.await
.expect("in memory sqlite should not fail")
.expect("entry exists");

View File

@ -2,6 +2,7 @@ use async_trait::async_trait;
use eyre::Result;
use atuin_common::record::{EncryptedData, Record};
use uuid::Uuid;
/// A record store stores records
/// In more detail - we tend to need to process this into _another_ format to actually query it.
@ -21,13 +22,13 @@ pub trait Store {
) -> Result<()>;
async fn get(&self, id: &str) -> Result<Record<EncryptedData>>;
async fn len(&self, host: &str, tag: &str) -> Result<u64>;
async fn len(&self, host: Uuid, tag: &str) -> Result<u64>;
/// Get the record that follows this record
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>;
/// Get the first record for a given host and tag
async fn first(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>>;
async fn first(&self, host: Uuid, tag: &str) -> Result<Option<Record<EncryptedData>>>;
/// Get the last record for a given host and tag
async fn last(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>>;
async fn last(&self, host: Uuid, tag: &str) -> Result<Option<Record<EncryptedData>>>;
}

View File

@ -1,6 +1,7 @@
use std::{
io::prelude::*,
path::{Path, PathBuf},
str::FromStr,
};
use chrono::{prelude::*, Utc};
@ -12,6 +13,7 @@ use parse_duration::parse;
use regex::RegexSet;
use semver::Version;
use serde::Deserialize;
use uuid::Uuid;
pub const HISTORY_PAGE_SIZE: i64 = 100;
pub const LAST_SYNC_FILENAME: &str = "last_sync_time";
@ -228,11 +230,13 @@ impl Settings {
Settings::load_time_from_file(LAST_VERSION_CHECK_FILENAME)
}
pub fn host_id() -> Option<String> {
pub fn host_id() -> Option<Uuid> {
let id = Settings::read_from_data_dir(HOST_ID_FILENAME);
if id.is_some() {
return id;
let parsed = Uuid::from_str(id.unwrap().as_str())
.expect("failed to parse host ID from local directory");
return Some(parsed);
}
let uuid = atuin_common::utils::uuid_v7();
@ -240,7 +244,7 @@ impl Settings {
Settings::save_to_data_dir(HOST_ID_FILENAME, uuid.as_simple().to_string().as_ref())
.expect("Could not write host ID to data dir");
Some(uuid.as_simple().to_string())
Some(uuid)
}
pub fn should_sync(&self) -> Result<bool> {

View File

@ -3,6 +3,7 @@ use std::collections::HashMap;
use eyre::Result;
use serde::{Deserialize, Serialize};
use typed_builder::TypedBuilder;
use uuid::Uuid;
#[derive(Clone, Debug, PartialEq)]
pub struct DecryptedData(pub Vec<u8>);
@ -17,21 +18,21 @@ pub struct EncryptedData {
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)]
pub struct Record<Data> {
/// a unique ID
#[builder(default = crate::utils::uuid_v7().as_simple().to_string())]
pub id: String,
#[builder(default = crate::utils::uuid_v7())]
pub id: Uuid,
/// The unique ID of the host.
// TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store
// as strings. I would rather avoid normalization, so store as UUID binary instead of
// encoding to a string and wasting much more storage.
pub host: String,
pub host: Uuid,
/// The ID of the parent entry
// A store is technically just a double linked list
// We can do some cheating with the timestamps, but should not rely upon them.
// Clocks are tricksy.
#[builder(default)]
pub parent: Option<String>,
pub parent: Option<Uuid>,
/// The creation time in nanoseconds since unix epoch
#[builder(default = chrono::Utc::now().timestamp_nanos() as u64)]
@ -71,9 +72,10 @@ impl<Data> Record<Data> {
/// An index representing the current state of the record stores
/// This can be both remote, or local, and compared in either direction
#[derive(Debug, Serialize, Deserialize)]
pub struct RecordIndex {
// A map of host -> tag -> tail
pub hosts: HashMap<String, HashMap<String, String>>,
pub hosts: HashMap<Uuid, HashMap<String, Uuid>>,
}
impl Default for RecordIndex {
@ -97,7 +99,11 @@ impl RecordIndex {
.insert(tail.tag, tail.id);
}
pub fn get(&self, host: String, tag: String) -> Option<String> {
pub fn set_raw(&mut self, host: Uuid, tag: String, tail: Uuid) {
self.hosts.entry(host).or_default().insert(tag, tail);
}
pub fn get(&self, host: Uuid, tag: String) -> Option<Uuid> {
self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned()
}
@ -108,7 +114,7 @@ impl RecordIndex {
/// other machine has a different tail, it will be the differing tail. This is useful to
/// check if the other index is ahead of us, or behind.
/// If the other index does not have the (host, tag) pair, then the other value will be None.
pub fn diff(&self, other: &Self) -> Vec<(String, String, Option<String>)> {
pub fn diff(&self, other: &Self) -> Vec<(Uuid, String, Option<Uuid>)> {
let mut ret = Vec::new();
// First, we check if other has everything that self has
@ -227,10 +233,11 @@ impl Record<EncryptedData> {
mod tests {
use super::{DecryptedData, Record, RecordIndex};
use pretty_assertions::assert_eq;
use uuid::Uuid;
fn test_record() -> Record<DecryptedData> {
Record::builder()
.host(crate::utils::uuid_v7().simple().to_string())
.host(crate::utils::uuid_v7())
.version("v1".into())
.tag(crate::utils::uuid_v7().simple().to_string())
.data(DecryptedData(vec![0, 1, 2, 3]))
@ -344,9 +351,9 @@ mod tests {
// both diffs should be ALMOST the same. They will agree on which hosts and tags
// require updating, but the "other" value will not be the same.
let smol_diff_1: Vec<(String, String)> =
let smol_diff_1: Vec<(Uuid, String)> =
diff1.iter().map(|v| (v.0.clone(), v.1.clone())).collect();
let smol_diff_2: Vec<(String, String)> =
let smol_diff_2: Vec<(Uuid, String)> =
diff1.iter().map(|v| (v.0.clone(), v.1.clone())).collect();
assert_eq!(smol_diff_1, smol_diff_2);

View File

@ -18,6 +18,7 @@ use chrono::{Datelike, TimeZone};
use chronoutil::RelativeDuration;
use serde::{de::DeserializeOwned, Serialize};
use tracing::instrument;
use uuid::Uuid;
#[derive(Debug)]
pub enum DbError {
@ -55,10 +56,10 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
async fn delete_history(&self, user: &User, id: String) -> DbResult<()>;
async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>;
async fn add_record(&self, user: &User, record: &[Record]) -> DbResult<()>;
async fn add_records(&self, user: &User, record: &[Record]) -> DbResult<()>;
// Return the tail record ID for each store, so (HostID, Tag, TailRecordID)
async fn tail_records(&self, user: &User) -> DbResult<Vec<(String, String, String)>>;
async fn tail_records(&self, user: &User) -> DbResult<Vec<(Uuid, String, Uuid)>>;
async fn count_history_range(
&self,

View File

@ -18,4 +18,5 @@ chrono = { workspace = true }
serde = { workspace = true }
sqlx = { workspace = true }
async-trait = { workspace = true }
uuid = { workspace = true }
futures-util = "0.3"

View File

@ -1,3 +1,5 @@
use std::str::FromStr;
use async_trait::async_trait;
use atuin_common::record::Record;
use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
@ -5,9 +7,10 @@ use atuin_server_database::{Database, DbError, DbResult};
use futures_util::TryStreamExt;
use serde::{Deserialize, Serialize};
use sqlx::postgres::PgPoolOptions;
use sqlx::Row;
use sqlx::types::Uuid;
use tracing::instrument;
use wrappers::{DbHistory, DbSession, DbUser};
@ -331,11 +334,11 @@ impl Database for Postgres {
.map(|DbHistory(h)| h)
}
async fn add_record(&self, user: &User, records: &[Record]) -> DbResult<()> {
async fn add_records(&self, user: &User, records: &[Record]) -> DbResult<()> {
let mut tx = self.pool.begin().await.map_err(fix_error)?;
for i in records {
let id = atuin_common::utils::uuid_v7().as_simple().to_string();
let id = atuin_common::utils::uuid_v7();
sqlx::query(
"insert into records
@ -345,9 +348,9 @@ impl Database for Postgres {
",
)
.bind(id)
.bind(&i.id)
.bind(&i.host)
.bind(&i.parent)
.bind(i.id)
.bind(i.host)
.bind(id)
.bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
.bind(&i.version)
.bind(&i.tag)
@ -363,8 +366,8 @@ impl Database for Postgres {
Ok(())
}
async fn tail_records(&self, user: &User) -> DbResult<Vec<(String, String, String)>> {
const TAIL_RECORDS_SQL: &str = "select host, tag, id from records rp where (select count(1) from records where parent=rp.id and user_id = $1) = 0 group by host, tag;";
async fn tail_records(&self, user: &User) -> DbResult<Vec<(Uuid, String, Uuid)>> {
const TAIL_RECORDS_SQL: &str = "select host, tag, id from records rp where (select count(1) from records where parent=rp.id and user_id = $1) = 0;";
let res = sqlx::query_as(TAIL_RECORDS_SQL)
.bind(user.id)

View File

@ -2,6 +2,7 @@ use atuin_common::api::{ErrorResponse, IndexResponse};
use axum::{response::IntoResponse, Json};
pub mod history;
pub mod record;
pub mod status;
pub mod user;

View File

@ -0,0 +1,70 @@
use axum::{extract::State, Json};
use http::StatusCode;
use tracing::{error, instrument};
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
use crate::router::{AppState, UserAuth};
use atuin_server_database::Database;
use atuin_common::record::{Record, RecordIndex};
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn post<DB: Database>(
UserAuth(user): UserAuth,
state: State<AppState<DB>>,
Json(records): Json<Vec<Record>>,
) -> Result<(), ErrorResponseStatus<'static>> {
let State(AppState { database, settings }) = state;
tracing::debug!(
count = records.len(),
user = user.username,
"request to add records"
);
let too_big = records
.iter()
.any(|r| r.data.len() >= settings.max_record_size || settings.max_record_size == 0);
if too_big {
return Err(
ErrorResponse::reply("could not add records; record too large")
.with_status(StatusCode::BAD_REQUEST),
);
}
if let Err(e) = database.add_records(&user, &records).await {
error!("failed to add record: {}", e);
return Err(ErrorResponse::reply("failed to add record")
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
};
Ok(())
}
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn index<DB: Database>(
UserAuth(user): UserAuth,
state: State<AppState<DB>>,
) -> Result<Json<RecordIndex>, ErrorResponseStatus<'static>> {
let State(AppState { database, settings }) = state;
let index = match database.tail_records(&user).await {
Ok(index) => index,
Err(e) => {
error!("failed to get record index: {}", e);
return Err(ErrorResponse::reply("failed to calculate record index")
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
}
};
let mut record_index = RecordIndex::new();
for row in index {
record_index.set_raw(row.0, row.1, row.2);
}
Ok(Json(record_index))
}

View File

@ -71,6 +71,8 @@ pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> R
.route("/sync/status", get(handlers::status::status))
.route("/history", post(handlers::history::add))
.route("/history", delete(handlers::history::delete))
.route("/record", post(handlers::record::post))
.route("/record", get(handlers::record::index))
.route("/user/:username", get(handlers::user::get))
.route("/account", delete(handlers::user::delete))
.route("/register", post(handlers::user::register))

View File

@ -12,6 +12,7 @@ pub struct Settings<DbSettings> {
pub path: String,
pub open_registration: bool,
pub max_history_length: usize,
pub max_record_size: usize,
pub page_size: i64,
pub register_webhook_url: Option<String>,
pub register_webhook_username: String,
@ -39,6 +40,7 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
.set_default("port", 8888)?
.set_default("open_registration", false)?
.set_default("max_history_length", 8192)?
.set_default("max_record_size", 1024 * 1024 * 1024)? // pretty chonky
.set_default("path", "")?
.set_default("register_webhook_username", "")?
.set_default("page_size", 1100)?

View File

@ -69,7 +69,7 @@ impl Cmd {
Self::Search(search) => search.run(db, &mut settings).await,
#[cfg(feature = "sync")]
Self::Sync(sync) => sync.run(settings, &mut db).await,
Self::Sync(sync) => sync.run(settings, &mut db, &mut store).await,
#[cfg(feature = "sync")]
Self::Account(account) => account.run(settings).await,

View File

@ -1,7 +1,7 @@
use clap::Subcommand;
use eyre::{Result, WrapErr};
use atuin_client::{database::Database, settings::Settings};
use atuin_client::{api_client, database::Database, record::store::Store, settings::Settings};
mod status;
@ -37,9 +37,14 @@ pub enum Cmd {
}
impl Cmd {
pub async fn run(self, settings: Settings, db: &mut impl Database) -> Result<()> {
pub async fn run(
self,
settings: Settings,
db: &mut impl Database,
store: &mut impl Store,
) -> Result<()> {
match self {
Self::Sync { force } => run(&settings, force, db).await,
Self::Sync { force } => run(&settings, force, db, store).await,
Self::Login(l) => l.run(&settings).await,
Self::Logout => account::logout::run(&settings),
Self::Register(r) => r.run(&settings).await,
@ -62,12 +67,17 @@ impl Cmd {
}
}
async fn run(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> {
atuin_client::sync::sync(settings, force, db).await?;
println!(
"Sync complete! {} items in database, force: {}",
db.history_count().await?,
force
);
async fn run(
settings: &Settings,
force: bool,
db: &mut impl Database,
store: &mut impl Store,
) -> Result<()> {
let host = Settings::host_id().expect("No host ID found");
// FOR TESTING ONLY!
let kv_tail = store.last(host, "kv").await?.expect("no kv found");
let client = api_client::Client::new(&settings.sync_address, &settings.session_token)?;
client.post_records(&[kv_tail]).await?;
Ok(())
}