Add new sync (#1093)

* Add record migration

* Add database functions for inserting history

No real tests yet :( I would like to avoid running postgres lol

* Add index handler, use UUIDs not strings

* Fix a bunch of tests, remove Option<Uuid>

* Add tests, all passing

* Working upload sync

* Record downloading works

* Sync download works

* Don't waste requests

* Use a page size for uploads, make it variable later

* Aaaaaand they're encrypted now too

* Add cek

* Allow reading tail across hosts

* Revert "Allow reading tail across hosts"

Not like that

This reverts commit 7b0c72e7e050c358172f9b53cbd21b9e44cf4931.

* Handle multiple shards properly

* format

* Format and make clippy happy

* use some fancy types (#1098)

* use some fancy types

* fmt

* Goodbye horrible tuple

* Update atuin-server-postgres/migrations/20230623070418_records.sql

Co-authored-by: Conrad Ludgate <conradludgate@gmail.com>

* fmt

* Sort tests too because time sucks

* fix features

---------

Co-authored-by: Conrad Ludgate <conradludgate@gmail.com>
This commit is contained in:
Ellie Huxtable 2023-07-14 20:44:08 +01:00 committed by GitHub
parent 3d4302ded1
commit 97e24d0d41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 1094 additions and 143 deletions

30
Cargo.lock generated
View File

@ -142,6 +142,7 @@ dependencies = [
"directories", "directories",
"eyre", "eyre",
"fs-err", "fs-err",
"futures",
"generic-array", "generic-array",
"hex", "hex",
"interim", "interim",
@ -151,6 +152,7 @@ dependencies = [
"memchr", "memchr",
"minspan", "minspan",
"parse_duration", "parse_duration",
"pretty_assertions",
"rand 0.8.5", "rand 0.8.5",
"regex", "regex",
"reqwest", "reqwest",
@ -182,6 +184,7 @@ dependencies = [
"pretty_assertions", "pretty_assertions",
"rand 0.8.5", "rand 0.8.5",
"serde", "serde",
"sqlx",
"typed-builder", "typed-builder",
"uuid", "uuid",
] ]
@ -240,6 +243,7 @@ dependencies = [
"serde", "serde",
"sqlx", "sqlx",
"tracing", "tracing",
"uuid",
] ]
[[package]] [[package]]
@ -950,6 +954,21 @@ version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0845fa252299212f0389d64ba26f34fa32cfe41588355f21ed507c59a0f64541" checksum = "0845fa252299212f0389d64ba26f34fa32cfe41588355f21ed507c59a0f64541"
[[package]]
name = "futures"
version = "0.3.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.24" version = "0.3.24"
@ -988,6 +1007,12 @@ dependencies = [
"parking_lot 0.11.2", "parking_lot 0.11.2",
] ]
[[package]]
name = "futures-io"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964"
[[package]] [[package]]
name = "futures-macro" name = "futures-macro"
version = "0.3.24" version = "0.3.24"
@ -1017,10 +1042,13 @@ version = "0.3.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90" checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90"
dependencies = [ dependencies = [
"futures-channel",
"futures-core", "futures-core",
"futures-io",
"futures-macro", "futures-macro",
"futures-sink", "futures-sink",
"futures-task", "futures-task",
"memchr",
"pin-project-lite", "pin-project-lite",
"pin-utils", "pin-utils",
"slab", "slab",
@ -2567,6 +2595,7 @@ dependencies = [
"thiserror", "thiserror",
"tokio-stream", "tokio-stream",
"url", "url",
"uuid",
"webpki-roots", "webpki-roots",
"whoami", "whoami",
] ]
@ -3037,6 +3066,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81" checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81"
dependencies = [ dependencies = [
"getrandom 0.2.7", "getrandom 0.2.7",
"serde",
] ]
[[package]] [[package]]

View File

@ -1,11 +1,11 @@
[workspace] [workspace]
members = [ members = [
"atuin", "atuin",
"atuin-client", "atuin-client",
"atuin-server", "atuin-server",
"atuin-server-postgres", "atuin-server-postgres",
"atuin-server-database", "atuin-server-database",
"atuin-common", "atuin-common",
] ]
[workspace.package] [workspace.package]
@ -35,9 +35,10 @@ semver = "1.0.14"
serde = { version = "1.0.145", features = ["derive"] } serde = { version = "1.0.145", features = ["derive"] }
serde_json = "1.0.86" serde_json = "1.0.86"
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
uuid = { version = "1.3", features = ["v4"] } uuid = { version = "1.3", features = ["v4", "serde"] }
whoami = "1.1.2" whoami = "1.1.2"
typed-builder = "0.14.0" typed-builder = "0.14.0"
pretty_assertions = "1.3.0"
[workspace.dependencies.reqwest] [workspace.dependencies.reqwest]
version = "0.11" version = "0.11"
@ -46,4 +47,4 @@ default-features = false
[workspace.dependencies.sqlx] [workspace.dependencies.sqlx]
version = "0.6" version = "0.6"
features = ["runtime-tokio-rustls", "chrono", "postgres"] features = ["runtime-tokio-rustls", "chrono", "postgres", "uuid"]

View File

@ -54,10 +54,14 @@ rmp = { version = "0.8.11" }
typed-builder = "0.14.0" typed-builder = "0.14.0"
tokio = { workspace = true } tokio = { workspace = true }
semver = { workspace = true } semver = { workspace = true }
futures = "0.3"
# encryption # encryption
rusty_paseto = { version = "0.5.0", default-features = false } rusty_paseto = { version = "0.5.0", default-features = false }
rusty_paserk = { version = "0.2.0", default-features = false, features = ["v4", "serde"] } rusty_paserk = { version = "0.2.0", default-features = false, features = [
"v4",
"serde",
] }
# sync # sync
urlencoding = { version = "2.1.0", optional = true } urlencoding = { version = "2.1.0", optional = true }
@ -69,3 +73,4 @@ generic-array = { version = "0.14", optional = true, features = ["serde"] }
[dev-dependencies] [dev-dependencies]
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
pretty_assertions = { workspace = true }

View File

@ -7,7 +7,8 @@ create table if not exists records (
timestamp integer not null, timestamp integer not null,
tag text not null, tag text not null,
version text not null, version text not null,
data blob not null data blob not null,
cek blob not null
); );
create index host_idx on records (host); create index host_idx on records (host);

View File

@ -1,3 +0,0 @@
-- store content encryption keys in the record
alter table records
add column cek text;

View File

@ -8,9 +8,13 @@ use reqwest::{
StatusCode, Url, StatusCode, Url,
}; };
use atuin_common::api::{ use atuin_common::record::{EncryptedData, HostId, Record, RecordId};
AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse, use atuin_common::{
LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse, api::{
AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse,
LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse,
},
record::RecordIndex,
}; };
use semver::Version; use semver::Version;
@ -195,6 +199,55 @@ impl<'a> Client<'a> {
Ok(()) Ok(())
} }
pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> {
let url = format!("{}/record", self.sync_addr);
let url = Url::parse(url.as_str())?;
self.client.post(url).json(records).send().await?;
Ok(())
}
pub async fn next_records(
&self,
host: HostId,
tag: String,
start: Option<RecordId>,
count: u64,
) -> Result<Vec<Record<EncryptedData>>> {
let url = format!(
"{}/record/next?host={}&tag={}&count={}",
self.sync_addr, host.0, tag, count
);
let mut url = Url::parse(url.as_str())?;
if let Some(start) = start {
url.set_query(Some(
format!(
"host={}&tag={}&count={}&start={}",
host.0, tag, count, start.0
)
.as_str(),
));
}
let resp = self.client.get(url).send().await?;
let records = resp.json::<Vec<Record<EncryptedData>>>().await?;
Ok(records)
}
pub async fn record_index(&self) -> Result<RecordIndex> {
let url = format!("{}/record", self.sync_addr);
let url = Url::parse(url.as_str())?;
let resp = self.client.get(url).send().await?;
let index = resp.json().await?;
Ok(index)
}
pub async fn delete(&self) -> Result<()> { pub async fn delete(&self) -> Result<()> {
let url = format!("{}/account", self.sync_addr); let url = format!("{}/account", self.sync_addr);
let url = Url::parse(url.as_str())?; let url = Url::parse(url.as_str())?;

View File

@ -57,7 +57,7 @@ pub fn current_context() -> Context {
session, session,
hostname, hostname,
cwd, cwd,
host_id, host_id: host_id.0.as_simple().to_string(),
} }
} }

