Add index handler, use UUIDs not strings

This commit is contained in:
Ellie Huxtable
2023-06-26 21:18:14 +01:00
parent fd51031bde
commit 1774ec93eb
18 changed files with 195 additions and 79 deletions

3
Cargo.lock generated
View File

@@ -240,6 +240,7 @@ dependencies = [
"serde", "serde",
"sqlx", "sqlx",
"tracing", "tracing",
"uuid",
] ]
[[package]] [[package]]
@@ -2549,6 +2550,7 @@ dependencies = [
"thiserror", "thiserror",
"tokio-stream", "tokio-stream",
"url", "url",
"uuid",
"webpki-roots", "webpki-roots",
"whoami", "whoami",
] ]
@@ -3019,6 +3021,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81" checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81"
dependencies = [ dependencies = [
"getrandom 0.2.7", "getrandom 0.2.7",
"serde",
] ]
[[package]] [[package]]

View File

@@ -1,11 +1,11 @@
[workspace] [workspace]
members = [ members = [
"atuin", "atuin",
"atuin-client", "atuin-client",
"atuin-server", "atuin-server",
"atuin-server-postgres", "atuin-server-postgres",
"atuin-server-database", "atuin-server-database",
"atuin-common", "atuin-common",
] ]
[workspace.package] [workspace.package]
@@ -35,7 +35,7 @@ semver = "1.0.14"
serde = { version = "1.0.145", features = ["derive"] } serde = { version = "1.0.145", features = ["derive"] }
serde_json = "1.0.86" serde_json = "1.0.86"
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
uuid = { version = "1.3", features = ["v4"] } uuid = { version = "1.3", features = ["v4", "serde"] }
whoami = "1.1.2" whoami = "1.1.2"
typed-builder = "0.14.0" typed-builder = "0.14.0"
@@ -46,4 +46,4 @@ default-features = false
[workspace.dependencies.sqlx] [workspace.dependencies.sqlx]
version = "0.6" version = "0.6"
features = ["runtime-tokio-rustls", "chrono", "postgres"] features = ["runtime-tokio-rustls", "chrono", "postgres", "uuid"]

View File

@@ -12,6 +12,7 @@ use atuin_common::api::{
AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse, AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse,
LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse, LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse,
}; };
use atuin_common::record::Record;
use semver::Version; use semver::Version;
use crate::{history::History, sync::hash_str}; use crate::{history::History, sync::hash_str};
@@ -195,6 +196,15 @@ impl<'a> Client<'a> {
Ok(()) Ok(())
} }
pub async fn post_records(&self, records: &[Record]) -> Result<()> {
let url = format!("{}/record", self.sync_addr);
let url = Url::parse(url.as_str())?;
self.client.post(url).json(records).send().await?;
Ok(())
}
pub async fn delete(&self) -> Result<()> { pub async fn delete(&self) -> Result<()> {
let url = format!("{}/account", self.sync_addr); let url = format!("{}/account", self.sync_addr);
let url = Url::parse(url.as_str())?; let url = Url::parse(url.as_str())?;

View File

@@ -13,6 +13,7 @@ use sqlx::{
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
Result, Row, Result, Row,
}; };
use uuid::Uuid;
use super::{ use super::{
history::History, history::History,
@@ -57,7 +58,7 @@ pub fn current_context() -> Context {
session, session,
hostname, hostname,
cwd, cwd,
host_id, host_id: host_id.as_simple().to_string(),
} }
} }

View File

