mirror of
https://github.com/atuinsh/atuin.git
synced 2025-06-20 18:07:57 +02:00
Add new sync (#1093)
* Add record migration * Add database functions for inserting history No real tests yet :( I would like to avoid running postgres lol * Add index handler, use UUIDs not strings * Fix a bunch of tests, remove Option<Uuid> * Add tests, all passing * Working upload sync * Record downloading works * Sync download works * Don't waste requests * Use a page size for uploads, make it variable later * Aaaaaand they're encrypted now too * Add cek * Allow reading tail across hosts * Revert "Allow reading tail across hosts" Not like that This reverts commit 7b0c72e7e050c358172f9b53cbd21b9e44cf4931. * Handle multiple shards properly * format * Format and make clippy happy * use some fancy types (#1098) * use some fancy types * fmt * Goodbye horrible tuple * Update atuin-server-postgres/migrations/20230623070418_records.sql Co-authored-by: Conrad Ludgate <conradludgate@gmail.com> * fmt * Sort tests too because time sucks * fix features --------- Co-authored-by: Conrad Ludgate <conradludgate@gmail.com>
This commit is contained in:
parent
3d4302ded1
commit
97e24d0d41
30
Cargo.lock
generated
30
Cargo.lock
generated
@ -142,6 +142,7 @@ dependencies = [
|
|||||||
"directories",
|
"directories",
|
||||||
"eyre",
|
"eyre",
|
||||||
"fs-err",
|
"fs-err",
|
||||||
|
"futures",
|
||||||
"generic-array",
|
"generic-array",
|
||||||
"hex",
|
"hex",
|
||||||
"interim",
|
"interim",
|
||||||
@ -151,6 +152,7 @@ dependencies = [
|
|||||||
"memchr",
|
"memchr",
|
||||||
"minspan",
|
"minspan",
|
||||||
"parse_duration",
|
"parse_duration",
|
||||||
|
"pretty_assertions",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
"regex",
|
"regex",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
@ -182,6 +184,7 @@ dependencies = [
|
|||||||
"pretty_assertions",
|
"pretty_assertions",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
"serde",
|
"serde",
|
||||||
|
"sqlx",
|
||||||
"typed-builder",
|
"typed-builder",
|
||||||
"uuid",
|
"uuid",
|
||||||
]
|
]
|
||||||
@ -240,6 +243,7 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
"uuid",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -950,6 +954,21 @@ version = "2.9.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0845fa252299212f0389d64ba26f34fa32cfe41588355f21ed507c59a0f64541"
|
checksum = "0845fa252299212f0389d64ba26f34fa32cfe41588355f21ed507c59a0f64541"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures"
|
||||||
|
version = "0.3.24"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c"
|
||||||
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
|
"futures-core",
|
||||||
|
"futures-executor",
|
||||||
|
"futures-io",
|
||||||
|
"futures-sink",
|
||||||
|
"futures-task",
|
||||||
|
"futures-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-channel"
|
name = "futures-channel"
|
||||||
version = "0.3.24"
|
version = "0.3.24"
|
||||||
@ -988,6 +1007,12 @@ dependencies = [
|
|||||||
"parking_lot 0.11.2",
|
"parking_lot 0.11.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-io"
|
||||||
|
version = "0.3.28"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-macro"
|
name = "futures-macro"
|
||||||
version = "0.3.24"
|
version = "0.3.24"
|
||||||
@ -1017,10 +1042,13 @@ version = "0.3.24"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90"
|
checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"futures-io",
|
||||||
"futures-macro",
|
"futures-macro",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
"futures-task",
|
"futures-task",
|
||||||
|
"memchr",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"pin-utils",
|
"pin-utils",
|
||||||
"slab",
|
"slab",
|
||||||
@ -2567,6 +2595,7 @@ dependencies = [
|
|||||||
"thiserror",
|
"thiserror",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"url",
|
"url",
|
||||||
|
"uuid",
|
||||||
"webpki-roots",
|
"webpki-roots",
|
||||||
"whoami",
|
"whoami",
|
||||||
]
|
]
|
||||||
@ -3037,6 +3066,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81"
|
checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"getrandom 0.2.7",
|
"getrandom 0.2.7",
|
||||||
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
17
Cargo.toml
17
Cargo.toml
@ -1,11 +1,11 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"atuin",
|
"atuin",
|
||||||
"atuin-client",
|
"atuin-client",
|
||||||
"atuin-server",
|
"atuin-server",
|
||||||
"atuin-server-postgres",
|
"atuin-server-postgres",
|
||||||
"atuin-server-database",
|
"atuin-server-database",
|
||||||
"atuin-common",
|
"atuin-common",
|
||||||
]
|
]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
@ -35,9 +35,10 @@ semver = "1.0.14"
|
|||||||
serde = { version = "1.0.145", features = ["derive"] }
|
serde = { version = "1.0.145", features = ["derive"] }
|
||||||
serde_json = "1.0.86"
|
serde_json = "1.0.86"
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
uuid = { version = "1.3", features = ["v4"] }
|
uuid = { version = "1.3", features = ["v4", "serde"] }
|
||||||
whoami = "1.1.2"
|
whoami = "1.1.2"
|
||||||
typed-builder = "0.14.0"
|
typed-builder = "0.14.0"
|
||||||
|
pretty_assertions = "1.3.0"
|
||||||
|
|
||||||
[workspace.dependencies.reqwest]
|
[workspace.dependencies.reqwest]
|
||||||
version = "0.11"
|
version = "0.11"
|
||||||
@ -46,4 +47,4 @@ default-features = false
|
|||||||
|
|
||||||
[workspace.dependencies.sqlx]
|
[workspace.dependencies.sqlx]
|
||||||
version = "0.6"
|
version = "0.6"
|
||||||
features = ["runtime-tokio-rustls", "chrono", "postgres"]
|
features = ["runtime-tokio-rustls", "chrono", "postgres", "uuid"]
|
||||||
|
@ -54,10 +54,14 @@ rmp = { version = "0.8.11" }
|
|||||||
typed-builder = "0.14.0"
|
typed-builder = "0.14.0"
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
semver = { workspace = true }
|
semver = { workspace = true }
|
||||||
|
futures = "0.3"
|
||||||
|
|
||||||
# encryption
|
# encryption
|
||||||
rusty_paseto = { version = "0.5.0", default-features = false }
|
rusty_paseto = { version = "0.5.0", default-features = false }
|
||||||
rusty_paserk = { version = "0.2.0", default-features = false, features = ["v4", "serde"] }
|
rusty_paserk = { version = "0.2.0", default-features = false, features = [
|
||||||
|
"v4",
|
||||||
|
"serde",
|
||||||
|
] }
|
||||||
|
|
||||||
# sync
|
# sync
|
||||||
urlencoding = { version = "2.1.0", optional = true }
|
urlencoding = { version = "2.1.0", optional = true }
|
||||||
@ -69,3 +73,4 @@ generic-array = { version = "0.14", optional = true, features = ["serde"] }
|
|||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
pretty_assertions = { workspace = true }
|
||||||
|
@ -7,7 +7,8 @@ create table if not exists records (
|
|||||||
timestamp integer not null,
|
timestamp integer not null,
|
||||||
tag text not null,
|
tag text not null,
|
||||||
version text not null,
|
version text not null,
|
||||||
data blob not null
|
data blob not null,
|
||||||
|
cek blob not null
|
||||||
);
|
);
|
||||||
|
|
||||||
create index host_idx on records (host);
|
create index host_idx on records (host);
|
||||||
|
@ -1,3 +0,0 @@
|
|||||||
-- store content encryption keys in the record
|
|
||||||
alter table records
|
|
||||||
add column cek text;
|
|
@ -8,9 +8,13 @@ use reqwest::{
|
|||||||
StatusCode, Url,
|
StatusCode, Url,
|
||||||
};
|
};
|
||||||
|
|
||||||
use atuin_common::api::{
|
use atuin_common::record::{EncryptedData, HostId, Record, RecordId};
|
||||||
AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse,
|
use atuin_common::{
|
||||||
LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse,
|
api::{
|
||||||
|
AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse,
|
||||||
|
LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse,
|
||||||
|
},
|
||||||
|
record::RecordIndex,
|
||||||
};
|
};
|
||||||
use semver::Version;
|
use semver::Version;
|
||||||
|
|
||||||
@ -195,6 +199,55 @@ impl<'a> Client<'a> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> {
|
||||||
|
let url = format!("{}/record", self.sync_addr);
|
||||||
|
let url = Url::parse(url.as_str())?;
|
||||||
|
|
||||||
|
self.client.post(url).json(records).send().await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn next_records(
|
||||||
|
&self,
|
||||||
|
host: HostId,
|
||||||
|
tag: String,
|
||||||
|
start: Option<RecordId>,
|
||||||
|
count: u64,
|
||||||
|
) -> Result<Vec<Record<EncryptedData>>> {
|
||||||
|
let url = format!(
|
||||||
|
"{}/record/next?host={}&tag={}&count={}",
|
||||||
|
self.sync_addr, host.0, tag, count
|
||||||
|
);
|
||||||
|
let mut url = Url::parse(url.as_str())?;
|
||||||
|
|
||||||
|
if let Some(start) = start {
|
||||||
|
url.set_query(Some(
|
||||||
|
format!(
|
||||||
|
"host={}&tag={}&count={}&start={}",
|
||||||
|
host.0, tag, count, start.0
|
||||||
|
)
|
||||||
|
.as_str(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp = self.client.get(url).send().await?;
|
||||||
|
|
||||||
|
let records = resp.json::<Vec<Record<EncryptedData>>>().await?;
|
||||||
|
|
||||||
|
Ok(records)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn record_index(&self) -> Result<RecordIndex> {
|
||||||
|
let url = format!("{}/record", self.sync_addr);
|
||||||
|
let url = Url::parse(url.as_str())?;
|
||||||
|
|
||||||
|
let resp = self.client.get(url).send().await?;
|
||||||
|
let index = resp.json().await?;
|
||||||
|
|
||||||
|
Ok(index)
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn delete(&self) -> Result<()> {
|
pub async fn delete(&self) -> Result<()> {
|
||||||
let url = format!("{}/account", self.sync_addr);
|
let url = format!("{}/account", self.sync_addr);
|
||||||
let url = Url::parse(url.as_str())?;
|
let url = Url::parse(url.as_str())?;
|
||||||
|
@ -57,7 +57,7 @@ pub fn current_context() -> Context {
|
|||||||
session,
|
session,
|
||||||
hostname,
|
hostname,
|
||||||
cwd,
|
cwd,
|
||||||
host_id,
|
host_id: host_id.0.as_simple().to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,10 +101,7 @@ impl KvStore {
|
|||||||
|
|
||||||
let bytes = record.serialize()?;
|
let bytes = record.serialize()?;
|
||||||
|
|
||||||
let parent = store
|
let parent = store.tail(host_id, KV_TAG).await?.map(|entry| entry.id);
|
||||||
.last(host_id.as_str(), KV_TAG)
|
|
||||||
.await?
|
|
||||||
.map(|entry| entry.id);
|
|
||||||
|
|
||||||
let record = atuin_common::record::Record::builder()
|
let record = atuin_common::record::Record::builder()
|
||||||
.host(host_id)
|
.host(host_id)
|
||||||
@ -130,17 +127,22 @@ impl KvStore {
|
|||||||
namespace: &str,
|
namespace: &str,
|
||||||
key: &str,
|
key: &str,
|
||||||
) -> Result<Option<KvRecord>> {
|
) -> Result<Option<KvRecord>> {
|
||||||
// TODO: don't load this from disk so much
|
|
||||||
let host_id = Settings::host_id().expect("failed to get host_id");
|
|
||||||
|
|
||||||
// Currently, this is O(n). When we have an actual KV store, it can be better
|
// Currently, this is O(n). When we have an actual KV store, it can be better
|
||||||
// Just a poc for now!
|
// Just a poc for now!
|
||||||
|
|
||||||
// iterate records to find the value we want
|
// iterate records to find the value we want
|
||||||
// start at the end, so we get the most recent version
|
// start at the end, so we get the most recent version
|
||||||
let Some(mut record) = store.last(host_id.as_str(), KV_TAG).await? else {
|
let tails = store.tag_tails(KV_TAG).await?;
|
||||||
|
|
||||||
|
if tails.is_empty() {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
}
|
||||||
|
|
||||||
|
// first, decide on a record.
|
||||||
|
// try getting the newest first
|
||||||
|
// we always need a way of deciding the "winner" of a write
|
||||||
|
// TODO(ellie): something better than last-write-wins, what if two write at the same time?
|
||||||
|
let mut record = tails.iter().max_by_key(|r| r.timestamp).unwrap().clone();
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let decrypted = match record.version.as_str() {
|
let decrypted = match record.version.as_str() {
|
||||||
@ -154,7 +156,7 @@ impl KvStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(parent) = decrypted.parent {
|
if let Some(parent) = decrypted.parent {
|
||||||
record = store.get(parent.as_str()).await?;
|
record = store.get(parent).await?;
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
use atuin_common::record::{AdditionalData, DecryptedData, EncryptedData, Encryption};
|
use atuin_common::record::{
|
||||||
|
AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId,
|
||||||
|
};
|
||||||
use base64::{engine::general_purpose, Engine};
|
use base64::{engine::general_purpose, Engine};
|
||||||
use eyre::{ensure, Context, Result};
|
use eyre::{ensure, Context, Result};
|
||||||
use rusty_paserk::{Key, KeyId, Local, PieWrappedKey};
|
use rusty_paserk::{Key, KeyId, Local, PieWrappedKey};
|
||||||
@ -158,10 +160,11 @@ struct AtuinFooter {
|
|||||||
// This cannot be changed, otherwise it breaks the authenticated encryption.
|
// This cannot be changed, otherwise it breaks the authenticated encryption.
|
||||||
#[derive(Debug, Copy, Clone, Serialize)]
|
#[derive(Debug, Copy, Clone, Serialize)]
|
||||||
struct Assertions<'a> {
|
struct Assertions<'a> {
|
||||||
id: &'a str,
|
id: &'a RecordId,
|
||||||
version: &'a str,
|
version: &'a str,
|
||||||
tag: &'a str,
|
tag: &'a str,
|
||||||
host: &'a str,
|
host: &'a HostId,
|
||||||
|
parent: Option<&'a RecordId>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
|
impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
|
||||||
@ -171,6 +174,7 @@ impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
|
|||||||
version: ad.version,
|
version: ad.version,
|
||||||
tag: ad.tag,
|
tag: ad.tag,
|
||||||
host: ad.host,
|
host: ad.host,
|
||||||
|
parent: ad.parent,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -183,7 +187,7 @@ impl Assertions<'_> {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use atuin_common::record::Record;
|
use atuin_common::{record::Record, utils::uuid_v7};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
@ -192,10 +196,11 @@ mod tests {
|
|||||||
let key = Key::<V4, Local>::new_os_random();
|
let key = Key::<V4, Local>::new_os_random();
|
||||||
|
|
||||||
let ad = AdditionalData {
|
let ad = AdditionalData {
|
||||||
id: "foo",
|
id: &RecordId(uuid_v7()),
|
||||||
version: "v0",
|
version: "v0",
|
||||||
tag: "kv",
|
tag: "kv",
|
||||||
host: "1234",
|
host: &HostId(uuid_v7()),
|
||||||
|
parent: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let data = DecryptedData(vec![1, 2, 3, 4]);
|
let data = DecryptedData(vec![1, 2, 3, 4]);
|
||||||
@ -210,10 +215,11 @@ mod tests {
|
|||||||
let key = Key::<V4, Local>::new_os_random();
|
let key = Key::<V4, Local>::new_os_random();
|
||||||
|
|
||||||
let ad = AdditionalData {
|
let ad = AdditionalData {
|
||||||
id: "foo",
|
id: &RecordId(uuid_v7()),
|
||||||
version: "v0",
|
version: "v0",
|
||||||
tag: "kv",
|
tag: "kv",
|
||||||
host: "1234",
|
host: &HostId(uuid_v7()),
|
||||||
|
parent: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let data = DecryptedData(vec![1, 2, 3, 4]);
|
let data = DecryptedData(vec![1, 2, 3, 4]);
|
||||||
@ -233,10 +239,11 @@ mod tests {
|
|||||||
let fake_key = Key::<V4, Local>::new_os_random();
|
let fake_key = Key::<V4, Local>::new_os_random();
|
||||||
|
|
||||||
let ad = AdditionalData {
|
let ad = AdditionalData {
|
||||||
id: "foo",
|
id: &RecordId(uuid_v7()),
|
||||||
version: "v0",
|
version: "v0",
|
||||||
tag: "kv",
|
tag: "kv",
|
||||||
host: "1234",
|
host: &HostId(uuid_v7()),
|
||||||
|
parent: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let data = DecryptedData(vec![1, 2, 3, 4]);
|
let data = DecryptedData(vec![1, 2, 3, 4]);
|
||||||
@ -250,10 +257,11 @@ mod tests {
|
|||||||
let key = Key::<V4, Local>::new_os_random();
|
let key = Key::<V4, Local>::new_os_random();
|
||||||
|
|
||||||
let ad = AdditionalData {
|
let ad = AdditionalData {
|
||||||
id: "foo",
|
id: &RecordId(uuid_v7()),
|
||||||
version: "v0",
|
version: "v0",
|
||||||
tag: "kv",
|
tag: "kv",
|
||||||
host: "1234",
|
host: &HostId(uuid_v7()),
|
||||||
|
parent: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let data = DecryptedData(vec![1, 2, 3, 4]);
|
let data = DecryptedData(vec![1, 2, 3, 4]);
|
||||||
@ -261,10 +269,8 @@ mod tests {
|
|||||||
let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes());
|
let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes());
|
||||||
|
|
||||||
let ad = AdditionalData {
|
let ad = AdditionalData {
|
||||||
id: "foo1",
|
id: &RecordId(uuid_v7()),
|
||||||
version: "v0",
|
..ad
|
||||||
tag: "kv",
|
|
||||||
host: "1234",
|
|
||||||
};
|
};
|
||||||
let _ = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap_err();
|
let _ = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap_err();
|
||||||
}
|
}
|
||||||
@ -275,10 +281,11 @@ mod tests {
|
|||||||
let key2 = Key::<V4, Local>::new_os_random();
|
let key2 = Key::<V4, Local>::new_os_random();
|
||||||
|
|
||||||
let ad = AdditionalData {
|
let ad = AdditionalData {
|
||||||
id: "foo",
|
id: &RecordId(uuid_v7()),
|
||||||
version: "v0",
|
version: "v0",
|
||||||
tag: "kv",
|
tag: "kv",
|
||||||
host: "1234",
|
host: &HostId(uuid_v7()),
|
||||||
|
parent: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let data = DecryptedData(vec![1, 2, 3, 4]);
|
let data = DecryptedData(vec![1, 2, 3, 4]);
|
||||||
@ -304,10 +311,10 @@ mod tests {
|
|||||||
fn full_record_round_trip() {
|
fn full_record_round_trip() {
|
||||||
let key = [0x55; 32];
|
let key = [0x55; 32];
|
||||||
let record = Record::builder()
|
let record = Record::builder()
|
||||||
.id("1".to_owned())
|
.id(RecordId(uuid_v7()))
|
||||||
.version("v0".to_owned())
|
.version("v0".to_owned())
|
||||||
.tag("kv".to_owned())
|
.tag("kv".to_owned())
|
||||||
.host("host1".to_owned())
|
.host(HostId(uuid_v7()))
|
||||||
.timestamp(1687244806000000)
|
.timestamp(1687244806000000)
|
||||||
.data(DecryptedData(vec![1, 2, 3, 4]))
|
.data(DecryptedData(vec![1, 2, 3, 4]))
|
||||||
.build();
|
.build();
|
||||||
@ -316,30 +323,20 @@ mod tests {
|
|||||||
|
|
||||||
assert!(!encrypted.data.data.is_empty());
|
assert!(!encrypted.data.data.is_empty());
|
||||||
assert!(!encrypted.data.content_encryption_key.is_empty());
|
assert!(!encrypted.data.content_encryption_key.is_empty());
|
||||||
assert_eq!(encrypted.id, "1");
|
|
||||||
assert_eq!(encrypted.host, "host1");
|
|
||||||
assert_eq!(encrypted.version, "v0");
|
|
||||||
assert_eq!(encrypted.tag, "kv");
|
|
||||||
assert_eq!(encrypted.timestamp, 1687244806000000);
|
|
||||||
|
|
||||||
let decrypted = encrypted.decrypt::<PASETO_V4>(&key).unwrap();
|
let decrypted = encrypted.decrypt::<PASETO_V4>(&key).unwrap();
|
||||||
|
|
||||||
assert_eq!(decrypted.data.0, [1, 2, 3, 4]);
|
assert_eq!(decrypted.data.0, [1, 2, 3, 4]);
|
||||||
assert_eq!(decrypted.id, "1");
|
|
||||||
assert_eq!(decrypted.host, "host1");
|
|
||||||
assert_eq!(decrypted.version, "v0");
|
|
||||||
assert_eq!(decrypted.tag, "kv");
|
|
||||||
assert_eq!(decrypted.timestamp, 1687244806000000);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn full_record_round_trip_fail() {
|
fn full_record_round_trip_fail() {
|
||||||
let key = [0x55; 32];
|
let key = [0x55; 32];
|
||||||
let record = Record::builder()
|
let record = Record::builder()
|
||||||
.id("1".to_owned())
|
.id(RecordId(uuid_v7()))
|
||||||
.version("v0".to_owned())
|
.version("v0".to_owned())
|
||||||
.tag("kv".to_owned())
|
.tag("kv".to_owned())
|
||||||
.host("host1".to_owned())
|
.host(HostId(uuid_v7()))
|
||||||
.timestamp(1687244806000000)
|
.timestamp(1687244806000000)
|
||||||
.data(DecryptedData(vec![1, 2, 3, 4]))
|
.data(DecryptedData(vec![1, 2, 3, 4]))
|
||||||
.build();
|
.build();
|
||||||
@ -347,13 +344,13 @@ mod tests {
|
|||||||
let encrypted = record.encrypt::<PASETO_V4>(&key);
|
let encrypted = record.encrypt::<PASETO_V4>(&key);
|
||||||
|
|
||||||
let mut enc1 = encrypted.clone();
|
let mut enc1 = encrypted.clone();
|
||||||
enc1.host = "host2".to_owned();
|
enc1.host = HostId(uuid_v7());
|
||||||
let _ = enc1
|
let _ = enc1
|
||||||
.decrypt::<PASETO_V4>(&key)
|
.decrypt::<PASETO_V4>(&key)
|
||||||
.expect_err("tampering with the host should result in auth failure");
|
.expect_err("tampering with the host should result in auth failure");
|
||||||
|
|
||||||
let mut enc2 = encrypted;
|
let mut enc2 = encrypted;
|
||||||
enc2.id = "2".to_owned();
|
enc2.id = RecordId(uuid_v7());
|
||||||
let _ = enc2
|
let _ = enc2
|
||||||
.decrypt::<PASETO_V4>(&key)
|
.decrypt::<PASETO_V4>(&key)
|
||||||
.expect_err("tampering with the id should result in auth failure");
|
.expect_err("tampering with the id should result in auth failure");
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
pub mod encryption;
|
pub mod encryption;
|
||||||
pub mod sqlite_store;
|
pub mod sqlite_store;
|
||||||
pub mod store;
|
pub mod store;
|
||||||
|
#[cfg(feature = "sync")]
|
||||||
|
pub mod sync;
|
||||||
|
@ -8,12 +8,14 @@ use std::str::FromStr;
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use eyre::{eyre, Result};
|
use eyre::{eyre, Result};
|
||||||
use fs_err as fs;
|
use fs_err as fs;
|
||||||
|
use futures::TryStreamExt;
|
||||||
use sqlx::{
|
use sqlx::{
|
||||||
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
|
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
|
||||||
Row,
|
Row,
|
||||||
};
|
};
|
||||||
|
|
||||||
use atuin_common::record::{EncryptedData, Record};
|
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
use super::store::Store;
|
use super::store::Store;
|
||||||
|
|
||||||
@ -62,11 +64,11 @@ impl SqliteStore {
|
|||||||
"insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek)
|
"insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek)
|
||||||
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
|
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
|
||||||
)
|
)
|
||||||
.bind(r.id.as_str())
|
.bind(r.id.0.as_simple().to_string())
|
||||||
.bind(r.host.as_str())
|
.bind(r.host.0.as_simple().to_string())
|
||||||
.bind(r.tag.as_str())
|
.bind(r.tag.as_str())
|
||||||
.bind(r.timestamp as i64)
|
.bind(r.timestamp as i64)
|
||||||
.bind(r.parent.as_ref())
|
.bind(r.parent.map(|p| p.0.as_simple().to_string()))
|
||||||
.bind(r.version.as_str())
|
.bind(r.version.as_str())
|
||||||
.bind(r.data.data.as_str())
|
.bind(r.data.data.as_str())
|
||||||
.bind(r.data.content_encryption_key.as_str())
|
.bind(r.data.content_encryption_key.as_str())
|
||||||
@ -79,10 +81,18 @@ impl SqliteStore {
|
|||||||
fn query_row(row: SqliteRow) -> Record<EncryptedData> {
|
fn query_row(row: SqliteRow) -> Record<EncryptedData> {
|
||||||
let timestamp: i64 = row.get("timestamp");
|
let timestamp: i64 = row.get("timestamp");
|
||||||
|
|
||||||
|
// tbh at this point things are pretty fucked so just panic
|
||||||
|
let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
|
||||||
|
let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
|
||||||
|
let parent: Option<&str> = row.get("parent");
|
||||||
|
|
||||||
|
let parent = parent
|
||||||
|
.map(|parent| Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB"));
|
||||||
|
|
||||||
Record {
|
Record {
|
||||||
id: row.get("id"),
|
id: RecordId(id),
|
||||||
host: row.get("host"),
|
host: HostId(host),
|
||||||
parent: row.get("parent"),
|
parent: parent.map(RecordId),
|
||||||
timestamp: timestamp as u64,
|
timestamp: timestamp as u64,
|
||||||
tag: row.get("tag"),
|
tag: row.get("tag"),
|
||||||
version: row.get("version"),
|
version: row.get("version"),
|
||||||
@ -111,9 +121,9 @@ impl Store for SqliteStore {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get(&self, id: &str) -> Result<Record<EncryptedData>> {
|
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
|
||||||
let res = sqlx::query("select * from records where id = ?1")
|
let res = sqlx::query("select * from records where id = ?1")
|
||||||
.bind(id)
|
.bind(id.0.as_simple().to_string())
|
||||||
.map(Self::query_row)
|
.map(Self::query_row)
|
||||||
.fetch_one(&self.pool)
|
.fetch_one(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
@ -121,10 +131,10 @@ impl Store for SqliteStore {
|
|||||||
Ok(res)
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn len(&self, host: &str, tag: &str) -> Result<u64> {
|
async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
|
||||||
let res: (i64,) =
|
let res: (i64,) =
|
||||||
sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
|
sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
|
||||||
.bind(host)
|
.bind(host.0.as_simple().to_string())
|
||||||
.bind(tag)
|
.bind(tag)
|
||||||
.fetch_one(&self.pool)
|
.fetch_one(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
@ -134,7 +144,7 @@ impl Store for SqliteStore {
|
|||||||
|
|
||||||
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> {
|
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> {
|
||||||
let res = sqlx::query("select * from records where parent = ?1")
|
let res = sqlx::query("select * from records where parent = ?1")
|
||||||
.bind(record.id.clone())
|
.bind(record.id.0.as_simple().to_string())
|
||||||
.map(Self::query_row)
|
.map(Self::query_row)
|
||||||
.fetch_one(&self.pool)
|
.fetch_one(&self.pool)
|
||||||
.await;
|
.await;
|
||||||
@ -146,11 +156,11 @@ impl Store for SqliteStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn first(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> {
|
async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
|
||||||
let res = sqlx::query(
|
let res = sqlx::query(
|
||||||
"select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
|
"select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
|
||||||
)
|
)
|
||||||
.bind(host)
|
.bind(host.0.as_simple().to_string())
|
||||||
.bind(tag)
|
.bind(tag)
|
||||||
.map(Self::query_row)
|
.map(Self::query_row)
|
||||||
.fetch_optional(&self.pool)
|
.fetch_optional(&self.pool)
|
||||||
@ -159,23 +169,53 @@ impl Store for SqliteStore {
|
|||||||
Ok(res)
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn last(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>> {
|
async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
|
||||||
let res = sqlx::query(
|
let res = sqlx::query(
|
||||||
"select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;",
|
"select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;",
|
||||||
)
|
)
|
||||||
.bind(tag)
|
.bind(tag)
|
||||||
.bind(host)
|
.bind(host.0.as_simple().to_string())
|
||||||
.map(Self::query_row)
|
.map(Self::query_row)
|
||||||
.fetch_optional(&self.pool)
|
.fetch_optional(&self.pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(res)
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
|
||||||
|
let res = sqlx::query(
|
||||||
|
"select * from records rp where tag=?1 and (select count(1) from records where parent=rp.id) = 0;",
|
||||||
|
)
|
||||||
|
.bind(tag)
|
||||||
|
.map(Self::query_row)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn tail_records(&self) -> Result<RecordIndex> {
|
||||||
|
let res = sqlx::query(
|
||||||
|
"select host, tag, id from records rp where (select count(1) from records where parent=rp.id) = 0;",
|
||||||
|
)
|
||||||
|
.map(|row: SqliteRow| {
|
||||||
|
let host: Uuid= Uuid::from_str(row.get("host")).expect("invalid uuid in db host");
|
||||||
|
let tag: String= row.get("tag");
|
||||||
|
let id: Uuid= Uuid::from_str(row.get("id")).expect("invalid uuid in db id");
|
||||||
|
|
||||||
|
(HostId(host), tag, RecordId(id))
|
||||||
|
})
|
||||||
|
.fetch(&self.pool)
|
||||||
|
.try_collect()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use atuin_common::record::{EncryptedData, Record};
|
use atuin_common::record::{EncryptedData, HostId, Record};
|
||||||
|
|
||||||
use crate::record::{encryption::PASETO_V4, store::Store};
|
use crate::record::{encryption::PASETO_V4, store::Store};
|
||||||
|
|
||||||
@ -183,7 +223,7 @@ mod tests {
|
|||||||
|
|
||||||
fn test_record() -> Record<EncryptedData> {
|
fn test_record() -> Record<EncryptedData> {
|
||||||
Record::builder()
|
Record::builder()
|
||||||
.host(atuin_common::utils::uuid_v7().simple().to_string())
|
.host(HostId(atuin_common::utils::uuid_v7()))
|
||||||
.version("v1".into())
|
.version("v1".into())
|
||||||
.tag(atuin_common::utils::uuid_v7().simple().to_string())
|
.tag(atuin_common::utils::uuid_v7().simple().to_string())
|
||||||
.data(EncryptedData {
|
.data(EncryptedData {
|
||||||
@ -218,10 +258,7 @@ mod tests {
|
|||||||
let record = test_record();
|
let record = test_record();
|
||||||
db.push(&record).await.unwrap();
|
db.push(&record).await.unwrap();
|
||||||
|
|
||||||
let new_record = db
|
let new_record = db.get(record.id).await.expect("failed to fetch record");
|
||||||
.get(record.id.as_str())
|
|
||||||
.await
|
|
||||||
.expect("failed to fetch record");
|
|
||||||
|
|
||||||
assert_eq!(record, new_record, "records are not equal");
|
assert_eq!(record, new_record, "records are not equal");
|
||||||
}
|
}
|
||||||
@ -233,7 +270,7 @@ mod tests {
|
|||||||
db.push(&record).await.unwrap();
|
db.push(&record).await.unwrap();
|
||||||
|
|
||||||
let len = db
|
let len = db
|
||||||
.len(record.host.as_str(), record.tag.as_str())
|
.len(record.host, record.tag.as_str())
|
||||||
.await
|
.await
|
||||||
.expect("failed to get store len");
|
.expect("failed to get store len");
|
||||||
|
|
||||||
@ -253,14 +290,8 @@ mod tests {
|
|||||||
db.push(&first).await.unwrap();
|
db.push(&first).await.unwrap();
|
||||||
db.push(&second).await.unwrap();
|
db.push(&second).await.unwrap();
|
||||||
|
|
||||||
let first_len = db
|
let first_len = db.len(first.host, first.tag.as_str()).await.unwrap();
|
||||||
.len(first.host.as_str(), first.tag.as_str())
|
let second_len = db.len(second.host, second.tag.as_str()).await.unwrap();
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let second_len = db
|
|
||||||
.len(second.host.as_str(), second.tag.as_str())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(first_len, 1, "expected length of 1 after insert");
|
assert_eq!(first_len, 1, "expected length of 1 after insert");
|
||||||
assert_eq!(second_len, 1, "expected length of 1 after insert");
|
assert_eq!(second_len, 1, "expected length of 1 after insert");
|
||||||
@ -281,7 +312,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(),
|
db.len(tail.host, tail.tag.as_str()).await.unwrap(),
|
||||||
100,
|
100,
|
||||||
"failed to insert 100 records"
|
"failed to insert 100 records"
|
||||||
);
|
);
|
||||||
@ -304,7 +335,7 @@ mod tests {
|
|||||||
db.push_batch(records.iter()).await.unwrap();
|
db.push_batch(records.iter()).await.unwrap();
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(),
|
db.len(tail.host, tail.tag.as_str()).await.unwrap(),
|
||||||
10000,
|
10000,
|
||||||
"failed to insert 10k records"
|
"failed to insert 10k records"
|
||||||
);
|
);
|
||||||
@ -327,7 +358,7 @@ mod tests {
|
|||||||
db.push_batch(records.iter()).await.unwrap();
|
db.push_batch(records.iter()).await.unwrap();
|
||||||
|
|
||||||
let mut record = db
|
let mut record = db
|
||||||
.first(tail.host.as_str(), tail.tag.as_str())
|
.head(tail.host, tail.tag.as_str())
|
||||||
.await
|
.await
|
||||||
.expect("in memory sqlite should not fail")
|
.expect("in memory sqlite should not fail")
|
||||||
.expect("entry exists");
|
.expect("entry exists");
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use eyre::Result;
|
use eyre::Result;
|
||||||
|
|
||||||
use atuin_common::record::{EncryptedData, Record};
|
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
|
||||||
|
|
||||||
/// A record store stores records
|
/// A record store stores records
|
||||||
/// In more detail - we tend to need to process this into _another_ format to actually query it.
|
/// In more detail - we tend to need to process this into _another_ format to actually query it.
|
||||||
@ -20,14 +20,22 @@ pub trait Store {
|
|||||||
records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync,
|
records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync,
|
||||||
) -> Result<()>;
|
) -> Result<()>;
|
||||||
|
|
||||||
async fn get(&self, id: &str) -> Result<Record<EncryptedData>>;
|
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
|
||||||
async fn len(&self, host: &str, tag: &str) -> Result<u64>;
|
async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
|
||||||
|
|
||||||
/// Get the record that follows this record
|
/// Get the record that follows this record
|
||||||
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>;
|
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>;
|
||||||
|
|
||||||
/// Get the first record for a given host and tag
|
/// Get the first record for a given host and tag
|
||||||
async fn first(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>>;
|
async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
|
||||||
|
|
||||||
/// Get the last record for a given host and tag
|
/// Get the last record for a given host and tag
|
||||||
async fn last(&self, host: &str, tag: &str) -> Result<Option<Record<EncryptedData>>>;
|
async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
|
||||||
|
|
||||||
|
// Get the last record for all hosts for a given tag, useful for the read path of apps.
|
||||||
|
async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>;
|
||||||
|
|
||||||
|
// Get the latest host/tag/record tuple for every set in the store. useful for building an
|
||||||
|
// index
|
||||||
|
async fn tail_records(&self) -> Result<RecordIndex>;
|
||||||
}
|
}
|
||||||
|
421
atuin-client/src/record/sync.rs
Normal file
421
atuin-client/src/record/sync.rs
Normal file
@ -0,0 +1,421 @@
|
|||||||
|
// do a sync :O
|
||||||
|
use eyre::Result;
|
||||||
|
|
||||||
|
use super::store::Store;
|
||||||
|
use crate::{api_client::Client, settings::Settings};
|
||||||
|
|
||||||
|
use atuin_common::record::{Diff, HostId, RecordId, RecordIndex};
|
||||||
|
|
||||||
|
#[derive(Debug, Eq, PartialEq)]
|
||||||
|
pub enum Operation {
|
||||||
|
// Either upload or download until the tail matches the below
|
||||||
|
Upload {
|
||||||
|
tail: RecordId,
|
||||||
|
host: HostId,
|
||||||
|
tag: String,
|
||||||
|
},
|
||||||
|
Download {
|
||||||
|
tail: RecordId,
|
||||||
|
host: HostId,
|
||||||
|
tag: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Diff>, RecordIndex)> {
|
||||||
|
let client = Client::new(&settings.sync_address, &settings.session_token)?;
|
||||||
|
|
||||||
|
let local_index = store.tail_records().await?;
|
||||||
|
let remote_index = client.record_index().await?;
|
||||||
|
|
||||||
|
let diff = local_index.diff(&remote_index);
|
||||||
|
|
||||||
|
Ok((diff, remote_index))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take a diff, along with a local store, and resolve it into a set of operations.
|
||||||
|
// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download.
|
||||||
|
// In theory this could be done as a part of the diffing stage, but it's easier to reason
|
||||||
|
// about and test this way
|
||||||
|
pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Operation>> {
|
||||||
|
let mut operations = Vec::with_capacity(diffs.len());
|
||||||
|
|
||||||
|
for diff in diffs {
|
||||||
|
// First, try to fetch the tail
|
||||||
|
// If it exists locally, then that means we need to update the remote
|
||||||
|
// host until it has the same tail. Ie, upload.
|
||||||
|
// If it does not exist locally, that means remote is ahead of us.
|
||||||
|
// Therefore, we need to download until our local tail matches
|
||||||
|
let record = store.get(diff.tail).await;
|
||||||
|
|
||||||
|
let op = if record.is_ok() {
|
||||||
|
// if local has the ID, then we should find the actual tail of this
|
||||||
|
// store, so we know what we need to update the remote to.
|
||||||
|
let tail = store
|
||||||
|
.tail(diff.host, diff.tag.as_str())
|
||||||
|
.await?
|
||||||
|
.expect("failed to fetch last record, expected tag/host to exist");
|
||||||
|
|
||||||
|
// TODO(ellie) update the diffing so that it stores the context of the current tail
|
||||||
|
// that way, we can determine how much we need to upload.
|
||||||
|
// For now just keep uploading until tails match
|
||||||
|
|
||||||
|
Operation::Upload {
|
||||||
|
tail: tail.id,
|
||||||
|
host: diff.host,
|
||||||
|
tag: diff.tag,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Operation::Download {
|
||||||
|
tail: diff.tail,
|
||||||
|
host: diff.host,
|
||||||
|
tag: diff.tag,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
operations.push(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// sort them - purely so we have a stable testing order, and can rely on
|
||||||
|
// same input = same output
|
||||||
|
// We can sort by ID so long as we continue to use UUIDv7 or something
|
||||||
|
// with the same properties
|
||||||
|
|
||||||
|
operations.sort_by_key(|op| match op {
|
||||||
|
Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
|
||||||
|
Operation::Download { tail, host, .. } => ("download", *host, *tail),
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(operations)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn sync_upload(
|
||||||
|
store: &mut impl Store,
|
||||||
|
remote_index: &RecordIndex,
|
||||||
|
client: &Client<'_>,
|
||||||
|
op: (HostId, String, RecordId),
|
||||||
|
) -> Result<i64> {
|
||||||
|
let upload_page_size = 100;
|
||||||
|
let mut total = 0;
|
||||||
|
|
||||||
|
// so. we have an upload operation, with the tail representing the state
|
||||||
|
// we want to get the remote to
|
||||||
|
let current_tail = remote_index.get(op.0, op.1.clone());
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"Syncing local {:?}/{}/{:?}, remote has {:?}",
|
||||||
|
op.0, op.1, op.2, current_tail
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = if let Some(current_tail) = current_tail {
|
||||||
|
current_tail
|
||||||
|
} else {
|
||||||
|
store
|
||||||
|
.head(op.0, op.1.as_str())
|
||||||
|
.await
|
||||||
|
.expect("failed to fetch host/tag head")
|
||||||
|
.expect("host/tag not in current index")
|
||||||
|
.id
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!("starting push to remote from: {:?}", start);
|
||||||
|
|
||||||
|
// we have the start point for sync. it is either the head of the store if
|
||||||
|
// the remote has no data for it, or the tail that the remote has
|
||||||
|
// we need to iterate from the remote tail, and keep going until
|
||||||
|
// remote tail = current local tail
|
||||||
|
|
||||||
|
let mut record = Some(store.get(start).await.unwrap());
|
||||||
|
|
||||||
|
let mut buf = Vec::with_capacity(upload_page_size);
|
||||||
|
|
||||||
|
while let Some(r) = record {
|
||||||
|
if buf.len() < upload_page_size {
|
||||||
|
buf.push(r.clone());
|
||||||
|
} else {
|
||||||
|
client.post_records(&buf).await?;
|
||||||
|
|
||||||
|
// can we reset what we have? len = 0 but keep capacity
|
||||||
|
buf = Vec::with_capacity(upload_page_size);
|
||||||
|
}
|
||||||
|
record = store.next(&r).await?;
|
||||||
|
|
||||||
|
total += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !buf.is_empty() {
|
||||||
|
client.post_records(&buf).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(total)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn sync_download(
|
||||||
|
store: &mut impl Store,
|
||||||
|
remote_index: &RecordIndex,
|
||||||
|
client: &Client<'_>,
|
||||||
|
op: (HostId, String, RecordId),
|
||||||
|
) -> Result<i64> {
|
||||||
|
// TODO(ellie): implement variable page sizing like on history sync
|
||||||
|
let download_page_size = 1000;
|
||||||
|
|
||||||
|
let mut total = 0;
|
||||||
|
|
||||||
|
// We know that the remote is ahead of us, so let's keep downloading until both
|
||||||
|
// 1) The remote stops returning full pages
|
||||||
|
// 2) The tail equals what we expect
|
||||||
|
//
|
||||||
|
// If (1) occurs without (2), then something is wrong with our index calculation
|
||||||
|
// and we should bail.
|
||||||
|
let remote_tail = remote_index
|
||||||
|
.get(op.0, op.1.clone())
|
||||||
|
.expect("remote index does not contain expected tail during download");
|
||||||
|
let local_tail = store.tail(op.0, op.1.as_str()).await?;
|
||||||
|
//
|
||||||
|
// We expect that the operations diff will represent the desired state
|
||||||
|
// In this case, that contains the remote tail.
|
||||||
|
assert_eq!(remote_tail, op.2);
|
||||||
|
|
||||||
|
println!("Downloading {:?}/{}/{:?} to local", op.0, op.1, op.2);
|
||||||
|
|
||||||
|
let mut records = client
|
||||||
|
.next_records(
|
||||||
|
op.0,
|
||||||
|
op.1.clone(),
|
||||||
|
local_tail.map(|r| r.id),
|
||||||
|
download_page_size,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
while !records.is_empty() {
|
||||||
|
total += std::cmp::min(download_page_size, records.len() as u64);
|
||||||
|
store.push_batch(records.iter()).await?;
|
||||||
|
|
||||||
|
if records.last().unwrap().id == remote_tail {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
records = client
|
||||||
|
.next_records(
|
||||||
|
op.0,
|
||||||
|
op.1.clone(),
|
||||||
|
records.last().map(|r| r.id),
|
||||||
|
download_page_size,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(total as i64)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn sync_remote(
|
||||||
|
operations: Vec<Operation>,
|
||||||
|
remote_index: &RecordIndex,
|
||||||
|
local_store: &mut impl Store,
|
||||||
|
settings: &Settings,
|
||||||
|
) -> Result<(i64, i64)> {
|
||||||
|
let client = Client::new(&settings.sync_address, &settings.session_token)?;
|
||||||
|
|
||||||
|
let mut uploaded = 0;
|
||||||
|
let mut downloaded = 0;
|
||||||
|
|
||||||
|
// this can totally run in parallel, but lets get it working first
|
||||||
|
for i in operations {
|
||||||
|
match i {
|
||||||
|
Operation::Upload { tail, host, tag } => {
|
||||||
|
uploaded +=
|
||||||
|
sync_upload(local_store, remote_index, &client, (host, tag, tail)).await?
|
||||||
|
}
|
||||||
|
Operation::Download { tail, host, tag } => {
|
||||||
|
downloaded +=
|
||||||
|
sync_download(local_store, remote_index, &client, (host, tag, tail)).await?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((uploaded, downloaded))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use atuin_common::record::{Diff, EncryptedData, HostId, Record};
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
|
|
||||||
|
use crate::record::{
|
||||||
|
encryption::PASETO_V4,
|
||||||
|
sqlite_store::SqliteStore,
|
||||||
|
store::Store,
|
||||||
|
sync::{self, Operation},
|
||||||
|
};
|
||||||
|
|
||||||
|
fn test_record() -> Record<EncryptedData> {
|
||||||
|
Record::builder()
|
||||||
|
.host(HostId(atuin_common::utils::uuid_v7()))
|
||||||
|
.version("v1".into())
|
||||||
|
.tag(atuin_common::utils::uuid_v7().simple().to_string())
|
||||||
|
.data(EncryptedData {
|
||||||
|
data: String::new(),
|
||||||
|
content_encryption_key: String::new(),
|
||||||
|
})
|
||||||
|
.build()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take a list of local records, and a list of remote records.
|
||||||
|
// Return the local database, and a diff of local/remote, ready to build
|
||||||
|
// ops
|
||||||
|
async fn build_test_diff(
|
||||||
|
local_records: Vec<Record<EncryptedData>>,
|
||||||
|
remote_records: Vec<Record<EncryptedData>>,
|
||||||
|
) -> (SqliteStore, Vec<Diff>) {
|
||||||
|
let local_store = SqliteStore::new(":memory:")
|
||||||
|
.await
|
||||||
|
.expect("failed to open in memory sqlite");
|
||||||
|
let remote_store = SqliteStore::new(":memory:")
|
||||||
|
.await
|
||||||
|
.expect("failed to open in memory sqlite"); // "remote"
|
||||||
|
|
||||||
|
for i in local_records {
|
||||||
|
local_store.push(&i).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in remote_records {
|
||||||
|
remote_store.push(&i).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let local_index = local_store.tail_records().await.unwrap();
|
||||||
|
let remote_index = remote_store.tail_records().await.unwrap();
|
||||||
|
|
||||||
|
let diff = local_index.diff(&remote_index);
|
||||||
|
|
||||||
|
(local_store, diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_basic_diff() {
|
||||||
|
// a diff where local is ahead of remote. nothing else.
|
||||||
|
|
||||||
|
let record = test_record();
|
||||||
|
let (store, diff) = build_test_diff(vec![record.clone()], vec![]).await;
|
||||||
|
|
||||||
|
assert_eq!(diff.len(), 1);
|
||||||
|
|
||||||
|
let operations = sync::operations(diff, &store).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(operations.len(), 1);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
operations[0],
|
||||||
|
Operation::Upload {
|
||||||
|
host: record.host,
|
||||||
|
tag: record.tag,
|
||||||
|
tail: record.id
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn build_two_way_diff() {
|
||||||
|
// a diff where local is ahead of remote for one, and remote for
|
||||||
|
// another. One upload, one download
|
||||||
|
|
||||||
|
let shared_record = test_record();
|
||||||
|
|
||||||
|
let remote_ahead = test_record();
|
||||||
|
let local_ahead = shared_record
|
||||||
|
.new_child(vec![1, 2, 3])
|
||||||
|
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||||
|
|
||||||
|
let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store
|
||||||
|
let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store
|
||||||
|
|
||||||
|
let (store, diff) = build_test_diff(local, remote).await;
|
||||||
|
let operations = sync::operations(diff, &store).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(operations.len(), 2);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
operations,
|
||||||
|
vec![
|
||||||
|
Operation::Download {
|
||||||
|
tail: remote_ahead.id,
|
||||||
|
host: remote_ahead.host,
|
||||||
|
tag: remote_ahead.tag,
|
||||||
|
},
|
||||||
|
Operation::Upload {
|
||||||
|
tail: local_ahead.id,
|
||||||
|
host: local_ahead.host,
|
||||||
|
tag: local_ahead.tag,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn build_complex_diff() {
|
||||||
|
// One shared, ahead but known only by remote
|
||||||
|
// One known only by local
|
||||||
|
// One known only by remote
|
||||||
|
|
||||||
|
let shared_record = test_record();
|
||||||
|
|
||||||
|
let remote_known = test_record();
|
||||||
|
let local_known = test_record();
|
||||||
|
|
||||||
|
let second_shared = test_record();
|
||||||
|
let second_shared_remote_ahead = second_shared
|
||||||
|
.new_child(vec![1, 2, 3])
|
||||||
|
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||||
|
|
||||||
|
let local_ahead = shared_record
|
||||||
|
.new_child(vec![1, 2, 3])
|
||||||
|
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||||
|
|
||||||
|
let local = vec![
|
||||||
|
shared_record.clone(),
|
||||||
|
second_shared.clone(),
|
||||||
|
local_known.clone(),
|
||||||
|
local_ahead.clone(),
|
||||||
|
];
|
||||||
|
|
||||||
|
let remote = vec![
|
||||||
|
shared_record.clone(),
|
||||||
|
second_shared.clone(),
|
||||||
|
second_shared_remote_ahead.clone(),
|
||||||
|
remote_known.clone(),
|
||||||
|
]; // remote knows about the already-synced, and one new record in a new store
|
||||||
|
|
||||||
|
let (store, diff) = build_test_diff(local, remote).await;
|
||||||
|
let operations = sync::operations(diff, &store).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(operations.len(), 4);
|
||||||
|
|
||||||
|
let mut result_ops = vec![
|
||||||
|
Operation::Download {
|
||||||
|
tail: remote_known.id,
|
||||||
|
host: remote_known.host,
|
||||||
|
tag: remote_known.tag,
|
||||||
|
},
|
||||||
|
Operation::Download {
|
||||||
|
tail: second_shared_remote_ahead.id,
|
||||||
|
host: second_shared.host,
|
||||||
|
tag: second_shared.tag,
|
||||||
|
},
|
||||||
|
Operation::Upload {
|
||||||
|
tail: local_ahead.id,
|
||||||
|
host: local_ahead.host,
|
||||||
|
tag: local_ahead.tag,
|
||||||
|
},
|
||||||
|
Operation::Upload {
|
||||||
|
tail: local_known.id,
|
||||||
|
host: local_known.host,
|
||||||
|
tag: local_known.tag,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
result_ops.sort_by_key(|op| match op {
|
||||||
|
Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
|
||||||
|
Operation::Download { tail, host, .. } => ("download", *host, *tail),
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(operations, result_ops);
|
||||||
|
}
|
||||||
|
}
|
@ -1,8 +1,10 @@
|
|||||||
use std::{
|
use std::{
|
||||||
io::prelude::*,
|
io::prelude::*,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
|
str::FromStr,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use atuin_common::record::HostId;
|
||||||
use chrono::{prelude::*, Utc};
|
use chrono::{prelude::*, Utc};
|
||||||
use clap::ValueEnum;
|
use clap::ValueEnum;
|
||||||
use config::{Config, Environment, File as ConfigFile, FileFormat};
|
use config::{Config, Environment, File as ConfigFile, FileFormat};
|
||||||
@ -12,6 +14,7 @@ use parse_duration::parse;
|
|||||||
use regex::RegexSet;
|
use regex::RegexSet;
|
||||||
use semver::Version;
|
use semver::Version;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
pub const HISTORY_PAGE_SIZE: i64 = 100;
|
pub const HISTORY_PAGE_SIZE: i64 = 100;
|
||||||
pub const LAST_SYNC_FILENAME: &str = "last_sync_time";
|
pub const LAST_SYNC_FILENAME: &str = "last_sync_time";
|
||||||
@ -228,11 +231,13 @@ impl Settings {
|
|||||||
Settings::load_time_from_file(LAST_VERSION_CHECK_FILENAME)
|
Settings::load_time_from_file(LAST_VERSION_CHECK_FILENAME)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn host_id() -> Option<String> {
|
pub fn host_id() -> Option<HostId> {
|
||||||
let id = Settings::read_from_data_dir(HOST_ID_FILENAME);
|
let id = Settings::read_from_data_dir(HOST_ID_FILENAME);
|
||||||
|
|
||||||
if id.is_some() {
|
if let Some(id) = id {
|
||||||
return id;
|
let parsed =
|
||||||
|
Uuid::from_str(id.as_str()).expect("failed to parse host ID from local directory");
|
||||||
|
return Some(HostId(parsed));
|
||||||
}
|
}
|
||||||
|
|
||||||
let uuid = atuin_common::utils::uuid_v7();
|
let uuid = atuin_common::utils::uuid_v7();
|
||||||
@ -240,7 +245,7 @@ impl Settings {
|
|||||||
Settings::save_to_data_dir(HOST_ID_FILENAME, uuid.as_simple().to_string().as_ref())
|
Settings::save_to_data_dir(HOST_ID_FILENAME, uuid.as_simple().to_string().as_ref())
|
||||||
.expect("Could not write host ID to data dir");
|
.expect("Could not write host ID to data dir");
|
||||||
|
|
||||||
Some(uuid.as_simple().to_string())
|
Some(HostId(uuid))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn should_sync(&self) -> Result<bool> {
|
pub fn should_sync(&self) -> Result<bool> {
|
||||||
|
@ -18,6 +18,7 @@ uuid = { workspace = true }
|
|||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
typed-builder = { workspace = true }
|
typed-builder = { workspace = true }
|
||||||
eyre = { workspace = true }
|
eyre = { workspace = true }
|
||||||
|
sqlx = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
pretty_assertions = "1.3.0"
|
pretty_assertions = { workspace = true }
|
||||||
|
@ -1,5 +1,57 @@
|
|||||||
#![forbid(unsafe_code)]
|
#![forbid(unsafe_code)]
|
||||||
|
|
||||||
|
/// Defines a new UUID type wrapper
|
||||||
|
macro_rules! new_uuid {
|
||||||
|
($name:ident) => {
|
||||||
|
#[derive(
|
||||||
|
Debug,
|
||||||
|
Copy,
|
||||||
|
Clone,
|
||||||
|
PartialEq,
|
||||||
|
Eq,
|
||||||
|
Hash,
|
||||||
|
PartialOrd,
|
||||||
|
Ord,
|
||||||
|
serde::Serialize,
|
||||||
|
serde::Deserialize,
|
||||||
|
)]
|
||||||
|
#[serde(transparent)]
|
||||||
|
pub struct $name(pub Uuid);
|
||||||
|
|
||||||
|
impl<DB: sqlx::Database> sqlx::Type<DB> for $name
|
||||||
|
where
|
||||||
|
Uuid: sqlx::Type<DB>,
|
||||||
|
{
|
||||||
|
fn type_info() -> <DB as sqlx::Database>::TypeInfo {
|
||||||
|
Uuid::type_info()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'r, DB: sqlx::Database> sqlx::Decode<'r, DB> for $name
|
||||||
|
where
|
||||||
|
Uuid: sqlx::Decode<'r, DB>,
|
||||||
|
{
|
||||||
|
fn decode(
|
||||||
|
value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef,
|
||||||
|
) -> std::result::Result<Self, sqlx::error::BoxDynError> {
|
||||||
|
Uuid::decode(value).map(Self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'q, DB: sqlx::Database> sqlx::Encode<'q, DB> for $name
|
||||||
|
where
|
||||||
|
Uuid: sqlx::Encode<'q, DB>,
|
||||||
|
{
|
||||||
|
fn encode_by_ref(
|
||||||
|
&self,
|
||||||
|
buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer,
|
||||||
|
) -> sqlx::encode::IsNull {
|
||||||
|
self.0.encode_by_ref(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
pub mod api;
|
pub mod api;
|
||||||
pub mod record;
|
pub mod record;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
@ -3,35 +3,43 @@ use std::collections::HashMap;
|
|||||||
use eyre::Result;
|
use eyre::Result;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use typed_builder::TypedBuilder;
|
use typed_builder::TypedBuilder;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq)]
|
#[derive(Clone, Debug, PartialEq)]
|
||||||
pub struct DecryptedData(pub Vec<u8>);
|
pub struct DecryptedData(pub Vec<u8>);
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct EncryptedData {
|
pub struct EncryptedData {
|
||||||
pub data: String,
|
pub data: String,
|
||||||
pub content_encryption_key: String,
|
pub content_encryption_key: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
|
pub struct Diff {
|
||||||
|
pub host: HostId,
|
||||||
|
pub tag: String,
|
||||||
|
pub tail: RecordId,
|
||||||
|
}
|
||||||
|
|
||||||
/// A single record stored inside of our local database
|
/// A single record stored inside of our local database
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)]
|
||||||
pub struct Record<Data> {
|
pub struct Record<Data> {
|
||||||
/// a unique ID
|
/// a unique ID
|
||||||
#[builder(default = crate::utils::uuid_v7().as_simple().to_string())]
|
#[builder(default = RecordId(crate::utils::uuid_v7()))]
|
||||||
pub id: String,
|
pub id: RecordId,
|
||||||
|
|
||||||
/// The unique ID of the host.
|
/// The unique ID of the host.
|
||||||
// TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store
|
// TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store
|
||||||
// as strings. I would rather avoid normalization, so store as UUID binary instead of
|
// as strings. I would rather avoid normalization, so store as UUID binary instead of
|
||||||
// encoding to a string and wasting much more storage.
|
// encoding to a string and wasting much more storage.
|
||||||
pub host: String,
|
pub host: HostId,
|
||||||
|
|
||||||
/// The ID of the parent entry
|
/// The ID of the parent entry
|
||||||
// A store is technically just a double linked list
|
// A store is technically just a double linked list
|
||||||
// We can do some cheating with the timestamps, but should not rely upon them.
|
// We can do some cheating with the timestamps, but should not rely upon them.
|
||||||
// Clocks are tricksy.
|
// Clocks are tricksy.
|
||||||
#[builder(default)]
|
#[builder(default)]
|
||||||
pub parent: Option<String>,
|
pub parent: Option<RecordId>,
|
||||||
|
|
||||||
/// The creation time in nanoseconds since unix epoch
|
/// The creation time in nanoseconds since unix epoch
|
||||||
#[builder(default = chrono::Utc::now().timestamp_nanos() as u64)]
|
#[builder(default = chrono::Utc::now().timestamp_nanos() as u64)]
|
||||||
@ -48,21 +56,25 @@ pub struct Record<Data> {
|
|||||||
pub data: Data,
|
pub data: Data,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
new_uuid!(RecordId);
|
||||||
|
new_uuid!(HostId);
|
||||||
|
|
||||||
/// Extra data from the record that should be encoded in the data
|
/// Extra data from the record that should be encoded in the data
|
||||||
#[derive(Debug, Copy, Clone)]
|
#[derive(Debug, Copy, Clone)]
|
||||||
pub struct AdditionalData<'a> {
|
pub struct AdditionalData<'a> {
|
||||||
pub id: &'a str,
|
pub id: &'a RecordId,
|
||||||
pub version: &'a str,
|
pub version: &'a str,
|
||||||
pub tag: &'a str,
|
pub tag: &'a str,
|
||||||
pub host: &'a str,
|
pub host: &'a HostId,
|
||||||
|
pub parent: Option<&'a RecordId>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Data> Record<Data> {
|
impl<Data> Record<Data> {
|
||||||
pub fn new_child(&self, data: Vec<u8>) -> Record<DecryptedData> {
|
pub fn new_child(&self, data: Vec<u8>) -> Record<DecryptedData> {
|
||||||
Record::builder()
|
Record::builder()
|
||||||
.host(self.host.clone())
|
.host(self.host)
|
||||||
.version(self.version.clone())
|
.version(self.version.clone())
|
||||||
.parent(Some(self.id.clone()))
|
.parent(Some(self.id))
|
||||||
.tag(self.tag.clone())
|
.tag(self.tag.clone())
|
||||||
.data(DecryptedData(data))
|
.data(DecryptedData(data))
|
||||||
.build()
|
.build()
|
||||||
@ -71,9 +83,10 @@ impl<Data> Record<Data> {
|
|||||||
|
|
||||||
/// An index representing the current state of the record stores
|
/// An index representing the current state of the record stores
|
||||||
/// This can be both remote, or local, and compared in either direction
|
/// This can be both remote, or local, and compared in either direction
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct RecordIndex {
|
pub struct RecordIndex {
|
||||||
// A map of host -> tag -> tail
|
// A map of host -> tag -> tail
|
||||||
pub hosts: HashMap<String, HashMap<String, String>>,
|
pub hosts: HashMap<HostId, HashMap<String, RecordId>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RecordIndex {
|
impl Default for RecordIndex {
|
||||||
@ -82,6 +95,14 @@ impl Default for RecordIndex {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Extend<(HostId, String, RecordId)> for RecordIndex {
|
||||||
|
fn extend<T: IntoIterator<Item = (HostId, String, RecordId)>>(&mut self, iter: T) {
|
||||||
|
for (host, tag, tail_id) in iter {
|
||||||
|
self.set_raw(host, tag, tail_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl RecordIndex {
|
impl RecordIndex {
|
||||||
pub fn new() -> RecordIndex {
|
pub fn new() -> RecordIndex {
|
||||||
RecordIndex {
|
RecordIndex {
|
||||||
@ -91,13 +112,14 @@ impl RecordIndex {
|
|||||||
|
|
||||||
/// Insert a new tail record into the store
|
/// Insert a new tail record into the store
|
||||||
pub fn set(&mut self, tail: Record<DecryptedData>) {
|
pub fn set(&mut self, tail: Record<DecryptedData>) {
|
||||||
self.hosts
|
self.set_raw(tail.host, tail.tag, tail.id)
|
||||||
.entry(tail.host)
|
|
||||||
.or_default()
|
|
||||||
.insert(tail.tag, tail.id);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, host: String, tag: String) -> Option<String> {
|
pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordId) {
|
||||||
|
self.hosts.entry(host).or_default().insert(tag, tail_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, host: HostId, tag: String) -> Option<RecordId> {
|
||||||
self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned()
|
self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -108,21 +130,29 @@ impl RecordIndex {
|
|||||||
/// other machine has a different tail, it will be the differing tail. This is useful to
|
/// other machine has a different tail, it will be the differing tail. This is useful to
|
||||||
/// check if the other index is ahead of us, or behind.
|
/// check if the other index is ahead of us, or behind.
|
||||||
/// If the other index does not have the (host, tag) pair, then the other value will be None.
|
/// If the other index does not have the (host, tag) pair, then the other value will be None.
|
||||||
pub fn diff(&self, other: &Self) -> Vec<(String, String, Option<String>)> {
|
pub fn diff(&self, other: &Self) -> Vec<Diff> {
|
||||||
let mut ret = Vec::new();
|
let mut ret = Vec::new();
|
||||||
|
|
||||||
// First, we check if other has everything that self has
|
// First, we check if other has everything that self has
|
||||||
for (host, tag_map) in self.hosts.iter() {
|
for (host, tag_map) in self.hosts.iter() {
|
||||||
for (tag, tail) in tag_map.iter() {
|
for (tag, tail) in tag_map.iter() {
|
||||||
match other.get(host.clone(), tag.clone()) {
|
match other.get(*host, tag.clone()) {
|
||||||
// The other store is all up to date! No diff.
|
// The other store is all up to date! No diff.
|
||||||
Some(t) if t.eq(tail) => continue,
|
Some(t) if t.eq(tail) => continue,
|
||||||
|
|
||||||
// The other store does exist, but it is either ahead or behind us. A diff regardless
|
// The other store does exist, but it is either ahead or behind us. A diff regardless
|
||||||
Some(t) => ret.push((host.clone(), tag.clone(), Some(t))),
|
Some(t) => ret.push(Diff {
|
||||||
|
host: *host,
|
||||||
|
tag: tag.clone(),
|
||||||
|
tail: t,
|
||||||
|
}),
|
||||||
|
|
||||||
// The other store does not exist :O
|
// The other store does not exist :O
|
||||||
None => ret.push((host.clone(), tag.clone(), None)),
|
None => ret.push(Diff {
|
||||||
|
host: *host,
|
||||||
|
tag: tag.clone(),
|
||||||
|
tail: *tail,
|
||||||
|
}),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -133,16 +163,20 @@ impl RecordIndex {
|
|||||||
// account for that!
|
// account for that!
|
||||||
for (host, tag_map) in other.hosts.iter() {
|
for (host, tag_map) in other.hosts.iter() {
|
||||||
for (tag, tail) in tag_map.iter() {
|
for (tag, tail) in tag_map.iter() {
|
||||||
match self.get(host.clone(), tag.clone()) {
|
match self.get(*host, tag.clone()) {
|
||||||
// If we have this host/tag combo, the comparison and diff will have already happened above
|
// If we have this host/tag combo, the comparison and diff will have already happened above
|
||||||
Some(_) => continue,
|
Some(_) => continue,
|
||||||
|
|
||||||
None => ret.push((host.clone(), tag.clone(), Some(tail.clone()))),
|
None => ret.push(Diff {
|
||||||
|
host: *host,
|
||||||
|
tag: tag.clone(),
|
||||||
|
tail: *tail,
|
||||||
|
}),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ret.sort();
|
ret.sort_by(|a, b| (a.host, a.tag.clone(), a.tail).cmp(&(b.host, b.tag.clone(), b.tail)));
|
||||||
ret
|
ret
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -168,6 +202,7 @@ impl Record<DecryptedData> {
|
|||||||
version: &self.version,
|
version: &self.version,
|
||||||
tag: &self.tag,
|
tag: &self.tag,
|
||||||
host: &self.host,
|
host: &self.host,
|
||||||
|
parent: self.parent.as_ref(),
|
||||||
};
|
};
|
||||||
Record {
|
Record {
|
||||||
data: E::encrypt(self.data, ad, key),
|
data: E::encrypt(self.data, ad, key),
|
||||||
@ -188,6 +223,7 @@ impl Record<EncryptedData> {
|
|||||||
version: &self.version,
|
version: &self.version,
|
||||||
tag: &self.tag,
|
tag: &self.tag,
|
||||||
host: &self.host,
|
host: &self.host,
|
||||||
|
parent: self.parent.as_ref(),
|
||||||
};
|
};
|
||||||
Ok(Record {
|
Ok(Record {
|
||||||
data: E::decrypt(self.data, ad, key)?,
|
data: E::decrypt(self.data, ad, key)?,
|
||||||
@ -210,6 +246,7 @@ impl Record<EncryptedData> {
|
|||||||
version: &self.version,
|
version: &self.version,
|
||||||
tag: &self.tag,
|
tag: &self.tag,
|
||||||
host: &self.host,
|
host: &self.host,
|
||||||
|
parent: self.parent.as_ref(),
|
||||||
};
|
};
|
||||||
Ok(Record {
|
Ok(Record {
|
||||||
data: E::re_encrypt(self.data, ad, old_key, new_key)?,
|
data: E::re_encrypt(self.data, ad, old_key, new_key)?,
|
||||||
@ -225,12 +262,14 @@ impl Record<EncryptedData> {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{DecryptedData, Record, RecordIndex};
|
use crate::record::HostId;
|
||||||
|
|
||||||
|
use super::{DecryptedData, Diff, Record, RecordIndex};
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
|
|
||||||
fn test_record() -> Record<DecryptedData> {
|
fn test_record() -> Record<DecryptedData> {
|
||||||
Record::builder()
|
Record::builder()
|
||||||
.host(crate::utils::uuid_v7().simple().to_string())
|
.host(HostId(crate::utils::uuid_v7()))
|
||||||
.version("v1".into())
|
.version("v1".into())
|
||||||
.tag(crate::utils::uuid_v7().simple().to_string())
|
.tag(crate::utils::uuid_v7().simple().to_string())
|
||||||
.data(DecryptedData(vec![0, 1, 2, 3]))
|
.data(DecryptedData(vec![0, 1, 2, 3]))
|
||||||
@ -304,7 +343,14 @@ mod tests {
|
|||||||
let diff = index1.diff(&index2);
|
let diff = index1.diff(&index2);
|
||||||
|
|
||||||
assert_eq!(1, diff.len(), "expected single diff");
|
assert_eq!(1, diff.len(), "expected single diff");
|
||||||
assert_eq!(diff[0], (record2.host, record2.tag, Some(record2.id)));
|
assert_eq!(
|
||||||
|
diff[0],
|
||||||
|
Diff {
|
||||||
|
host: record2.host,
|
||||||
|
tag: record2.tag,
|
||||||
|
tail: record2.id
|
||||||
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -342,12 +388,14 @@ mod tests {
|
|||||||
assert_eq!(4, diff1.len());
|
assert_eq!(4, diff1.len());
|
||||||
assert_eq!(4, diff2.len());
|
assert_eq!(4, diff2.len());
|
||||||
|
|
||||||
|
dbg!(&diff1, &diff2);
|
||||||
|
|
||||||
// both diffs should be ALMOST the same. They will agree on which hosts and tags
|
// both diffs should be ALMOST the same. They will agree on which hosts and tags
|
||||||
// require updating, but the "other" value will not be the same.
|
// require updating, but the "other" value will not be the same.
|
||||||
let smol_diff_1: Vec<(String, String)> =
|
let smol_diff_1: Vec<(HostId, String)> =
|
||||||
diff1.iter().map(|v| (v.0.clone(), v.1.clone())).collect();
|
diff1.iter().map(|v| (v.host, v.tag.clone())).collect();
|
||||||
let smol_diff_2: Vec<(String, String)> =
|
let smol_diff_2: Vec<(HostId, String)> =
|
||||||
diff1.iter().map(|v| (v.0.clone(), v.1.clone())).collect();
|
diff1.iter().map(|v| (v.host, v.tag.clone())).collect();
|
||||||
|
|
||||||
assert_eq!(smol_diff_1, smol_diff_2);
|
assert_eq!(smol_diff_1, smol_diff_2);
|
||||||
|
|
||||||
|
@ -13,7 +13,10 @@ use self::{
|
|||||||
models::{History, NewHistory, NewSession, NewUser, Session, User},
|
models::{History, NewHistory, NewSession, NewUser, Session, User},
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use atuin_common::utils::get_days_from_month;
|
use atuin_common::{
|
||||||
|
record::{EncryptedData, HostId, Record, RecordId, RecordIndex},
|
||||||
|
utils::get_days_from_month,
|
||||||
|
};
|
||||||
use chrono::{Datelike, TimeZone};
|
use chrono::{Datelike, TimeZone};
|
||||||
use chronoutil::RelativeDuration;
|
use chronoutil::RelativeDuration;
|
||||||
use serde::{de::DeserializeOwned, Serialize};
|
use serde::{de::DeserializeOwned, Serialize};
|
||||||
@ -55,6 +58,19 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
|
|||||||
async fn delete_history(&self, user: &User, id: String) -> DbResult<()>;
|
async fn delete_history(&self, user: &User, id: String) -> DbResult<()>;
|
||||||
async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>;
|
async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>;
|
||||||
|
|
||||||
|
async fn add_records(&self, user: &User, record: &[Record<EncryptedData>]) -> DbResult<()>;
|
||||||
|
async fn next_records(
|
||||||
|
&self,
|
||||||
|
user: &User,
|
||||||
|
host: HostId,
|
||||||
|
tag: String,
|
||||||
|
start: Option<RecordId>,
|
||||||
|
count: u64,
|
||||||
|
) -> DbResult<Vec<Record<EncryptedData>>>;
|
||||||
|
|
||||||
|
// Return the tail record ID for each store, so (HostID, Tag, TailRecordID)
|
||||||
|
async fn tail_records(&self, user: &User) -> DbResult<RecordIndex>;
|
||||||
|
|
||||||
async fn count_history_range(
|
async fn count_history_range(
|
||||||
&self,
|
&self,
|
||||||
user: &User,
|
user: &User,
|
||||||
|
@ -18,4 +18,5 @@ chrono = { workspace = true }
|
|||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
sqlx = { workspace = true }
|
sqlx = { workspace = true }
|
||||||
async-trait = { workspace = true }
|
async-trait = { workspace = true }
|
||||||
|
uuid = { workspace = true }
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
|
5
atuin-server-postgres/build.rs
Normal file
5
atuin-server-postgres/build.rs
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
// generated by `sqlx migrate build-script`
|
||||||
|
fn main() {
|
||||||
|
// trigger recompilation when a new migration is added
|
||||||
|
println!("cargo:rerun-if-changed=migrations");
|
||||||
|
}
|
15
atuin-server-postgres/migrations/20230623070418_records.sql
Normal file
15
atuin-server-postgres/migrations/20230623070418_records.sql
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
-- Add migration script here
|
||||||
|
create table records (
|
||||||
|
id uuid primary key, -- remember to use uuidv7 for happy indices <3
|
||||||
|
client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key
|
||||||
|
host uuid not null, -- a unique identifier for the host
|
||||||
|
parent uuid default null, -- the ID of the parent record, bearing in mind this is a linked list
|
||||||
|
timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision
|
||||||
|
version text not null,
|
||||||
|
tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host
|
||||||
|
data text not null, -- store the actual history data, encrypted. I don't wanna know!
|
||||||
|
cek text not null,
|
||||||
|
|
||||||
|
user_id bigint not null, -- allow multiple users
|
||||||
|
created_at timestamp not null default current_timestamp
|
||||||
|
);
|
@ -1,14 +1,14 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
|
||||||
use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
|
use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
|
||||||
use atuin_server_database::{Database, DbError, DbResult};
|
use atuin_server_database::{Database, DbError, DbResult};
|
||||||
use futures_util::TryStreamExt;
|
use futures_util::TryStreamExt;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::postgres::PgPoolOptions;
|
use sqlx::postgres::PgPoolOptions;
|
||||||
|
|
||||||
use sqlx::Row;
|
use sqlx::Row;
|
||||||
|
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
use wrappers::{DbHistory, DbSession, DbUser};
|
use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
|
||||||
|
|
||||||
mod wrappers;
|
mod wrappers;
|
||||||
|
|
||||||
@ -329,4 +329,102 @@ impl Database for Postgres {
|
|||||||
.map_err(fix_error)
|
.map_err(fix_error)
|
||||||
.map(|DbHistory(h)| h)
|
.map(|DbHistory(h)| h)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
|
||||||
|
let mut tx = self.pool.begin().await.map_err(fix_error)?;
|
||||||
|
|
||||||
|
for i in records {
|
||||||
|
let id = atuin_common::utils::uuid_v7();
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"insert into records
|
||||||
|
(id, client_id, host, parent, timestamp, version, tag, data, cek, user_id)
|
||||||
|
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||||
|
on conflict do nothing
|
||||||
|
",
|
||||||
|
)
|
||||||
|
.bind(id)
|
||||||
|
.bind(i.id)
|
||||||
|
.bind(i.host)
|
||||||
|
.bind(i.parent)
|
||||||
|
.bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
|
||||||
|
.bind(&i.version)
|
||||||
|
.bind(&i.tag)
|
||||||
|
.bind(&i.data.data)
|
||||||
|
.bind(&i.data.content_encryption_key)
|
||||||
|
.bind(user.id)
|
||||||
|
.execute(&mut tx)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.commit().await.map_err(fix_error)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn next_records(
|
||||||
|
&self,
|
||||||
|
user: &User,
|
||||||
|
host: HostId,
|
||||||
|
tag: String,
|
||||||
|
start: Option<RecordId>,
|
||||||
|
count: u64,
|
||||||
|
) -> DbResult<Vec<Record<EncryptedData>>> {
|
||||||
|
tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
|
||||||
|
let mut ret = Vec::with_capacity(count as usize);
|
||||||
|
let mut parent = start;
|
||||||
|
|
||||||
|
// yeah let's do something better
|
||||||
|
for _ in 0..count {
|
||||||
|
// a very much not ideal query. but it's simple at least?
|
||||||
|
// we are basically using postgres as a kv store here, so... maybe consider using an actual
|
||||||
|
// kv store?
|
||||||
|
let record: Result<DbRecord, DbError> = sqlx::query_as(
|
||||||
|
"select client_id, host, parent, timestamp, version, tag, data, cek from records
|
||||||
|
where user_id = $1
|
||||||
|
and tag = $2
|
||||||
|
and host = $3
|
||||||
|
and parent is not distinct from $4",
|
||||||
|
)
|
||||||
|
.bind(user.id)
|
||||||
|
.bind(tag.clone())
|
||||||
|
.bind(host)
|
||||||
|
.bind(parent)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(fix_error);
|
||||||
|
|
||||||
|
match record {
|
||||||
|
Ok(record) => {
|
||||||
|
let record: Record<EncryptedData> = record.into();
|
||||||
|
ret.push(record.clone());
|
||||||
|
|
||||||
|
parent = Some(record.id);
|
||||||
|
}
|
||||||
|
Err(DbError::NotFound) => {
|
||||||
|
tracing::debug!("hit tail of store: {:?}/{}", host, tag);
|
||||||
|
return Ok(ret);
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn tail_records(&self, user: &User) -> DbResult<RecordIndex> {
|
||||||
|
const TAIL_RECORDS_SQL: &str = "select host, tag, client_id from records rp where (select count(1) from records where parent=rp.client_id and user_id = $1) = 0;";
|
||||||
|
|
||||||
|
let res = sqlx::query_as(TAIL_RECORDS_SQL)
|
||||||
|
.bind(user.id)
|
||||||
|
.fetch(&self.pool)
|
||||||
|
.try_collect()
|
||||||
|
.await
|
||||||
|
.map_err(fix_error)?;
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
use ::sqlx::{FromRow, Result};
|
use ::sqlx::{FromRow, Result};
|
||||||
|
use atuin_common::record::{EncryptedData, Record};
|
||||||
use atuin_server_database::models::{History, Session, User};
|
use atuin_server_database::models::{History, Session, User};
|
||||||
use sqlx::{postgres::PgRow, Row};
|
use sqlx::{postgres::PgRow, Row};
|
||||||
|
|
||||||
pub struct DbUser(pub User);
|
pub struct DbUser(pub User);
|
||||||
pub struct DbSession(pub Session);
|
pub struct DbSession(pub Session);
|
||||||
pub struct DbHistory(pub History);
|
pub struct DbHistory(pub History);
|
||||||
|
pub struct DbRecord(pub Record<EncryptedData>);
|
||||||
|
|
||||||
impl<'a> FromRow<'a, PgRow> for DbUser {
|
impl<'a> FromRow<'a, PgRow> for DbUser {
|
||||||
fn from_row(row: &'a PgRow) -> Result<Self> {
|
fn from_row(row: &'a PgRow) -> Result<Self> {
|
||||||
@ -40,3 +42,30 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory {
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord {
|
||||||
|
fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
|
||||||
|
let timestamp: i64 = row.try_get("timestamp")?;
|
||||||
|
|
||||||
|
let data = EncryptedData {
|
||||||
|
data: row.try_get("data")?,
|
||||||
|
content_encryption_key: row.try_get("cek")?,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self(Record {
|
||||||
|
id: row.try_get("client_id")?,
|
||||||
|
host: row.try_get("host")?,
|
||||||
|
parent: row.try_get("parent")?,
|
||||||
|
timestamp: timestamp as u64,
|
||||||
|
version: row.try_get("version")?,
|
||||||
|
tag: row.try_get("tag")?,
|
||||||
|
data,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<DbRecord> for Record<EncryptedData> {
|
||||||
|
fn from(other: DbRecord) -> Record<EncryptedData> {
|
||||||
|
Record { ..other.0 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@ use atuin_common::api::{ErrorResponse, IndexResponse};
|
|||||||
use axum::{response::IntoResponse, Json};
|
use axum::{response::IntoResponse, Json};
|
||||||
|
|
||||||
pub mod history;
|
pub mod history;
|
||||||
|
pub mod record;
|
||||||
pub mod status;
|
pub mod status;
|
||||||
pub mod user;
|
pub mod user;
|
||||||
|
|
||||||
|
104
atuin-server/src/handlers/record.rs
Normal file
104
atuin-server/src/handlers/record.rs
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
use axum::{extract::Query, extract::State, Json};
|
||||||
|
use http::StatusCode;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tracing::{error, instrument};
|
||||||
|
|
||||||
|
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
|
||||||
|
use crate::router::{AppState, UserAuth};
|
||||||
|
use atuin_server_database::Database;
|
||||||
|
|
||||||
|
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
|
||||||
|
|
||||||
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
|
pub async fn post<DB: Database>(
|
||||||
|
UserAuth(user): UserAuth,
|
||||||
|
state: State<AppState<DB>>,
|
||||||
|
Json(records): Json<Vec<Record<EncryptedData>>>,
|
||||||
|
) -> Result<(), ErrorResponseStatus<'static>> {
|
||||||
|
let State(AppState { database, settings }) = state;
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
count = records.len(),
|
||||||
|
user = user.username,
|
||||||
|
"request to add records"
|
||||||
|
);
|
||||||
|
|
||||||
|
let too_big = records
|
||||||
|
.iter()
|
||||||
|
.any(|r| r.data.data.len() >= settings.max_record_size || settings.max_record_size == 0);
|
||||||
|
|
||||||
|
if too_big {
|
||||||
|
return Err(
|
||||||
|
ErrorResponse::reply("could not add records; record too large")
|
||||||
|
.with_status(StatusCode::BAD_REQUEST),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = database.add_records(&user, &records).await {
|
||||||
|
error!("failed to add record: {}", e);
|
||||||
|
|
||||||
|
return Err(ErrorResponse::reply("failed to add record")
|
||||||
|
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
|
pub async fn index<DB: Database>(
|
||||||
|
UserAuth(user): UserAuth,
|
||||||
|
state: State<AppState<DB>>,
|
||||||
|
) -> Result<Json<RecordIndex>, ErrorResponseStatus<'static>> {
|
||||||
|
let State(AppState {
|
||||||
|
database,
|
||||||
|
settings: _,
|
||||||
|
}) = state;
|
||||||
|
|
||||||
|
let record_index = match database.tail_records(&user).await {
|
||||||
|
Ok(index) => index,
|
||||||
|
Err(e) => {
|
||||||
|
error!("failed to get record index: {}", e);
|
||||||
|
|
||||||
|
return Err(ErrorResponse::reply("failed to calculate record index")
|
||||||
|
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Json(record_index))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct NextParams {
|
||||||
|
host: HostId,
|
||||||
|
tag: String,
|
||||||
|
start: Option<RecordId>,
|
||||||
|
count: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all, fields(user.id = user.id))]
|
||||||
|
pub async fn next<DB: Database>(
|
||||||
|
params: Query<NextParams>,
|
||||||
|
UserAuth(user): UserAuth,
|
||||||
|
state: State<AppState<DB>>,
|
||||||
|
) -> Result<Json<Vec<Record<EncryptedData>>>, ErrorResponseStatus<'static>> {
|
||||||
|
let State(AppState {
|
||||||
|
database,
|
||||||
|
settings: _,
|
||||||
|
}) = state;
|
||||||
|
let params = params.0;
|
||||||
|
|
||||||
|
let records = match database
|
||||||
|
.next_records(&user, params.host, params.tag, params.start, params.count)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(records) => records,
|
||||||
|
Err(e) => {
|
||||||
|
error!("failed to get record index: {}", e);
|
||||||
|
|
||||||
|
return Err(ErrorResponse::reply("failed to calculate record index")
|
||||||
|
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Json(records))
|
||||||
|
}
|
@ -71,6 +71,9 @@ pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> R
|
|||||||
.route("/sync/status", get(handlers::status::status))
|
.route("/sync/status", get(handlers::status::status))
|
||||||
.route("/history", post(handlers::history::add))
|
.route("/history", post(handlers::history::add))
|
||||||
.route("/history", delete(handlers::history::delete))
|
.route("/history", delete(handlers::history::delete))
|
||||||
|
.route("/record", post(handlers::record::post))
|
||||||
|
.route("/record", get(handlers::record::index))
|
||||||
|
.route("/record/next", get(handlers::record::next))
|
||||||
.route("/user/:username", get(handlers::user::get))
|
.route("/user/:username", get(handlers::user::get))
|
||||||
.route("/account", delete(handlers::user::delete))
|
.route("/account", delete(handlers::user::delete))
|
||||||
.route("/register", post(handlers::user::register))
|
.route("/register", post(handlers::user::register))
|
||||||
|
@ -12,6 +12,7 @@ pub struct Settings<DbSettings> {
|
|||||||
pub path: String,
|
pub path: String,
|
||||||
pub open_registration: bool,
|
pub open_registration: bool,
|
||||||
pub max_history_length: usize,
|
pub max_history_length: usize,
|
||||||
|
pub max_record_size: usize,
|
||||||
pub page_size: i64,
|
pub page_size: i64,
|
||||||
pub register_webhook_url: Option<String>,
|
pub register_webhook_url: Option<String>,
|
||||||
pub register_webhook_username: String,
|
pub register_webhook_username: String,
|
||||||
@ -39,6 +40,7 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
|
|||||||
.set_default("port", 8888)?
|
.set_default("port", 8888)?
|
||||||
.set_default("open_registration", false)?
|
.set_default("open_registration", false)?
|
||||||
.set_default("max_history_length", 8192)?
|
.set_default("max_history_length", 8192)?
|
||||||
|
.set_default("max_record_size", 1024 * 1024 * 1024)? // pretty chonky
|
||||||
.set_default("path", "")?
|
.set_default("path", "")?
|
||||||
.set_default("register_webhook_username", "")?
|
.set_default("register_webhook_username", "")?
|
||||||
.set_default("page_size", 1100)?
|
.set_default("page_size", 1100)?
|
||||||
|
@ -69,7 +69,7 @@ impl Cmd {
|
|||||||
Self::Search(search) => search.run(db, &mut settings).await,
|
Self::Search(search) => search.run(db, &mut settings).await,
|
||||||
|
|
||||||
#[cfg(feature = "sync")]
|
#[cfg(feature = "sync")]
|
||||||
Self::Sync(sync) => sync.run(settings, &mut db).await,
|
Self::Sync(sync) => sync.run(settings, &mut db, &mut store).await,
|
||||||
|
|
||||||
#[cfg(feature = "sync")]
|
#[cfg(feature = "sync")]
|
||||||
Self::Account(account) => account.run(settings).await,
|
Self::Account(account) => account.run(settings).await,
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
use clap::Subcommand;
|
use clap::Subcommand;
|
||||||
use eyre::{Result, WrapErr};
|
use eyre::{Result, WrapErr};
|
||||||
|
|
||||||
use atuin_client::{database::Database, settings::Settings};
|
use atuin_client::{
|
||||||
|
database::Database,
|
||||||
|
record::{store::Store, sync},
|
||||||
|
settings::Settings,
|
||||||
|
};
|
||||||
|
|
||||||
mod status;
|
mod status;
|
||||||
|
|
||||||
@ -37,9 +41,14 @@ pub enum Cmd {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Cmd {
|
impl Cmd {
|
||||||
pub async fn run(self, settings: Settings, db: &mut impl Database) -> Result<()> {
|
pub async fn run(
|
||||||
|
self,
|
||||||
|
settings: Settings,
|
||||||
|
db: &mut impl Database,
|
||||||
|
store: &mut (impl Store + Send + Sync),
|
||||||
|
) -> Result<()> {
|
||||||
match self {
|
match self {
|
||||||
Self::Sync { force } => run(&settings, force, db).await,
|
Self::Sync { force } => run(&settings, force, db, store).await,
|
||||||
Self::Login(l) => l.run(&settings).await,
|
Self::Login(l) => l.run(&settings).await,
|
||||||
Self::Logout => account::logout::run(&settings),
|
Self::Logout => account::logout::run(&settings),
|
||||||
Self::Register(r) => r.run(&settings).await,
|
Self::Register(r) => r.run(&settings).await,
|
||||||
@ -62,12 +71,26 @@ impl Cmd {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> {
|
async fn run(
|
||||||
|
settings: &Settings,
|
||||||
|
force: bool,
|
||||||
|
db: &mut impl Database,
|
||||||
|
store: &mut (impl Store + Send + Sync),
|
||||||
|
) -> Result<()> {
|
||||||
|
let (diff, remote_index) = sync::diff(settings, store).await?;
|
||||||
|
let operations = sync::operations(diff, store).await?;
|
||||||
|
let (uploaded, downloaded) =
|
||||||
|
sync::sync_remote(operations, &remote_index, store, settings).await?;
|
||||||
|
|
||||||
|
println!("{uploaded}/{downloaded} up/down to record store");
|
||||||
|
|
||||||
atuin_client::sync::sync(settings, force, db).await?;
|
atuin_client::sync::sync(settings, force, db).await?;
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"Sync complete! {} items in database, force: {}",
|
"Sync complete! {} items in history database, force: {}",
|
||||||
db.history_count().await?,
|
db.history_count().await?,
|
||||||
force
|
force
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user