View File

@ -101,10 +101,7 @@ impl KvStore {
let bytes = record.serialize()?; let bytes = record.serialize()?;
let parent = store let parent = store.tail(host_id, KV_TAG).await?.map(|entry| entry.id);
.last(host_id.as_str(), KV_TAG)
.await?
.map(|entry| entry.id);
let record = atuin_common::record::Record::builder() let record = atuin_common::record::Record::builder()
.host(host_id) .host(host_id)
@ -130,17 +127,22 @@ impl KvStore {
namespace: &str, namespace: &str,
key: &str, key: &str,
) -> Result<Option<KvRecord>> { ) -> Result<Option<KvRecord>> {
// TODO: don't load this from disk so much
let host_id = Settings::host_id().expect("failed to get host_id");
// Currently, this is O(n). When we have an actual KV store, it can be better // Currently, this is O(n). When we have an actual KV store, it can be better
// Just a poc for now! // Just a poc for now!
// iterate records to find the value we want // iterate records to find the value we want
// start at the end, so we get the most recent version // start at the end, so we get the most recent version
let Some(mut record) = store.last(host_id.as_str(), KV_TAG).await? else { let tails = store.tag_tails(KV_TAG).await?;
if tails.is_empty() {
return Ok(None); return Ok(None);
}; }
// first, decide on a record.
// try getting the newest first
// we always need a way of deciding the "winner" of a write
// TODO(ellie): something better than last-write-wins, what if two write at the same time?
let mut record = tails.iter().max_by_key(|r| r.timestamp).unwrap().clone();
loop { loop {
let decrypted = match record.version.as_str() { let decrypted = match record.version.as_str() {
@ -154,7 +156,7 @@ impl KvStore {
} }
if let Some(parent) = decrypted.parent { if let Some(parent) = decrypted.parent {
record = store.get(parent.as_str()).await?; record = store.get(parent).await?;
} else { } else {
break; break;
} }

View File

@ -1,4 +1,6 @@
use atuin_common::record::{AdditionalData, DecryptedData, EncryptedData, Encryption}; use atuin_common::record::{
AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId,
};
use base64::{engine::general_purpose, Engine}; use base64::{engine::general_purpose, Engine};
use eyre::{ensure, Context, Result}; use eyre::{ensure, Context, Result};
use rusty_paserk::{Key, KeyId, Local, PieWrappedKey}; use rusty_paserk::{Key, KeyId, Local, PieWrappedKey};
@ -158,10 +160,11 @@ struct AtuinFooter {
// This cannot be changed, otherwise it breaks the authenticated encryption. // This cannot be changed, otherwise it breaks the authenticated encryption.
#[derive(Debug, Copy, Clone, Serialize)] #[derive(Debug, Copy, Clone, Serialize)]
struct Assertions<'a> { struct Assertions<'a> {
id: &'a str, id: &'a RecordId,
version: &'a str, version: &'a str,
tag: &'a str, tag: &'a str,
host: &'a str, host: &'a HostId,
parent: Option<&'a RecordId>,
} }
impl<'a> From<AdditionalData<'a>> for Assertions<'a> { impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
@ -171,6 +174,7 @@ impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
version: ad.version, version: ad.version,
tag: ad.tag, tag: ad.tag,
host: ad.host, host: ad.host,
parent: ad.parent,
} }
} }
} }
@ -183,7 +187,7 @@ impl Assertions<'_> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use atuin_common::record::Record; use atuin_common::{record::Record, utils::uuid_v7};
use super::*; use super::*;
@ -192,10 +196,11 @@ mod tests {
let key = Key::<V4, Local>::new_os_random(); let key = Key::<V4, Local>::new_os_random();
let ad = AdditionalData { let ad = AdditionalData {
id: "foo", id: &RecordId(uuid_v7()),
version: "v0", version: "v0",
tag: "kv", tag: "kv",
host: "1234", host: &HostId(uuid_v7()),
parent: None,
}; };
let data = DecryptedData(vec![1, 2, 3, 4]); let data = DecryptedData(vec![1, 2, 3, 4]);
@ -210,10 +215,11 @@ mod tests {
let key = Key::<V4, Local>::new_os_random(); let key = Key::<V4, Local>::new_os_random();
let ad = AdditionalData { let ad = AdditionalData {
id: "foo", id: &RecordId(uuid_v7()),
version: "v0", version: "v0",
tag: "kv", tag: "kv",
host: "1234", host: &HostId(uuid_v7()),
parent: None,
}; };
let data = DecryptedData(vec![1, 2, 3, 4]); let data = DecryptedData(vec![1, 2, 3, 4]);
@ -233,10 +239,11 @@ mod tests {
let fake_key = Key::<V4, Local>::new_os_random(); let fake_key = Key::<V4, Local>::new_os_random();
let ad = AdditionalData { let ad = AdditionalData {
id: "foo", id: &RecordId(uuid_v7()),
version: "v0", version: "v0",
tag: "kv", tag: "kv",
host: "1234", host: &HostId(uuid_v7()),
parent: None,
}; };
let data = DecryptedData(vec![1, 2, 3, 4]); let data = DecryptedData(vec![1, 2, 3, 4]);
@ -250,10 +257,11 @@ mod tests {
let key = Key::<V4, Local>::new_os_random(); let key = Key::<V4, Local>::new_os_random();
let ad = AdditionalData { let ad = AdditionalData {
id: "foo", id: &RecordId(uuid_v7()),
version: "v0", version: "v0",
tag: "kv", tag: "kv",
host: "1234", host: &HostId(uuid_v7()),
parent: None,
}; };
let data = DecryptedData(vec![1, 2, 3, 4]); let data = DecryptedData(vec![1, 2, 3, 4]);
@ -261,10 +269,8 @@ mod tests {
let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes());
let ad = AdditionalData { let ad = AdditionalData {
id: "foo1", id: &RecordId(uuid_v7()),
version: "v0", ..ad
tag: "kv",
host: "1234",
}; };
let _ = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap_err(); let _ = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap_err();
} }
@ -275,10 +281,11 @@ mod tests {
let key2 = Key::<V4, Local>::new_os_random(); let key2 = Key::<V4, Local>::new_os_random();
let ad = AdditionalData { let ad = AdditionalData {
id: "foo", id: &RecordId(uuid_v7()),
version: "v0", version: "v0",
tag: "kv", tag: "kv",
host: "1234", host: &HostId(uuid_v7()),
parent: None,
}; };
let data = DecryptedData(vec![1, 2, 3, 4]); let data = DecryptedData(vec![1, 2, 3, 4]);
@ -304,10 +311,10 @@ mod tests {
fn full_record_round_trip() { fn full_record_round_trip() {
let key = [0x55; 32]; let key = [0x55; 32];
let record = Record::builder() let record = Record::builder()
.id("1".to_owned()) .id(RecordId(uuid_v7()))
.version("v0".to_owned()) .version("v0".to_owned())
.tag("kv".to_owned()) .tag("kv".to_owned())
.host("host1".to_owned()) .host(HostId(uuid_v7()))
.timestamp(1687244806000000) .timestamp(1687244806000000)
.data(DecryptedData(vec![1, 2, 3, 4])) .data(DecryptedData(vec![1, 2, 3, 4]))
.build(); .build();
@ -316,30 +323,20 @@ mod tests {
assert!(!encrypted.data.data.is_empty()); assert!(!encrypted.data.data.is_empty());
assert!(!encrypted.data.content_encryption_key.is_empty()); assert!(!encrypted.data.content_encryption_key.is_empty());
assert_eq!(encrypted.id, "1");
assert_eq!(encrypted.host, "host1");
assert_eq!(encrypted.version, "v0");
assert_eq!(encrypted.tag, "kv");
assert_eq!(encrypted.timestamp, 1687244806000000);
let decrypted = encrypted.decrypt::<PASETO_V4>(&key).unwrap(); let decrypted = encrypted.decrypt::<PASETO_V4>(&key).unwrap();
assert_eq!(decrypted.data.0, [1, 2, 3, 4]); assert_eq!(decrypted.data.0, [1, 2, 3, 4]);
assert_eq!(decrypted.id, "1");
assert_eq!(decrypted.host, "host1");
assert_eq!(decrypted.version, "v0");
assert_eq!(decrypted.tag, "kv");
assert_eq!(decrypted.timestamp, 1687244806000000);
} }
#[test] #[test]
fn full_record_round_trip_fail() { fn full_record_round_trip_fail() {
let key = [0x55; 32]; let key = [0x55; 32];
let record = Record::builder() let record = Record::builder()
.id("1".to_owned()) .id(RecordId(uuid_v7()))
.version("v0".to_owned()) .version("v0".to_owned())
.tag("kv".to_owned()) .tag("kv".to_owned())
.host("host1".to_owned()) .host(HostId(uuid_v7()))
.timestamp(1687244806000000) .timestamp(1687244806000000)
.data(DecryptedData(vec![1, 2, 3, 4])) .data(DecryptedData(vec![1, 2, 3, 4]))
.build(); .build();
@ -347,13 +344,13 @@ mod tests {
let encrypted = record.encrypt::<PASETO_V4>(&key); let encrypted = record.encrypt::<PASETO_V4>(&key);
let mut enc1 = encrypted.clone(); let mut enc1 = encrypted.clone();
enc1.host = "host2".to_owned(); enc1.host = HostId(uuid_v7());
let _ = enc1 let _ = enc1
.decrypt::<PASETO_V4>(&key) .decrypt::<PASETO_V4>(&key)
.expect_err("tampering with the host should result in auth failure"); .expect_err("tampering with the host should result in auth failure");
let mut enc2 = encrypted; let mut enc2 = encrypted;
enc2.id = "2".to_owned(); enc2.id = RecordId(uuid_v7());
let _ = enc2 let _ = enc2
.decrypt::<PASETO_V4>(&key) .decrypt::<PASETO_V4>(&key)
.expect_err("tampering with the id should result in auth failure"); .expect_err("tampering with the id should result in auth failure");

View File

@ -1,3 +1,5 @@
pub mod encryption; pub mod encryption;
pub mod sqlite_store; pub mod sqlite_store;
pub mod store; pub mod store;
#[cfg(feature = "sync")]
pub mod sync;

View File

@ -8,12 +8,14 @@ use std::str::FromStr;
use async_trait::async_trait; use async_trait::async_trait;
use eyre::{eyre, Result}; use eyre::{eyre, Result};
use fs_err as fs; use fs_err as fs;
use futures::TryStreamExt;
use sqlx::{ use sqlx::{
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
Row, Row,
}; };
use atuin_common::record::{EncryptedData, Record}; use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
use uuid::Uuid;
use super::store::Store; use super::store::Store;
@ -62,11 +64,11 @@ impl SqliteStore {
"insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek) "insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek)
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
) )
.bind(r.id.as_str()) .bind(r.id.0.as_simple().to_string())
.bind(r.host.as_str()) .bind(r.host.0.as_simple().to_string())
.bind(r.tag.as_str()) .bind(r.tag.as_str())
.bind(r.timestamp as i64) .bind(r.timestamp as i64)
.bind(r.parent.as_ref()) .bind(r.parent.map(|p| p.0.as_simple().to_string()))
.bind(r.version.as_str()) .bind(r.version.as_str())
.bind(r.data.data.as_str()) .bind(r.data.data.as_str())
.bind(r.data.content_encryption_key.as_str()) .bind(r.data.content_encryption_key.as_str())
@ -79,10 +81,18 @@ impl SqliteStore {
fn query_row(row: SqliteRow) -> Record<EncryptedData> { fn query_row(row: SqliteRow) -> Record<EncryptedData> {
let timestamp: i64 = row.get("timestamp"); let timestamp: i64 = row.get("timestamp");
// tbh at this point things are pretty fucked so just panic
let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
let parent: Option<&str> = row.get("parent");
let parent = parent
.map(|parent| Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB"));
Record { Record {
id: row.get("id"), id: RecordId(id),
host: row.get("host"), host: HostId(host),
parent: row.get("parent"), parent: parent.map(RecordId),
timestamp: timestamp as u64, timestamp: timestamp as u64,
tag: row.get("tag"), tag: row.get("tag"),
version: row.get("version"), version: row.get("version"),
@ -111,9 +121,9 @@ impl Store for SqliteStore {
Ok(()) Ok(())
} }
async fn get(&self, id: &str) -> Result<Record<EncryptedData>> { async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
let res = sqlx::query("select * from records where id = ?1") let res = sqlx::query("select * from records where id = ?1")
.bind(id) .bind(id.0.as_simple().to_string())
.map(Self::query_row) .map(Self::query_row)
.fetch_one(&self.pool) .fetch_one(&self.pool)
.await?; .await?;
@ -121,10 +131,10 @@ impl Store for SqliteStore {
Ok(res) Ok(res)
} }
async fn len(&self, host: &str, tag: &str) -> Result<u64> { async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
let res: (i64,) = let res: (i64,) =
sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2") sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
.bind(host) .bind(host.0.as_simple().to_string())
.bind(tag) .bind(tag)
.fetch_one(&self.pool) .fetch_one(&self.pool)
.await?; .await?;
@ -134,7 +144,7 @@ impl Store for SqliteStore {
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> { async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query("select * from records where parent = ?1") let res = sqlx::query("select * from records where parent = ?1")
.bind(record.id.clone()) .bind(record.id.0.as_simple().to_string())
.map(Self::query_row) .map(Self::query_row)
.fetch_one(&self.pool) .fetch_one(&self.pool)
.await; .await;
@ -146,11 +156,11 @@ impl Store for SqliteStore {
} }
} }
async fn first(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> { async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query( let res = sqlx::query(
"select * from records where host = ?1 and tag = ?2 and parent is null limit 1", "select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
) )
.bind(host) .bind(host.0.as_simple().to_string())
.bind(tag) .bind(tag)
.map(Self::query_row) .map(Self::query_row)
.fetch_optional(&self.pool) .fetch_optional(&self.pool)
@ -159,23 +169,53 @@ impl Store for SqliteStore {
Ok(res) Ok(res)
} }
async fn last(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> { async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query( let res = sqlx::query(
"select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;", "select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;",
) )
.bind(tag) .bind(tag)
.bind(host) .bind(host.0.as_simple().to_string())
.map(Self::query_row) .map(Self::query_row)
.fetch_optional(&self.pool) .fetch_optional(&self.pool)
.await?; .await?;
Ok(res) Ok(res)
} }
async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
let res = sqlx::query(
"select * from records rp where tag=?1 and (select count(1) from records where parent=rp.id) = 0;",
)
.bind(tag)
.map(Self::query_row)
.fetch_all(&self.pool)
.await?;
Ok(res)
}
async fn tail_records(&self) -> Result<RecordIndex> {
let res = sqlx::query(
"select host, tag, id from records rp where (select count(1) from records where parent=rp.id) = 0;",
)
.map(|row: SqliteRow| {
let host: Uuid= Uuid::from_str(row.get("host")).expect("invalid uuid in db host");
let tag: String= row.get("tag");
let id: Uuid= Uuid::from_str(row.get("id")).expect("invalid uuid in db id");
(HostId(host), tag, RecordId(id))
})
.fetch(&self.pool)
.try_collect()
.await?;
Ok(res)
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use atuin_common::record::{EncryptedData, Record}; use atuin_common::record::{EncryptedData, HostId, Record};
use crate::record::{encryption::PASETO_V4, store::Store}; use crate::record::{encryption::PASETO_V4, store::Store};
@ -183,7 +223,7 @@ mod tests {
fn test_record() -> Record<EncryptedData> { fn test_record() -> Record<EncryptedData> {
Record::builder() Record::builder()
.host(atuin_common::utils::uuid_v7().simple().to_string()) .host(HostId(atuin_common::utils::uuid_v7()))
.version("v1".into()) .version("v1".into())
.tag(atuin_common::utils::uuid_v7().simple().to_string()) .tag(atuin_common::utils::uuid_v7().simple().to_string())
.data(EncryptedData { .data(EncryptedData {
@ -218,10 +258,7 @@ mod tests {
let record = test_record(); let record = test_record();
db.push(&record).await.unwrap(); db.push(&record).await.unwrap();
let new_record = db let new_record = db.get(record.id).await.expect("failed to fetch record");
.get(record.id.as_str())
.await
.expect("failed to fetch record");
assert_eq!(record, new_record, "records are not equal"); assert_eq!(record, new_record, "records are not equal");
} }
@ -233,7 +270,7 @@ mod tests {
db.push(&record).await.unwrap(); db.push(&record).await.unwrap();
let len = db let len = db
.len(record.host.as_str(), record.tag.as_str()) .len(record.host, record.tag.as_str())
.await .await
.expect("failed to get store len"); .expect("failed to get store len");
@ -253,14 +290,8 @@ mod tests {
db.push(&first).await.unwrap(); db.push(&first).await.unwrap();
db.push(&second).await.unwrap(); db.push(&second).await.unwrap();
let first_len = db let first_len = db.len(first.host, first.tag.as_str()).await.unwrap();
.len(first.host.as_str(), first.tag.as_str()) let second_len = db.len(second.host, second.tag.as_str()).await.unwrap();
.await
.unwrap();
let second_len = db
.len(second.host.as_str(), second.tag.as_str())
.await
.unwrap();
assert_eq!(first_len, 1, "expected length of 1 after insert"); assert_eq!(first_len, 1, "expected length of 1 after insert");
assert_eq!(second_len, 1, "expected length of 1 after insert"); assert_eq!(second_len, 1, "expected length of 1 after insert");
@ -281,7 +312,7 @@ mod tests {
} }
assert_eq!( assert_eq!(
db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(), db.len(tail.host, tail.tag.as_str()).await.unwrap(),
100, 100,
"failed to insert 100 records" "failed to insert 100 records"
); );
@ -304,7 +335,7 @@ mod tests {
db.push_batch(records.iter()).await.unwrap(); db.push_batch(records.iter()).await.unwrap();
assert_eq!( assert_eq!(
db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(), db.len(tail.host, tail.tag.as_str()).await.unwrap(),
10000, 10000,
"failed to insert 10k records" "failed to insert 10k records"
); );
@ -327,7 +358,7 @@ mod tests {
db.push_batch(records.iter()).await.unwrap(); db.push_batch(records.iter()).await.unwrap();
let mut record = db let mut record = db
.first(tail.host.as_str(), tail.tag.as_str()) .head(tail.host, tail.tag.as_str())
.await .await
.expect("in memory sqlite should not fail") .expect("in memory sqlite should not fail")
.expect("entry exists"); .expect("entry exists");

View File

@ -1,7 +1,7 @@
use async_trait::async_trait; use async_trait::async_trait;
use eyre::Result; use eyre::Result;
use atuin_common::record::{EncryptedData, Record}; use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
/// A record store stores records /// A record store stores records
/// In more detail - we tend to need to process this into _another_ format to actually query it. /// In more detail - we tend to need to process this into _another_ format to actually query it.
@ -20,14 +20,22 @@ pub trait Store {
records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync, records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync,
) -> Result<()>; ) -> Result<()>;
async fn get(&self, id: &str) -> Result<Record<EncryptedData>>; async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
async fn len(&self, host: &str, tag: &str) -> Result<u64>; async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
/// Get the record that follows this record /// Get the record that follows this record
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>; async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>;
/// Get the first record for a given host and tag /// Get the first record for a given host and tag
async fn first(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>>; async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
/// Get the last record for a given host and tag /// Get the last record for a given host and tag
async fn last(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>>; async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
// Get the last record for all hosts for a given tag, useful for the read path of apps.
async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>;
// Get the latest host/tag/record tuple for every set in the store. useful for building an
// index
async fn tail_records(&self) -> Result<RecordIndex>;
} }

View File

@ -0,0 +1,421 @@
// do a sync :O
use eyre::Result;
use super::store::Store;
use crate::{api_client::Client, settings::Settings};
use atuin_common::record::{Diff, HostId, RecordId, RecordIndex};
#[derive(Debug, Eq, PartialEq)]
pub enum Operation {
// Either upload or download until the tail matches the below
Upload {
tail: RecordId,
host: HostId,
tag: String,
},
Download {
tail: RecordId,
host: HostId,
tag: String,
},
}
pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Diff>, RecordIndex)> {
let client = Client::new(&settings.sync_address, &settings.session_token)?;
let local_index = store.tail_records().await?;
let remote_index = client.record_index().await?;
let diff = local_index.diff(&remote_index);
Ok((diff, remote_index))
}
// Take a diff, along with a local store, and resolve it into a set of operations.
// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download.
// In theory this could be done as a part of the diffing stage, but it's easier to reason
// about and test this way
pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Operation>> {
let mut operations = Vec::with_capacity(diffs.len());
for diff in diffs {
// First, try to fetch the tail
// If it exists locally, then that means we need to update the remote
// host until it has the same tail. Ie, upload.
// If it does not exist locally, that means remote is ahead of us.
// Therefore, we need to download until our local tail matches
let record = store.get(diff.tail).await;
let op = if record.is_ok() {
// if local has the ID, then we should find the actual tail of this
// store, so we know what we need to update the remote to.
let tail = store
.tail(diff.host, diff.tag.as_str())
.await?
.expect("failed to fetch last record, expected tag/host to exist");
// TODO(ellie) update the diffing so that it stores the context of the current tail
// that way, we can determine how much we need to upload.
// For now just keep uploading until tails match
Operation::Upload {
tail: tail.id,
host: diff.host,
tag: diff.tag,
}
} else {
Operation::Download {
tail: diff.tail,
host: diff.host,
tag: diff.tag,
}
};
operations.push(op);
}
// sort them - purely so we have a stable testing order, and can rely on
// same input = same output
// We can sort by ID so long as we continue to use UUIDv7 or something
// with the same properties
operations.sort_by_key(|op| match op {
Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
Operation::Download { tail, host, .. } => ("download", *host, *tail),
});
Ok(operations)
}
async fn sync_upload(
store: &mut impl Store,
remote_index: &RecordIndex,
client: &Client<'_>,
op: (HostId, String, RecordId),
) -> Result<i64> {
let upload_page_size = 100;
let mut total = 0;
// so. we have an upload operation, with the tail representing the state
// we want to get the remote to
let current_tail = remote_index.get(op.0, op.1.clone());
println!(
"Syncing local {:?}/{}/{:?}, remote has {:?}",
op.0, op.1, op.2, current_tail
);
let start = if let Some(current_tail) = current_tail {
current_tail
} else {
store
.head(op.0, op.1.as_str())
.await
.expect("failed to fetch host/tag head")
.expect("host/tag not in current index")
.id
};
debug!("starting push to remote from: {:?}", start);
// we have the start point for sync. it is either the head of the store if
// the remote has no data for it, or the tail that the remote has
// we need to iterate from the remote tail, and keep going until
// remote tail = current local tail
let mut record = Some(store.get(start).await.unwrap());
let mut buf = Vec::with_capacity(upload_page_size);
while let Some(r) = record {
if buf.len() < upload_page_size {
buf.push(r.clone());
} else {
client.post_records(&buf).await?;
// can we reset what we have? len = 0 but keep capacity
buf = Vec::with_capacity(upload_page_size);
}
record = store.next(&r).await?;
total += 1;
}
if !buf.is_empty() {
client.post_records(&buf).await?;
}
Ok(total)
}
async fn sync_download(
store: &mut impl Store,
remote_index: &RecordIndex,
client: &Client<'_>,
op: (HostId, String, RecordId),
) -> Result<i64> {
// TODO(ellie): implement variable page sizing like on history sync
let download_page_size = 1000;
let mut total = 0;
// We know that the remote is ahead of us, so let's keep downloading until both
// 1) The remote stops returning full pages
// 2) The tail equals what we expect
//
// If (1) occurs without (2), then something is wrong with our index calculation
// and we should bail.
let remote_tail = remote_index
.get(op.0, op.1.clone())
.expect("remote index does not contain expected tail during download");
let local_tail = store.tail(op.0, op.1.as_str()).await?;
//
// We expect that the operations diff will represent the desired state
// In this case, that contains the remote tail.
assert_eq!(remote_tail, op.2);
println!("Downloading {:?}/{}/{:?} to local", op.0, op.1, op.2);
let mut records = client
.next_records(
op.0,
op.1.clone(),
local_tail.map(|r| r.id),
download_page_size,
)
.await?;
while !records.is_empty() {
total += std::cmp::min(download_page_size, records.len() as u64);
store.push_batch(records.iter()).await?;
if records.last().unwrap().id == remote_tail {
break;
}
records = client
.next_records(
op.0,
op.1.clone(),
records.last().map(|r| r.id),
download_page_size,
)
.await?;
}
Ok(total as i64)
}
pub async fn sync_remote(
operations: Vec<Operation>,
remote_index: &RecordIndex,
local_store: &mut impl Store,
settings: &Settings,
) -> Result<(i64, i64)> {
let client = Client::new(&settings.sync_address, &settings.session_token)?;
let mut uploaded = 0;
let mut downloaded = 0;
// this can totally run in parallel, but lets get it working first
for i in operations {
match i {
Operation::Upload { tail, host, tag } => {
uploaded +=
sync_upload(local_store, remote_index, &client, (host, tag, tail)).await?
}
Operation::Download { tail, host, tag } => {
downloaded +=
sync_download(local_store, remote_index, &client, (host, tag, tail)).await?
}
}
}
Ok((uploaded, downloaded))
}
#[cfg(test)]
mod tests {
use atuin_common::record::{Diff, EncryptedData, HostId, Record};
use pretty_assertions::assert_eq;
use crate::record::{
encryption::PASETO_V4,
sqlite_store::SqliteStore,
store::Store,
sync::{self, Operation},
};
fn test_record() -> Record<EncryptedData> {
Record::builder()
.host(HostId(atuin_common::utils::uuid_v7()))
.version("v1".into())
.tag(atuin_common::utils::uuid_v7().simple().to_string())
.data(EncryptedData {
data: String::new(),
content_encryption_key: String::new(),
})
.build()
}
// Take a list of local records, and a list of remote records.
// Return the local database, and a diff of local/remote, ready to build
// ops
async fn build_test_diff(
local_records: Vec<Record<EncryptedData>>,
remote_records: Vec<Record<EncryptedData>>,
) -> (SqliteStore, Vec<Diff>) {
let local_store = SqliteStore::new(":memory:")
.await
.expect("failed to open in memory sqlite");
let remote_store = SqliteStore::new(":memory:")
.await
.expect("failed to open in memory sqlite"); // "remote"
for i in local_records {
local_store.push(&i).await.unwrap();
}
for i in remote_records {
remote_store.push(&i).await.unwrap();
}
let local_index = local_store.tail_records().await.unwrap();
let remote_index = remote_store.tail_records().await.unwrap();
let diff = local_index.diff(&remote_index);
(local_store, diff)
}
#[tokio::test]
async fn test_basic_diff() {
// a diff where local is ahead of remote. nothing else.
let record = test_record();
let (store, diff) = build_test_diff(vec![record.clone()], vec![]).await;
assert_eq!(diff.len(), 1);
let operations = sync::operations(diff, &store).await.unwrap();
assert_eq!(operations.len(), 1);
assert_eq!(
operations[0],
Operation::Upload {
host: record.host,
tag: record.tag,
tail: record.id
}
);
}
#[tokio::test]
async fn build_two_way_diff() {
// a diff where local is ahead of remote for one, and remote for
// another. One upload, one download
let shared_record = test_record();
let remote_ahead = test_record();
let local_ahead = shared_record
.new_child(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store
let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store
let (store, diff) = build_test_diff(local, remote).await;
let operations = sync::operations(diff, &store).await.unwrap();
assert_eq!(operations.len(), 2);
assert_eq!(
operations,
vec![
Operation::Download {
tail: remote_ahead.id,
host: remote_ahead.host,
tag: remote_ahead.tag,
},
Operation::Upload {
tail: local_ahead.id,
host: local_ahead.host,
tag: local_ahead.tag,
},
]
);
}
#[tokio::test]
async fn build_complex_diff() {
// One shared, ahead but known only by remote
// One known only by local
// One known only by remote
let shared_record = test_record();
let remote_known = test_record();
let local_known = test_record();
let second_shared = test_record();
let second_shared_remote_ahead = second_shared
.new_child(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let local_ahead = shared_record
.new_child(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let local = vec![
shared_record.clone(),
second_shared.clone(),
local_known.clone(),
local_ahead.clone(),
];
let remote = vec![
shared_record.clone(),
second_shared.clone(),
second_shared_remote_ahead.clone(),
remote_known.clone(),
]; // remote knows about the already-synced, and one new record in a new store
let (store, diff) = build_test_diff(local, remote).await;
let operations = sync::operations(diff, &store).await.unwrap();
assert_eq!(operations.len(), 4);
let mut result_ops = vec![
Operation::Download {
tail: remote_known.id,
host: remote_known.host,
tag: remote_known.tag,
},
Operation::Download {
tail: second_shared_remote_ahead.id,
host: second_shared.host,
tag: second_shared.tag,
},
Operation::Upload {
tail: local_ahead.id,
host: local_ahead.host,
tag: local_ahead.tag,
},
Operation::Upload {
tail: local_known.id,
host: local_known.host,
tag: local_known.tag,
},
];
result_ops.sort_by_key(|op| match op {
Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
Operation::Download { tail, host, .. } => ("download", *host, *tail),
});
assert_eq!(operations, result_ops);
}
}

View File

@ -1,8 +1,10 @@
use std::{ use std::{
io::prelude::*, io::prelude::*,
path::{Path, PathBuf}, path::{Path, PathBuf},
str::FromStr,
}; };
use atuin_common::record::HostId;
use chrono::{prelude::*, Utc}; use chrono::{prelude::*, Utc};
use clap::ValueEnum; use clap::ValueEnum;
use config::{Config, Environment, File as ConfigFile, FileFormat}; use config::{Config, Environment, File as ConfigFile, FileFormat};
@ -12,6 +14,7 @@ use parse_duration::parse;
use regex::RegexSet; use regex::RegexSet;
use semver::Version; use semver::Version;
use serde::Deserialize; use serde::Deserialize;
use uuid::Uuid;
pub const HISTORY_PAGE_SIZE: i64 = 100; pub const HISTORY_PAGE_SIZE: i64 = 100;
pub const LAST_SYNC_FILENAME: &str = "last_sync_time"; pub const LAST_SYNC_FILENAME: &str = "last_sync_time";
@ -228,11 +231,13 @@ impl Settings {
Settings::load_time_from_file(LAST_VERSION_CHECK_FILENAME) Settings::load_time_from_file(LAST_VERSION_CHECK_FILENAME)
} }
pub fn host_id() -> Option<String> { pub fn host_id() -> Option<HostId> {
let id = Settings::read_from_data_dir(HOST_ID_FILENAME); let id = Settings::read_from_data_dir(HOST_ID_FILENAME);
if id.is_some() { if let Some(id) = id {
return id; let parsed =
Uuid::from_str(id.as_str()).expect("failed to parse host ID from local directory");
return Some(HostId(parsed));
} }
let uuid = atuin_common::utils::uuid_v7(); let uuid = atuin_common::utils::uuid_v7();
@ -240,7 +245,7 @@ impl Settings {
Settings::save_to_data_dir(HOST_ID_FILENAME, uuid.as_simple().to_string().as_ref()) Settings::save_to_data_dir(HOST_ID_FILENAME, uuid.as_simple().to_string().as_ref())
.expect("Could not write host ID to data dir"); .expect("Could not write host ID to data dir");
Some(uuid.as_simple().to_string()) Some(HostId(uuid))
} }
pub fn should_sync(&self) -> Result<bool> { pub fn should_sync(&self) -> Result<bool> {

View File

@ -18,6 +18,7 @@ uuid = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
typed-builder = { workspace = true } typed-builder = { workspace = true }
eyre = { workspace = true } eyre = { workspace = true }
sqlx = { workspace = true }
[dev-dependencies] [dev-dependencies]
pretty_assertions = "1.3.0" pretty_assertions = { workspace = true }

View File

@ -1,5 +1,57 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
/// Defines a new UUID type wrapper
macro_rules! new_uuid {
($name:ident) => {
#[derive(
Debug,
Copy,
Clone,
PartialEq,
Eq,
Hash,
PartialOrd,
Ord,
serde::Serialize,
serde::Deserialize,
)]
#[serde(transparent)]
pub struct $name(pub Uuid);
impl<DB: sqlx::Database> sqlx::Type<DB> for $name
where
Uuid: sqlx::Type<DB>,
{
fn type_info() -> <DB as sqlx::Database>::TypeInfo {
Uuid::type_info()
}
}
impl<'r, DB: sqlx::Database> sqlx::Decode<'r, DB> for $name
where
Uuid: sqlx::Decode<'r, DB>,
{
fn decode(
value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef,
) -> std::result::Result<Self, sqlx::error::BoxDynError> {
Uuid::decode(value).map(Self)
}
}
impl<'q, DB: sqlx::Database> sqlx::Encode<'q, DB> for $name
where
Uuid: sqlx::Encode<'q, DB>,
{
fn encode_by_ref(
&self,
buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer,
) -> sqlx::encode::IsNull {
self.0.encode_by_ref(buf)
}
}
};
}
pub mod api; pub mod api;
pub mod record; pub mod record;
pub mod utils; pub mod utils;

View File

@ -3,35 +3,43 @@ use std::collections::HashMap;
use eyre::Result; use eyre::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use typed_builder::TypedBuilder; use typed_builder::TypedBuilder;
use uuid::Uuid;
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct DecryptedData(pub Vec<u8>); pub struct DecryptedData(pub Vec<u8>);
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EncryptedData { pub struct EncryptedData {
pub data: String, pub data: String,
pub content_encryption_key: String, pub content_encryption_key: String,
} }
#[derive(Debug, PartialEq)]
pub struct Diff {
pub host: HostId,
pub tag: String,
pub tail: RecordId,
}
/// A single record stored inside of our local database /// A single record stored inside of our local database
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)]
pub struct Record<Data> { pub struct Record<Data> {
/// a unique ID /// a unique ID
#[builder(default = crate::utils::uuid_v7().as_simple().to_string())] #[builder(default = RecordId(crate::utils::uuid_v7()))]
pub id: String, pub id: RecordId,
/// The unique ID of the host. /// The unique ID of the host.
// TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store // TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store
// as strings. I would rather avoid normalization, so store as UUID binary instead of // as strings. I would rather avoid normalization, so store as UUID binary instead of
// encoding to a string and wasting much more storage. // encoding to a string and wasting much more storage.
pub host: String, pub host: HostId,
/// The ID of the parent entry /// The ID of the parent entry
// A store is technically just a double linked list // A store is technically just a double linked list
// We can do some cheating with the timestamps, but should not rely upon them. // We can do some cheating with the timestamps, but should not rely upon them.
// Clocks are tricksy. // Clocks are tricksy.
#[builder(default)] #[builder(default)]
pub parent: Option<String>, pub parent: Option<RecordId>,
/// The creation time in nanoseconds since unix epoch /// The creation time in nanoseconds since unix epoch
#[builder(default = chrono::Utc::now().timestamp_nanos() as u64)] #[builder(default = chrono::Utc::now().timestamp_nanos() as u64)]
@ -48,21 +56,25 @@ pub struct Record<Data> {
pub data: Data, pub data: Data,
} }
new_uuid!(RecordId);
new_uuid!(HostId);
/// Extra data from the record that should be encoded in the data /// Extra data from the record that should be encoded in the data
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
pub struct AdditionalData<'a> { pub struct AdditionalData<'a> {
pub id: &'a str, pub id: &'a RecordId,
pub version: &'a str, pub version: &'a str,
pub tag: &'a str, pub tag: &'a str,
pub host: &'a str, pub host: &'a HostId,
pub parent: Option<&'a RecordId>,
} }
impl<Data> Record<Data> { impl<Data> Record<Data> {
pub fn new_child(&self, data: Vec<u8>) -> Record<DecryptedData> { pub fn new_child(&self, data: Vec<u8>) -> Record<DecryptedData> {
Record::builder() Record::builder()
.host(self.host.clone()) .host(self.host)
.version(self.version.clone()) .version(self.version.clone())
.parent(Some(self.id.clone())) .parent(Some(self.id))
.tag(self.tag.clone()) .tag(self.tag.clone())
.data(DecryptedData(data)) .data(DecryptedData(data))
.build() .build()
@ -71,9 +83,10 @@ impl<Data> Record<Data> {
/// An index representing the current state of the record stores /// An index representing the current state of the record stores
/// This can be both remote, or local, and compared in either direction /// This can be both remote, or local, and compared in either direction
#[derive(Debug, Serialize, Deserialize)]
pub struct RecordIndex { pub struct RecordIndex {
// A map of host -> tag -> tail // A map of host -> tag -> tail
pub hosts: HashMap<String, HashMap<String, String>>, pub hosts: HashMap<HostId, HashMap<String, RecordId>>,
} }
impl Default for RecordIndex { impl Default for RecordIndex {
@ -82,6 +95,14 @@ impl Default for RecordIndex {
} }
} }
impl Extend<(HostId, String, RecordId)> for RecordIndex {
fn extend<T: IntoIterator<Item = (HostId, String, RecordId)>>(&mut self, iter: T) {
for (host, tag, tail_id) in iter {
self.set_raw(host, tag, tail_id);
}
}
}
impl RecordIndex { impl RecordIndex {
pub fn new() -> RecordIndex { pub fn new() -> RecordIndex {
RecordIndex { RecordIndex {
@ -91,13 +112,14 @@ impl RecordIndex {
/// Insert a new tail record into the store /// Insert a new tail record into the store
pub fn set(&mut self, tail: Record<DecryptedData>) { pub fn set(&mut self, tail: Record<DecryptedData>) {
self.hosts self.set_raw(tail.host, tail.tag, tail.id)
.entry(tail.host)
.or_default()
.insert(tail.tag, tail.id);
} }
pub fn get(&self, host: String, tag: String) -> Option<String> { pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordId) {
self.hosts.entry(host).or_default().insert(tag, tail_id);
}
pub fn get(&self, host: HostId, tag: String) -> Option<RecordId> {
self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned() self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned()
} }
@ -108,21 +130,29 @@ impl RecordIndex {
/// other machine has a different tail, it will be the differing tail. This is useful to /// other machine has a different tail, it will be the differing tail. This is useful to
/// check if the other index is ahead of us, or behind. /// check if the other index is ahead of us, or behind.
/// If the other index does not have the (host, tag) pair, then the other value will be None. /// If the other index does not have the (host, tag) pair, then the other value will be None.
pub fn diff(&self, other: &Self) -> Vec<(String, String, Option<String>)> { pub fn diff(&self, other: &Self) -> Vec<Diff> {
let mut ret = Vec::new(); let mut ret = Vec::new();
// First, we check if other has everything that self has // First, we check if other has everything that self has
for (host, tag_map) in self.hosts.iter() { for (host, tag_map) in self.hosts.iter() {
for (tag, tail) in tag_map.iter() { for (tag, tail) in tag_map.iter() {
match other.get(host.clone(), tag.clone()) { match other.get(*host, tag.clone()) {
// The other store is all up to date! No diff. // The other store is all up to date! No diff.
Some(t) if t.eq(tail) => continue, Some(t) if t.eq(tail) => continue,
// The other store does exist, but it is either ahead or behind us. A diff regardless // The other store does exist, but it is either ahead or behind us. A diff regardless
Some(t) => ret.push((host.clone(), tag.clone(), Some(t))), Some(t) => ret.push(Diff {
host: *host,
tag: tag.clone(),
tail: t,
}),
// The other store does not exist :O // The other store does not exist :O
None => ret.push((host.clone(), tag.clone(), None)), None => ret.push(Diff {
host: *host,
tag: tag.clone(),
tail: *tail,
}),
}; };
} }
} }
@ -133,16 +163,20 @@ impl RecordIndex {
// account for that! // account for that!
for (host, tag_map) in other.hosts.iter() { for (host, tag_map) in other.hosts.iter() {
for (tag, tail) in tag_map.iter() { for (tag, tail) in tag_map.iter() {
match self.get(host.clone(), tag.clone()) { match self.get(*host, tag.clone()) {
// If we have this host/tag combo, the comparison and diff will have already happened above // If we have this host/tag combo, the comparison and diff will have already happened above
Some(_) => continue, Some(_) => continue,
None => ret.push((host.clone(), tag.clone(), Some(tail.clone()))), None => ret.push(Diff {
host: *host,
tag: tag.clone(),
tail: *tail,
}),
}; };
} }
} }
ret.sort(); ret.sort_by(|a, b| (a.host, a.tag.clone(), a.tail).cmp(&(b.host, b.tag.clone(), b.tail)));
ret ret
} }
} }
@ -168,6 +202,7 @@ impl Record<DecryptedData> {
version: &self.version, version: &self.version,
tag: &self.tag, tag: &self.tag,
host: &self.host, host: &self.host,
parent: self.parent.as_ref(),
}; };
Record { Record {
data: E::encrypt(self.data, ad, key), data: E::encrypt(self.data, ad, key),
@ -188,6 +223,7 @@ impl Record<EncryptedData> {
version: &self.version, version: &self.version,
tag: &self.tag, tag: &self.tag,
host: &self.host, host: &self.host,
parent: self.parent.as_ref(),
}; };
Ok(Record { Ok(Record {
data: E::decrypt(self.data, ad, key)?, data: E::decrypt(self.data, ad, key)?,
@ -210,6 +246,7 @@ impl Record<EncryptedData> {
version: &self.version, version: &self.version,
tag: &self.tag, tag: &self.tag,
host: &self.host, host: &self.host,
parent: self.parent.as_ref(),
}; };
Ok(Record { Ok(Record {
data: E::re_encrypt(self.data, ad, old_key, new_key)?, data: E::re_encrypt(self.data, ad, old_key, new_key)?,
@ -225,12 +262,14 @@ impl Record<EncryptedData> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{DecryptedData, Record, RecordIndex}; use crate::record::HostId;
use super::{DecryptedData, Diff, Record, RecordIndex};
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
fn test_record() -> Record<DecryptedData> { fn test_record() -> Record<DecryptedData> {
Record::builder() Record::builder()
.host(crate::utils::uuid_v7().simple().to_string()) .host(HostId(crate::utils::uuid_v7()))
.version("v1".into()) .version("v1".into())
.tag(crate::utils::uuid_v7().simple().to_string()) .tag(crate::utils::uuid_v7().simple().to_string())
.data(DecryptedData(vec![0, 1, 2, 3])) .data(DecryptedData(vec![0, 1, 2, 3]))
@ -304,7 +343,14 @@ mod tests {
let diff = index1.diff(&index2); let diff = index1.diff(&index2);
assert_eq!(1, diff.len(), "expected single diff"); assert_eq!(1, diff.len(), "expected single diff");
assert_eq!(diff[0], (record2.host, record2.tag, Some(record2.id))); assert_eq!(
diff[0],
Diff {
host: record2.host,
tag: record2.tag,
tail: record2.id
}
);
} }
#[test] #[test]
@ -342,12 +388,14 @@ mod tests {
assert_eq!(4, diff1.len()); assert_eq!(4, diff1.len());
assert_eq!(4, diff2.len()); assert_eq!(4, diff2.len());
dbg!(&diff1, &diff2);
// both diffs should be ALMOST the same. They will agree on which hosts and tags // both diffs should be ALMOST the same. They will agree on which hosts and tags
// require updating, but the "other" value will not be the same. // require updating, but the "other" value will not be the same.
let smol_diff_1: Vec<(String, String)> = let smol_diff_1: Vec<(HostId, String)> =
diff1.iter().map(|v| (v.0.clone(), v.1.clone())).collect(); diff1.iter().map(|v| (v.host, v.tag.clone())).collect();
let smol_diff_2: Vec<(String, String)> = let smol_diff_2: Vec<(HostId, String)> =
diff1.iter().map(|v| (v.0.clone(), v.1.clone())).collect(); diff1.iter().map(|v| (v.host, v.tag.clone())).collect();
assert_eq!(smol_diff_1, smol_diff_2); assert_eq!(smol_diff_1, smol_diff_2);

View File

@ -13,7 +13,10 @@ use self::{
models::{History, NewHistory, NewSession, NewUser, Session, User}, models::{History, NewHistory, NewSession, NewUser, Session, User},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use atuin_common::utils::get_days_from_month; use atuin_common::{
record::{EncryptedData, HostId, Record, RecordId, RecordIndex},
utils::get_days_from_month,
};
use chrono::{Datelike, TimeZone}; use chrono::{Datelike, TimeZone};
use chronoutil::RelativeDuration; use chronoutil::RelativeDuration;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
@ -55,6 +58,19 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
async fn delete_history(&self, user: &User, id: String) -> DbResult<()>; async fn delete_history(&self, user: &User, id: String) -> DbResult<()>;
async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>; async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>;
async fn add_records(&self, user: &User, record: &[Record<EncryptedData>]) -> DbResult<()>;
async fn next_records(
&self,
user: &User,
host: HostId,
tag: String,
start: Option<RecordId>,
count: u64,
) -> DbResult<Vec<Record<EncryptedData>>>;
// Return the tail record ID for each store, so (HostID, Tag, TailRecordID)
async fn tail_records(&self, user: &User) -> DbResult<RecordIndex>;
async fn count_history_range( async fn count_history_range(
&self, &self,
user: &User, user: &User,

View File

@ -18,4 +18,5 @@ chrono = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
sqlx = { workspace = true } sqlx = { workspace = true }
async-trait = { workspace = true } async-trait = { workspace = true }
uuid = { workspace = true }
futures-util = "0.3" futures-util = "0.3"

View File

@ -0,0 +1,5 @@
// generated by `sqlx migrate build-script`
fn main() {
// trigger recompilation when a new migration is added
println!("cargo:rerun-if-changed=migrations");
}

View File

@ -0,0 +1,15 @@
-- Add migration script here
create table records (
id uuid primary key, -- remember to use uuidv7 for happy indices <3
client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key
host uuid not null, -- a unique identifier for the host
parent uuid default null, -- the ID of the parent record, bearing in mind this is a linked list
timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision
version text not null,
tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host
data text not null, -- store the actual history data, encrypted. I don't wanna know!
cek text not null,
user_id bigint not null, -- allow multiple users
created_at timestamp not null default current_timestamp
);

View File

@ -1,14 +1,14 @@
use async_trait::async_trait; use async_trait::async_trait;
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User}; use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
use atuin_server_database::{Database, DbError, DbResult}; use atuin_server_database::{Database, DbError, DbResult};
use futures_util::TryStreamExt; use futures_util::TryStreamExt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::postgres::PgPoolOptions; use sqlx::postgres::PgPoolOptions;
use sqlx::Row; use sqlx::Row;
use tracing::instrument; use tracing::instrument;
use wrappers::{DbHistory, DbSession, DbUser}; use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
mod wrappers; mod wrappers;
@ -329,4 +329,102 @@ impl Database for Postgres {
.map_err(fix_error) .map_err(fix_error)
.map(|DbHistory(h)| h) .map(|DbHistory(h)| h)
} }
#[instrument(skip_all)]
async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
let mut tx = self.pool.begin().await.map_err(fix_error)?;
for i in records {
let id = atuin_common::utils::uuid_v7();
sqlx::query(
"insert into records
(id, client_id, host, parent, timestamp, version, tag, data, cek, user_id)
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
on conflict do nothing
",
)
.bind(id)
.bind(i.id)
.bind(i.host)
.bind(i.parent)
.bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
.bind(&i.version)
.bind(&i.tag)
.bind(&i.data.data)
.bind(&i.data.content_encryption_key)
.bind(user.id)
.execute(&mut tx)
.await
.map_err(fix_error)?;
}
tx.commit().await.map_err(fix_error)?;
Ok(())
}
#[instrument(skip_all)]
async fn next_records(
&self,
user: &User,
host: HostId,
tag: String,
start: Option<RecordId>,
count: u64,
) -> DbResult<Vec<Record<EncryptedData>>> {
tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
let mut ret = Vec::with_capacity(count as usize);
let mut parent = start;
// yeah let's do something better
for _ in 0..count {
// a very much not ideal query. but it's simple at least?
// we are basically using postgres as a kv store here, so... maybe consider using an actual
// kv store?
let record: Result<DbRecord, DbError> = sqlx::query_as(
"select client_id, host, parent, timestamp, version, tag, data, cek from records
where user_id = $1
and tag = $2
and host = $3
and parent is not distinct from $4",
)
.bind(user.id)
.bind(tag.clone())
.bind(host)
.bind(parent)
.fetch_one(&self.pool)
.await
.map_err(fix_error);
match record {
Ok(record) => {
let record: Record<EncryptedData> = record.into();
ret.push(record.clone());
parent = Some(record.id);
}
Err(DbError::NotFound) => {
tracing::debug!("hit tail of store: {:?}/{}", host, tag);
return Ok(ret);
}
Err(e) => return Err(e),
}
}
Ok(ret)
}
async fn tail_records(&self, user: &User) -> DbResult<RecordIndex> {
const TAIL_RECORDS_SQL: &str = "select host, tag, client_id from records rp where (select count(1) from records where parent=rp.client_id and user_id = $1) = 0;";
let res = sqlx::query_as(TAIL_RECORDS_SQL)
.bind(user.id)
.fetch(&self.pool)
.try_collect()
.await
.map_err(fix_error)?;
Ok(res)
}
} }

View File

@ -1,10 +1,12 @@
use ::sqlx::{FromRow, Result}; use ::sqlx::{FromRow, Result};
use atuin_common::record::{EncryptedData, Record};
use atuin_server_database::models::{History, Session, User}; use atuin_server_database::models::{History, Session, User};
use sqlx::{postgres::PgRow, Row}; use sqlx::{postgres::PgRow, Row};
pub struct DbUser(pub User); pub struct DbUser(pub User);
pub struct DbSession(pub Session); pub struct DbSession(pub Session);
pub struct DbHistory(pub History); pub struct DbHistory(pub History);
pub struct DbRecord(pub Record<EncryptedData>);
impl<'a> FromRow<'a, PgRow> for DbUser { impl<'a> FromRow<'a, PgRow> for DbUser {
fn from_row(row: &'a PgRow) -> Result<Self> { fn from_row(row: &'a PgRow) -> Result<Self> {
@ -40,3 +42,30 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory {
})) }))
} }
} }
impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord {
fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
let timestamp: i64 = row.try_get("timestamp")?;
let data = EncryptedData {
data: row.try_get("data")?,
content_encryption_key: row.try_get("cek")?,
};
Ok(Self(Record {
id: row.try_get("client_id")?,
host: row.try_get("host")?,
parent: row.try_get("parent")?,
timestamp: timestamp as u64,
version: row.try_get("version")?,
tag: row.try_get("tag")?,
data,
}))
}
}
impl From<DbRecord> for Record<EncryptedData> {
fn from(other: DbRecord) -> Record<EncryptedData> {
Record { ..other.0 }
}
}

View File

@ -2,6 +2,7 @@ use atuin_common::api::{ErrorResponse, IndexResponse};
use axum::{response::IntoResponse, Json}; use axum::{response::IntoResponse, Json};
pub mod history; pub mod history;
pub mod record;
pub mod status; pub mod status;
pub mod user; pub mod user;

View File

@ -0,0 +1,104 @@
use axum::{extract::Query, extract::State, Json};
use http::StatusCode;
use serde::Deserialize;
use tracing::{error, instrument};
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
use crate::router::{AppState, UserAuth};
use atuin_server_database::Database;
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn post<DB: Database>(
UserAuth(user): UserAuth,
state: State<AppState<DB>>,
Json(records): Json<Vec<Record<EncryptedData>>>,
) -> Result<(), ErrorResponseStatus<'static>> {
let State(AppState { database, settings }) = state;
tracing::debug!(
count = records.len(),
user = user.username,
"request to add records"
);
let too_big = records
.iter()
.any(|r| r.data.data.len() >= settings.max_record_size || settings.max_record_size == 0);
if too_big {
return Err(
ErrorResponse::reply("could not add records; record too large")
.with_status(StatusCode::BAD_REQUEST),
);
}
if let Err(e) = database.add_records(&user, &records).await {
error!("failed to add record: {}", e);
return Err(ErrorResponse::reply("failed to add record")
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
};
Ok(())
}
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn index<DB: Database>(
UserAuth(user): UserAuth,
state: State<AppState<DB>>,
) -> Result<Json<RecordIndex>, ErrorResponseStatus<'static>> {
let State(AppState {
database,
settings: _,
}) = state;
let record_index = match database.tail_records(&user).await {
Ok(index) => index,
Err(e) => {
error!("failed to get record index: {}", e);
return Err(ErrorResponse::reply("failed to calculate record index")
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
}
};
Ok(Json(record_index))
}
#[derive(Deserialize)]
pub struct NextParams {
host: HostId,
tag: String,
start: Option<RecordId>,
count: u64,
}
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn next<DB: Database>(
params: Query<NextParams>,
UserAuth(user): UserAuth,
state: State<AppState<DB>>,
) -> Result<Json<Vec<Record<EncryptedData>>>, ErrorResponseStatus<'static>> {
let State(AppState {
database,
settings: _,
}) = state;
let params = params.0;
let records = match database
.next_records(&user, params.host, params.tag, params.start, params.count)
.await
{
Ok(records) => records,
Err(e) => {
error!("failed to get record index: {}", e);
return Err(ErrorResponse::reply("failed to calculate record index")
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
}
};
Ok(Json(records))
}

View File

@ -71,6 +71,9 @@ pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> R
.route("/sync/status", get(handlers::status::status)) .route("/sync/status", get(handlers::status::status))
.route("/history", post(handlers::history::add)) .route("/history", post(handlers::history::add))
.route("/history", delete(handlers::history::delete)) .route("/history", delete(handlers::history::delete))
.route("/record", post(handlers::record::post))
.route("/record", get(handlers::record::index))
.route("/record/next", get(handlers::record::next))
.route("/user/:username", get(handlers::user::get)) .route("/user/:username", get(handlers::user::get))
.route("/account", delete(handlers::user::delete)) .route("/account", delete(handlers::user::delete))
.route("/register", post(handlers::user::register)) .route("/register", post(handlers::user::register))

View File

@ -12,6 +12,7 @@ pub struct Settings<DbSettings> {
pub path: String, pub path: String,
pub open_registration: bool, pub open_registration: bool,
pub max_history_length: usize, pub max_history_length: usize,
pub max_record_size: usize,
pub page_size: i64, pub page_size: i64,
pub register_webhook_url: Option<String>, pub register_webhook_url: Option<String>,
pub register_webhook_username: String, pub register_webhook_username: String,
@ -39,6 +40,7 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
.set_default("port", 8888)? .set_default("port", 8888)?
.set_default("open_registration", false)? .set_default("open_registration", false)?
.set_default("max_history_length", 8192)? .set_default("max_history_length", 8192)?
.set_default("max_record_size", 1024 * 1024 * 1024)? // pretty chonky
.set_default("path", "")? .set_default("path", "")?
.set_default("register_webhook_username", "")? .set_default("register_webhook_username", "")?
.set_default("page_size", 1100)? .set_default("page_size", 1100)?

View File

@ -69,7 +69,7 @@ impl Cmd {
Self::Search(search) => search.run(db, &mut settings).await, Self::Search(search) => search.run(db, &mut settings).await,
#[cfg(feature = "sync")] #[cfg(feature = "sync")]
Self::Sync(sync) => sync.run(settings, &mut db).await, Self::Sync(sync) => sync.run(settings, &mut db, &mut store).await,
#[cfg(feature = "sync")] #[cfg(feature = "sync")]
Self::Account(account) => account.run(settings).await, Self::Account(account) => account.run(settings).await,

View File

@ -1,7 +1,11 @@
use clap::Subcommand; use clap::Subcommand;
use eyre::{Result, WrapErr}; use eyre::{Result, WrapErr};
use atuin_client::{database::Database, settings::Settings}; use atuin_client::{
database::Database,
record::{store::Store, sync},
settings::Settings,
};
mod status; mod status;
@ -37,9 +41,14 @@ pub enum Cmd {
} }
impl Cmd { impl Cmd {
pub async fn run(self, settings: Settings, db: &mut impl Database) -> Result<()> { pub async fn run(
self,
settings: Settings,
db: &mut impl Database,
store: &mut (impl Store + Send + Sync),
) -> Result<()> {
match self { match self {
Self::Sync { force } => run(&settings, force, db).await, Self::Sync { force } => run(&settings, force, db, store).await,
Self::Login(l) => l.run(&settings).await, Self::Login(l) => l.run(&settings).await,
Self::Logout => account::logout::run(&settings), Self::Logout => account::logout::run(&settings),
Self::Register(r) => r.run(&settings).await, Self::Register(r) => r.run(&settings).await,
@ -62,12 +71,26 @@ impl Cmd {
} }
} }
async fn run(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> { async fn run(
settings: &Settings,
force: bool,
db: &mut impl Database,
store: &mut (impl Store + Send + Sync),
) -> Result<()> {
let (diff, remote_index) = sync::diff(settings, store).await?;
let operations = sync::operations(diff, store).await?;
let (uploaded, downloaded) =
sync::sync_remote(operations, &remote_index, store, settings).await?;
println!("{uploaded}/{downloaded} up/down to record store");
atuin_client::sync::sync(settings, force, db).await?; atuin_client::sync::sync(settings, force, db).await?;
println!( println!(
"Sync complete! {} items in database, force: {}", "Sync complete! {} items in history database, force: {}",
db.history_count().await?, db.history_count().await?,
force force
); );
Ok(()) Ok(())
} }