@@ -101,10 +101,7 @@ impl KvStore {
let bytes = record.serialize()?; let bytes = record.serialize()?;
let parent = store let parent = store.last(host_id, KV_TAG).await?.map(|entry| entry.id);
.last(host_id.as_str(), KV_TAG)
.await?
.map(|entry| entry.id);
let record = atuin_common::record::Record::builder() let record = atuin_common::record::Record::builder()
.host(host_id) .host(host_id)
@@ -138,7 +135,7 @@ impl KvStore {
// iterate records to find the value we want // iterate records to find the value we want
// start at the end, so we get the most recent version // start at the end, so we get the most recent version
let Some(mut record) = store.last(host_id.as_str(), KV_TAG).await? else { let Some(mut record) = store.last(host_id, KV_TAG).await? else {
return Ok(None); return Ok(None);
}; };

View File

@@ -14,6 +14,7 @@ use sqlx::{
}; };
use atuin_common::record::{EncryptedData, Record}; use atuin_common::record::{EncryptedData, Record};
use uuid::Uuid;
use super::store::Store; use super::store::Store;
@@ -62,11 +63,11 @@ impl SqliteStore {
"insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek) "insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek)
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
) )
.bind(r.id.as_str()) .bind(r.id.as_simple().to_string())
.bind(r.host.as_str()) .bind(r.host.as_simple().to_string())
.bind(r.tag.as_str()) .bind(r.tag.as_str())
.bind(r.timestamp as i64) .bind(r.timestamp as i64)
.bind(r.parent.as_ref()) .bind(r.parent.map(|p| p.as_simple().to_string()))
.bind(r.version.as_str()) .bind(r.version.as_str())
.bind(r.data.data.as_str()) .bind(r.data.data.as_str())
.bind(r.data.content_encryption_key.as_str()) .bind(r.data.content_encryption_key.as_str())
@@ -79,10 +80,21 @@ impl SqliteStore {
fn query_row(row: SqliteRow) -> Record<EncryptedData> { fn query_row(row: SqliteRow) -> Record<EncryptedData> {
let timestamp: i64 = row.get("timestamp"); let timestamp: i64 = row.get("timestamp");
// tbh at this point things are pretty fucked so just panic
let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
let parent: Option<&str> = row.get("host");
let parent = if let Some(parent) = parent {
Some(Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB"))
} else {
None
};
Record { Record {
id: row.get("id"), id,
host: row.get("host"), host,
parent: row.get("parent"), parent,
timestamp: timestamp as u64, timestamp: timestamp as u64,
tag: row.get("tag"), tag: row.get("tag"),
version: row.get("version"), version: row.get("version"),
@@ -111,7 +123,7 @@ impl Store for SqliteStore {
Ok(()) Ok(())
} }
async fn get(&self, id: &str) -> Result<Record<EncryptedData>> { async fn get(&self, id: Uuid) -> Result<Record<EncryptedData>> {
let res = sqlx::query("select * from records where id = ?1") let res = sqlx::query("select * from records where id = ?1")
.bind(id) .bind(id)
.map(Self::query_row) .map(Self::query_row)
@@ -121,7 +133,7 @@ impl Store for SqliteStore {
Ok(res) Ok(res)
} }
async fn len(&self, host: &str, tag: &str) -> Result<u64> { async fn len(&self, host: Uuid, tag: &str) -> Result<u64> {
let res: (i64,) = let res: (i64,) =
sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2") sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
.bind(host) .bind(host)
@@ -146,7 +158,7 @@ impl Store for SqliteStore {
} }
} }
async fn first(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> { async fn first(&self, host: Uuid, tag: &str) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query( let res = sqlx::query(
"select * from records where host = ?1 and tag = ?2 and parent is null limit 1", "select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
) )
@@ -159,12 +171,12 @@ impl Store for SqliteStore {
Ok(res) Ok(res)
} }
async fn last(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> { async fn last(&self, host: Uuid, tag: &str) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query( let res = sqlx::query(
"select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;", "select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;",
) )
.bind(tag) .bind(tag)
.bind(host) .bind(host.as_simple().to_string())
.map(Self::query_row) .map(Self::query_row)
.fetch_optional(&self.pool) .fetch_optional(&self.pool)
.await?; .await?;
@@ -183,7 +195,7 @@ mod tests {
fn test_record() -> Record<EncryptedData> { fn test_record() -> Record<EncryptedData> {
Record::builder() Record::builder()
.host(atuin_common::utils::uuid_v7().simple().to_string()) .host(atuin_common::utils::uuid_v7())
.version("v1".into()) .version("v1".into())
.tag(atuin_common::utils::uuid_v7().simple().to_string()) .tag(atuin_common::utils::uuid_v7().simple().to_string())
.data(EncryptedData { .data(EncryptedData {
@@ -218,10 +230,7 @@ mod tests {
let record = test_record(); let record = test_record();
db.push(&record).await.unwrap(); db.push(&record).await.unwrap();
let new_record = db let new_record = db.get(record.id).await.expect("failed to fetch record");
.get(record.id.as_str())
.await
.expect("failed to fetch record");
assert_eq!(record, new_record, "records are not equal"); assert_eq!(record, new_record, "records are not equal");
} }
@@ -233,7 +242,7 @@ mod tests {
db.push(&record).await.unwrap(); db.push(&record).await.unwrap();
let len = db let len = db
.len(record.host.as_str(), record.tag.as_str()) .len(record.host, record.tag.as_str())
.await .await
.expect("failed to get store len"); .expect("failed to get store len");
@@ -253,14 +262,8 @@ mod tests {
db.push(&first).await.unwrap(); db.push(&first).await.unwrap();
db.push(&second).await.unwrap(); db.push(&second).await.unwrap();
let first_len = db let first_len = db.len(first.host, first.tag.as_str()).await.unwrap();
.len(first.host.as_str(), first.tag.as_str()) let second_len = db.len(second.host, second.tag.as_str()).await.unwrap();
.await
.unwrap();
let second_len = db
.len(second.host.as_str(), second.tag.as_str())
.await
.unwrap();
assert_eq!(first_len, 1, "expected length of 1 after insert"); assert_eq!(first_len, 1, "expected length of 1 after insert");
assert_eq!(second_len, 1, "expected length of 1 after insert"); assert_eq!(second_len, 1, "expected length of 1 after insert");
@@ -281,7 +284,7 @@ mod tests {
} }
assert_eq!( assert_eq!(
db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(), db.len(tail.host, tail.tag.as_str()).await.unwrap(),
100, 100,
"failed to insert 100 records" "failed to insert 100 records"
); );
@@ -304,7 +307,7 @@ mod tests {
db.push_batch(records.iter()).await.unwrap(); db.push_batch(records.iter()).await.unwrap();
assert_eq!( assert_eq!(
db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(), db.len(tail.host, tail.tag.as_str()).await.unwrap(),
10000, 10000,
"failed to insert 10k records" "failed to insert 10k records"
); );
@@ -327,7 +330,7 @@ mod tests {
db.push_batch(records.iter()).await.unwrap(); db.push_batch(records.iter()).await.unwrap();
let mut record = db let mut record = db
.first(tail.host.as_str(), tail.tag.as_str()) .first(tail.host, tail.tag.as_str())
.await .await
.expect("in memory sqlite should not fail") .expect("in memory sqlite should not fail")
.expect("entry exists"); .expect("entry exists");

View File

@@ -2,6 +2,7 @@ use async_trait::async_trait;
use eyre::Result; use eyre::Result;
use atuin_common::record::{EncryptedData, Record}; use atuin_common::record::{EncryptedData, Record};
use uuid::Uuid;
/// A record store stores records /// A record store stores records
/// In more detail - we tend to need to process this into _another_ format to actually query it. /// In more detail - we tend to need to process this into _another_ format to actually query it.
@@ -21,13 +22,13 @@ pub trait Store {
) -> Result<()>; ) -> Result<()>;
async fn get(&self, id: &str) -> Result<Record<EncryptedData>>; async fn get(&self, id: &str) -> Result<Record<EncryptedData>>;
async fn len(&self, host: &str, tag: &str) -> Result<u64>; async fn len(&self, host: Uuid, tag: &str) -> Result<u64>;
/// Get the record that follows this record /// Get the record that follows this record
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>; async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>;
/// Get the first record for a given host and tag /// Get the first record for a given host and tag
async fn first(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>>; async fn first(&self, host: Uuid, tag: &str) -> Result<Option<Record<EncryptedData>>>;
/// Get the last record for a given host and tag /// Get the last record for a given host and tag
async fn last(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>>; async fn last(&self, host: Uuid, tag: &str) -> Result<Option<Record<EncryptedData>>>;
} }

View File

@@ -1,6 +1,7 @@
use std::{ use std::{
io::prelude::*, io::prelude::*,
path::{Path, PathBuf}, path::{Path, PathBuf},
str::FromStr,
}; };
use chrono::{prelude::*, Utc}; use chrono::{prelude::*, Utc};
@@ -12,6 +13,7 @@ use parse_duration::parse;
use regex::RegexSet; use regex::RegexSet;
use semver::Version; use semver::Version;
use serde::Deserialize; use serde::Deserialize;
use uuid::Uuid;
pub const HISTORY_PAGE_SIZE: i64 = 100; pub const HISTORY_PAGE_SIZE: i64 = 100;
pub const LAST_SYNC_FILENAME: &str = "last_sync_time"; pub const LAST_SYNC_FILENAME: &str = "last_sync_time";
@@ -228,11 +230,13 @@ impl Settings {
Settings::load_time_from_file(LAST_VERSION_CHECK_FILENAME) Settings::load_time_from_file(LAST_VERSION_CHECK_FILENAME)
} }
pub fn host_id() -> Option<String> { pub fn host_id() -> Option<Uuid> {
let id = Settings::read_from_data_dir(HOST_ID_FILENAME); let id = Settings::read_from_data_dir(HOST_ID_FILENAME);
if id.is_some() { if id.is_some() {
return id; let parsed = Uuid::from_str(id.unwrap().as_str())
.expect("failed to parse host ID from local directory");
return Some(parsed);
} }
let uuid = atuin_common::utils::uuid_v7(); let uuid = atuin_common::utils::uuid_v7();
@@ -240,7 +244,7 @@ impl Settings {
Settings::save_to_data_dir(HOST_ID_FILENAME, uuid.as_simple().to_string().as_ref()) Settings::save_to_data_dir(HOST_ID_FILENAME, uuid.as_simple().to_string().as_ref())
.expect("Could not write host ID to data dir"); .expect("Could not write host ID to data dir");
Some(uuid.as_simple().to_string()) Some(uuid)
} }
pub fn should_sync(&self) -> Result<bool> { pub fn should_sync(&self) -> Result<bool> {

View File

@@ -3,6 +3,7 @@ use std::collections::HashMap;
use eyre::Result; use eyre::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use typed_builder::TypedBuilder; use typed_builder::TypedBuilder;
use uuid::Uuid;
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct DecryptedData(pub Vec<u8>); pub struct DecryptedData(pub Vec<u8>);
@@ -17,21 +18,21 @@ pub struct EncryptedData {
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)]
pub struct Record<Data> { pub struct Record<Data> {
/// a unique ID /// a unique ID
#[builder(default = crate::utils::uuid_v7().as_simple().to_string())] #[builder(default = crate::utils::uuid_v7())]
pub id: String, pub id: Uuid,
/// The unique ID of the host. /// The unique ID of the host.
// TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store // TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store
// as strings. I would rather avoid normalization, so store as UUID binary instead of // as strings. I would rather avoid normalization, so store as UUID binary instead of
// encoding to a string and wasting much more storage. // encoding to a string and wasting much more storage.
pub host: String, pub host: Uuid,
/// The ID of the parent entry /// The ID of the parent entry
// A store is technically just a double linked list // A store is technically just a double linked list
// We can do some cheating with the timestamps, but should not rely upon them. // We can do some cheating with the timestamps, but should not rely upon them.
// Clocks are tricksy. // Clocks are tricksy.
#[builder(default)] #[builder(default)]
pub parent: Option<String>, pub parent: Option<Uuid>,
/// The creation time in nanoseconds since unix epoch /// The creation time in nanoseconds since unix epoch
#[builder(default = chrono::Utc::now().timestamp_nanos() as u64)] #[builder(default = chrono::Utc::now().timestamp_nanos() as u64)]
@@ -71,9 +72,10 @@ impl<Data> Record<Data> {
/// An index representing the current state of the record stores /// An index representing the current state of the record stores
/// This can be both remote, or local, and compared in either direction /// This can be both remote, or local, and compared in either direction
#[derive(Debug, Serialize, Deserialize)]
pub struct RecordIndex { pub struct RecordIndex {
// A map of host -> tag -> tail // A map of host -> tag -> tail
pub hosts: HashMap<String, HashMap<String, String>>, pub hosts: HashMap<Uuid, HashMap<String, Uuid>>,
} }
impl Default for RecordIndex { impl Default for RecordIndex {
@@ -97,7 +99,11 @@ impl RecordIndex {
.insert(tail.tag, tail.id); .insert(tail.tag, tail.id);
} }
pub fn get(&self, host: String, tag: String) -> Option<String> { pub fn set_raw(&mut self, host: Uuid, tag: String, tail: Uuid) {
self.hosts.entry(host).or_default().insert(tag, tail);
}
pub fn get(&self, host: Uuid, tag: String) -> Option<Uuid> {
self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned() self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned()
} }
@@ -108,7 +114,7 @@ impl RecordIndex {
/// other machine has a different tail, it will be the differing tail. This is useful to /// other machine has a different tail, it will be the differing tail. This is useful to
/// check if the other index is ahead of us, or behind. /// check if the other index is ahead of us, or behind.
/// If the other index does not have the (host, tag) pair, then the other value will be None. /// If the other index does not have the (host, tag) pair, then the other value will be None.
pub fn diff(&self, other: &Self) -> Vec<(String, String, Option<String>)> { pub fn diff(&self, other: &Self) -> Vec<(Uuid, String, Option<Uuid>)> {
let mut ret = Vec::new(); let mut ret = Vec::new();
// First, we check if other has everything that self has // First, we check if other has everything that self has
@@ -227,10 +233,11 @@ impl Record<EncryptedData> {
mod tests { mod tests {
use super::{DecryptedData, Record, RecordIndex}; use super::{DecryptedData, Record, RecordIndex};
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use uuid::Uuid;
fn test_record() -> Record<DecryptedData> { fn test_record() -> Record<DecryptedData> {
Record::builder() Record::builder()
.host(crate::utils::uuid_v7().simple().to_string()) .host(crate::utils::uuid_v7())
.version("v1".into()) .version("v1".into())
.tag(crate::utils::uuid_v7().simple().to_string()) .tag(crate::utils::uuid_v7().simple().to_string())
.data(DecryptedData(vec![0, 1, 2, 3])) .data(DecryptedData(vec![0, 1, 2, 3]))
@@ -344,9 +351,9 @@ mod tests {
// both diffs should be ALMOST the same. They will agree on which hosts and tags // both diffs should be ALMOST the same. They will agree on which hosts and tags
// require updating, but the "other" value will not be the same. // require updating, but the "other" value will not be the same.
let smol_diff_1: Vec<(String, String)> = let smol_diff_1: Vec<(Uuid, String)> =
diff1.iter().map(|v| (v.0.clone(), v.1.clone())).collect(); diff1.iter().map(|v| (v.0.clone(), v.1.clone())).collect();
let smol_diff_2: Vec<(String, String)> = let smol_diff_2: Vec<(Uuid, String)> =
diff1.iter().map(|v| (v.0.clone(), v.1.clone())).collect(); diff1.iter().map(|v| (v.0.clone(), v.1.clone())).collect();
assert_eq!(smol_diff_1, smol_diff_2); assert_eq!(smol_diff_1, smol_diff_2);

View File

@@ -18,6 +18,7 @@ use chrono::{Datelike, TimeZone};
use chronoutil::RelativeDuration; use chronoutil::RelativeDuration;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use tracing::instrument; use tracing::instrument;
use uuid::Uuid;
#[derive(Debug)] #[derive(Debug)]
pub enum DbError { pub enum DbError {
@@ -55,10 +56,10 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
async fn delete_history(&self, user: &User, id: String) -> DbResult<()>; async fn delete_history(&self, user: &User, id: String) -> DbResult<()>;
async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>; async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>;
async fn add_record(&self, user: &User, record: &[Record]) -> DbResult<()>; async fn add_records(&self, user: &User, record: &[Record]) -> DbResult<()>;
// Return the tail record ID for each store, so (HostID, Tag, TailRecordID) // Return the tail record ID for each store, so (HostID, Tag, TailRecordID)
async fn tail_records(&self, user: &User) -> DbResult<Vec<(String, String, String)>>; async fn tail_records(&self, user: &User) -> DbResult<Vec<(Uuid, String, Uuid)>>;
async fn count_history_range( async fn count_history_range(
&self, &self,

View File

@@ -18,4 +18,5 @@ chrono = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
sqlx = { workspace = true } sqlx = { workspace = true }
async-trait = { workspace = true } async-trait = { workspace = true }
uuid = { workspace = true }
futures-util = "0.3" futures-util = "0.3"

View File

@@ -1,3 +1,5 @@
use std::str::FromStr;
use async_trait::async_trait; use async_trait::async_trait;
use atuin_common::record::Record; use atuin_common::record::Record;
use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User}; use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
@@ -5,9 +7,10 @@ use atuin_server_database::{Database, DbError, DbResult};
use futures_util::TryStreamExt; use futures_util::TryStreamExt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::postgres::PgPoolOptions; use sqlx::postgres::PgPoolOptions;
use sqlx::Row; use sqlx::Row;
use sqlx::types::Uuid;
use tracing::instrument; use tracing::instrument;
use wrappers::{DbHistory, DbSession, DbUser}; use wrappers::{DbHistory, DbSession, DbUser};
@@ -331,11 +334,11 @@ impl Database for Postgres {
.map(|DbHistory(h)| h) .map(|DbHistory(h)| h)
} }
async fn add_record(&self, user: &User, records: &[Record]) -> DbResult<()> { async fn add_records(&self, user: &User, records: &[Record]) -> DbResult<()> {
let mut tx = self.pool.begin().await.map_err(fix_error)?; let mut tx = self.pool.begin().await.map_err(fix_error)?;
for i in records { for i in records {
let id = atuin_common::utils::uuid_v7().as_simple().to_string(); let id = atuin_common::utils::uuid_v7();
sqlx::query( sqlx::query(
"insert into records "insert into records
@@ -345,9 +348,9 @@ impl Database for Postgres {
", ",
) )
.bind(id) .bind(id)
.bind(&i.id) .bind(i.id)
.bind(&i.host) .bind(i.host)
.bind(&i.parent) .bind(id)
.bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
.bind(&i.version) .bind(&i.version)
.bind(&i.tag) .bind(&i.tag)
@@ -363,8 +366,8 @@ impl Database for Postgres {
Ok(()) Ok(())
} }
async fn tail_records(&self, user: &User) -> DbResult<Vec<(String, String, String)>> { async fn tail_records(&self, user: &User) -> DbResult<Vec<(Uuid, String, Uuid)>> {
const TAIL_RECORDS_SQL: &str = "select host, tag, id from records rp where (select count(1) from records where parent=rp.id and user_id = $1) = 0 group by host, tag;"; const TAIL_RECORDS_SQL: &str = "select host, tag, id from records rp where (select count(1) from records where parent=rp.id and user_id = $1) = 0;";
let res = sqlx::query_as(TAIL_RECORDS_SQL) let res = sqlx::query_as(TAIL_RECORDS_SQL)
.bind(user.id) .bind(user.id)

View File

@@ -2,6 +2,7 @@ use atuin_common::api::{ErrorResponse, IndexResponse};
use axum::{response::IntoResponse, Json}; use axum::{response::IntoResponse, Json};
pub mod history; pub mod history;
pub mod record;
pub mod status; pub mod status;
pub mod user; pub mod user;

View File

@@ -0,0 +1,70 @@
use axum::{extract::State, Json};
use http::StatusCode;
use tracing::{error, instrument};
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
use crate::router::{AppState, UserAuth};
use atuin_server_database::Database;
use atuin_common::record::{Record, RecordIndex};
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn post<DB: Database>(
UserAuth(user): UserAuth,
state: State<AppState<DB>>,
Json(records): Json<Vec<Record>>,
) -> Result<(), ErrorResponseStatus<'static>> {
let State(AppState { database, settings }) = state;
tracing::debug!(
count = records.len(),
user = user.username,
"request to add records"
);
let too_big = records
.iter()
.any(|r| r.data.len() >= settings.max_record_size || settings.max_record_size == 0);
if too_big {
return Err(
ErrorResponse::reply("could not add records; record too large")
.with_status(StatusCode::BAD_REQUEST),
);
}
if let Err(e) = database.add_records(&user, &records).await {
error!("failed to add record: {}", e);
return Err(ErrorResponse::reply("failed to add record")
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
};
Ok(())
}
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn index<DB: Database>(
UserAuth(user): UserAuth,
state: State<AppState<DB>>,
) -> Result<Json<RecordIndex>, ErrorResponseStatus<'static>> {
let State(AppState { database, settings }) = state;
let index = match database.tail_records(&user).await {
Ok(index) => index,
Err(e) => {
error!("failed to get record index: {}", e);
return Err(ErrorResponse::reply("failed to calculate record index")
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
}
};
let mut record_index = RecordIndex::new();
for row in index {
record_index.set_raw(row.0, row.1, row.2);
}
Ok(Json(record_index))
}

View File

@@ -71,6 +71,8 @@ pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> R
.route("/sync/status", get(handlers::status::status)) .route("/sync/status", get(handlers::status::status))
.route("/history", post(handlers::history::add)) .route("/history", post(handlers::history::add))
.route("/history", delete(handlers::history::delete)) .route("/history", delete(handlers::history::delete))
.route("/record", post(handlers::record::post))
.route("/record", get(handlers::record::index))
.route("/user/:username", get(handlers::user::get)) .route("/user/:username", get(handlers::user::get))
.route("/account", delete(handlers::user::delete)) .route("/account", delete(handlers::user::delete))
.route("/register", post(handlers::user::register)) .route("/register", post(handlers::user::register))

View File

@@ -12,6 +12,7 @@ pub struct Settings<DbSettings> {
pub path: String, pub path: String,
pub open_registration: bool, pub open_registration: bool,
pub max_history_length: usize, pub max_history_length: usize,
pub max_record_size: usize,
pub page_size: i64, pub page_size: i64,
pub register_webhook_url: Option<String>, pub register_webhook_url: Option<String>,
pub register_webhook_username: String, pub register_webhook_username: String,
@@ -39,6 +40,7 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
.set_default("port", 8888)? .set_default("port", 8888)?
.set_default("open_registration", false)? .set_default("open_registration", false)?
.set_default("max_history_length", 8192)? .set_default("max_history_length", 8192)?
.set_default("max_record_size", 1024 * 1024 * 1024)? // pretty chonky
.set_default("path", "")? .set_default("path", "")?
.set_default("register_webhook_username", "")? .set_default("register_webhook_username", "")?
.set_default("page_size", 1100)? .set_default("page_size", 1100)?

View File

@@ -69,7 +69,7 @@ impl Cmd {
Self::Search(search) => search.run(db, &mut settings).await, Self::Search(search) => search.run(db, &mut settings).await,
#[cfg(feature = "sync")] #[cfg(feature = "sync")]
Self::Sync(sync) => sync.run(settings, &mut db).await, Self::Sync(sync) => sync.run(settings, &mut db, &mut store).await,
#[cfg(feature = "sync")] #[cfg(feature = "sync")]
Self::Account(account) => account.run(settings).await, Self::Account(account) => account.run(settings).await,

View File

@@ -1,7 +1,7 @@
use clap::Subcommand; use clap::Subcommand;
use eyre::{Result, WrapErr}; use eyre::{Result, WrapErr};
use atuin_client::{database::Database, settings::Settings}; use atuin_client::{api_client, database::Database, record::store::Store, settings::Settings};
mod status; mod status;
@@ -37,9 +37,14 @@ pub enum Cmd {
} }
impl Cmd { impl Cmd {
pub async fn run(self, settings: Settings, db: &mut impl Database) -> Result<()> { pub async fn run(
self,
settings: Settings,
db: &mut impl Database,
store: &mut impl Store,
) -> Result<()> {
match self { match self {
Self::Sync { force } => run(&settings, force, db).await, Self::Sync { force } => run(&settings, force, db, store).await,
Self::Login(l) => l.run(&settings).await, Self::Login(l) => l.run(&settings).await,
Self::Logout => account::logout::run(&settings), Self::Logout => account::logout::run(&settings),
Self::Register(r) => r.run(&settings).await, Self::Register(r) => r.run(&settings).await,
@@ -62,12 +67,17 @@ impl Cmd {
} }
} }
async fn run(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> { async fn run(
atuin_client::sync::sync(settings, force, db).await?; settings: &Settings,
println!( force: bool,
"Sync complete! {} items in database, force: {}", db: &mut impl Database,
db.history_count().await?, store: &mut impl Store,
force ) -> Result<()> {
); let host = Settings::host_id().expect("No host ID found");
// FOR TESTING ONLY!
let kv_tail = store.last(host, "kv").await?.expect("no kv found");
let client = api_client::Client::new(&settings.sync_address, &settings.session_token)?;
client.post_records(&[kv_tail]).await?;
Ok(()) Ok(())
} }