mirror of
https://github.com/atuinsh/atuin.git
synced 2024-11-22 00:03:49 +01:00
feat: rework record sync for improved reliability (#1478)
* feat: rework record sync for improved reliability So, to tell a story 1. We introduced the record sync, intended to be the new algorithm to sync history. 2. On top of this, I added the KV store. This was intended as a simple test of the record sync, and to see if people wanted that sort of functionality 3. History remained syncing via the old means, as while it had issues it worked more-or-less OK. And we are aware of its flaws 4. If KV syncing worked ok, history would be moved across KV syncing ran ok for 6mo or so, so I started to move across history. For several weeks, I ran a local fork of Atuin + the server that synced via records instead. The record store maintained ordering via a linked list, which was a mistake. It performed well in testing, but was really difficult to debug and reason about. So when a few small sync issues occured, they took an extremely long time to debug. This PR is huge, which I regret. It involves replacing the "parent" relationship that records once had (pointing to the previous record) with a simple index (generally referred to as idx). This also means we had to change the recordindex, which referenced "tails". Tails were the last item in the chain. Now that we use an "array" vs linked list, that logic was also replaced. And is much simpler :D Same for the queries that act on this data. ---- This isn't final - we still need to add 1. Proper server/client error handling, which has been lacking for a while 2. The actual history implementation on top This exists in a branch, just without deletions. Won't be much to add that, I just don't want to make this any larger than it already is The _only_ caveat here is that we basically lose data synced via the old record store. This is the KV data from before. It hasn't been deleted or anything, just no longer hooked up. So it's totally possible to write a migration script. I just need to do that. * update .gitignore * use correct endpoint * fix for stores with length of 1 * use create/delete enum for history store * lint, remove unneeded host_id * remove prints * add command to import old history * add enable/disable switch for record sync * add record sync to auto sync * satisfy the almighty clippy * remove file that I did not mean to commit * feedback
This commit is contained in:
parent
604ae40b9d
commit
7bc6ccdd70
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,6 +1,8 @@
|
||||
.DS_Store
|
||||
/target
|
||||
*/target
|
||||
.env
|
||||
.idea/
|
||||
.vscode/
|
||||
result
|
||||
publish.sh
|
||||
|
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -242,6 +242,7 @@ dependencies = [
|
||||
"shellexpand",
|
||||
"sql-builder",
|
||||
"sqlx",
|
||||
"thiserror",
|
||||
"time",
|
||||
"tokio",
|
||||
"typed-builder",
|
||||
|
@ -43,6 +43,7 @@ uuid = { version = "1.3", features = ["v4", "v7", "serde"] }
|
||||
whoami = "1.1.2"
|
||||
typed-builder = "0.18.0"
|
||||
pretty_assertions = "1.3.0"
|
||||
thiserror = "1.0"
|
||||
|
||||
[workspace.dependencies.reqwest]
|
||||
version = "0.11"
|
||||
|
@ -48,6 +48,7 @@ rmp = { version = "0.8.11" }
|
||||
typed-builder = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
semver = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
futures = "0.3"
|
||||
crypto_secretbox = "0.1.1"
|
||||
generic-array = { version = "0.14", features = ["serde"] }
|
||||
|
@ -0,0 +1,15 @@
|
||||
-- Add migration script here
|
||||
create table if not exists store (
|
||||
id text primary key, -- globally unique ID
|
||||
|
||||
idx integer, -- incrementing integer ID unique per (host, tag)
|
||||
host text not null, -- references the host row
|
||||
tag text not null,
|
||||
|
||||
timestamp integer not null,
|
||||
version text not null,
|
||||
data blob not null,
|
||||
cek blob not null
|
||||
);
|
||||
|
||||
create unique index record_uniq ON store(host, tag, idx);
|
@ -13,11 +13,11 @@ use atuin_common::{
|
||||
AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse,
|
||||
LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse,
|
||||
},
|
||||
record::RecordIndex,
|
||||
record::RecordStatus,
|
||||
};
|
||||
use atuin_common::{
|
||||
api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION},
|
||||
record::{EncryptedData, HostId, Record, RecordId},
|
||||
record::{EncryptedData, HostId, Record, RecordIdx},
|
||||
};
|
||||
use semver::Version;
|
||||
use time::format_description::well_known::Rfc3339;
|
||||
@ -267,10 +267,18 @@ impl<'a> Client<'a> {
|
||||
}
|
||||
|
||||
pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> {
|
||||
let url = format!("{}/record", self.sync_addr);
|
||||
let url = format!("{}/api/v0/record", self.sync_addr);
|
||||
let url = Url::parse(url.as_str())?;
|
||||
|
||||
self.client.post(url).json(records).send().await?;
|
||||
let resp = self.client.post(url).json(records).send().await?;
|
||||
info!("posted records, got {}", resp.status());
|
||||
|
||||
if !resp.status().is_success() {
|
||||
error!(
|
||||
"failed to post records to server; got: {:?}",
|
||||
resp.text().await
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -279,24 +287,22 @@ impl<'a> Client<'a> {
|
||||
&self,
|
||||
host: HostId,
|
||||
tag: String,
|
||||
start: Option<RecordId>,
|
||||
start: RecordIdx,
|
||||
count: u64,
|
||||
) -> Result<Vec<Record<EncryptedData>>> {
|
||||
let url = format!(
|
||||
"{}/record/next?host={}&tag={}&count={}",
|
||||
self.sync_addr, host.0, tag, count
|
||||
debug!(
|
||||
"fetching record/s from host {}/{}/{}",
|
||||
host.0.to_string(),
|
||||
tag,
|
||||
start
|
||||
);
|
||||
let mut url = Url::parse(url.as_str())?;
|
||||
|
||||
if let Some(start) = start {
|
||||
url.set_query(Some(
|
||||
format!(
|
||||
"host={}&tag={}&count={}&start={}",
|
||||
host.0, tag, count, start.0
|
||||
)
|
||||
.as_str(),
|
||||
));
|
||||
}
|
||||
let url = format!(
|
||||
"{}/api/v0/record/next?host={}&tag={}&count={}&start={}",
|
||||
self.sync_addr, host.0, tag, count, start
|
||||
);
|
||||
|
||||
let url = Url::parse(url.as_str())?;
|
||||
|
||||
let resp = self.client.get(url).send().await?;
|
||||
|
||||
@ -305,8 +311,8 @@ impl<'a> Client<'a> {
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
pub async fn record_index(&self) -> Result<RecordIndex> {
|
||||
let url = format!("{}/record", self.sync_addr);
|
||||
pub async fn record_status(&self) -> Result<RecordStatus> {
|
||||
let url = format!("{}/api/v0/record", self.sync_addr);
|
||||
let url = Url::parse(url.as_str())?;
|
||||
|
||||
let resp = self.client.get(url).send().await?;
|
||||
@ -317,6 +323,8 @@ impl<'a> Client<'a> {
|
||||
|
||||
let index = resp.json().await?;
|
||||
|
||||
debug!("got remote index {:?}", index);
|
||||
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
|
@ -1,12 +1,21 @@
|
||||
use rmp::decode::ValueReadError;
|
||||
use rmp::{decode::Bytes, Marker};
|
||||
use std::env;
|
||||
|
||||
use atuin_common::record::DecryptedData;
|
||||
use atuin_common::utils::uuid_v7;
|
||||
|
||||
use eyre::{bail, eyre, Result};
|
||||
use regex::RegexSet;
|
||||
|
||||
use crate::{secrets::SECRET_PATTERNS, settings::Settings};
|
||||
use time::OffsetDateTime;
|
||||
|
||||
mod builder;
|
||||
pub mod store;
|
||||
|
||||
const HISTORY_VERSION: &str = "v0";
|
||||
const HISTORY_TAG: &str = "history";
|
||||
|
||||
/// Client-side history entry.
|
||||
///
|
||||
@ -81,6 +90,108 @@ impl History {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize(&self) -> Result<DecryptedData> {
|
||||
// This is pretty much the same as what we used for the old history, with one difference -
|
||||
// it uses integers for timestamps rather than a string format.
|
||||
|
||||
use rmp::encode;
|
||||
|
||||
let mut output = vec![];
|
||||
|
||||
// write the version
|
||||
encode::write_u16(&mut output, 0)?;
|
||||
// INFO: ensure this is updated when adding new fields
|
||||
encode::write_array_len(&mut output, 9)?;
|
||||
|
||||
encode::write_str(&mut output, &self.id)?;
|
||||
encode::write_u64(&mut output, self.timestamp.unix_timestamp_nanos() as u64)?;
|
||||
encode::write_sint(&mut output, self.duration)?;
|
||||
encode::write_sint(&mut output, self.exit)?;
|
||||
encode::write_str(&mut output, &self.command)?;
|
||||
encode::write_str(&mut output, &self.cwd)?;
|
||||
encode::write_str(&mut output, &self.session)?;
|
||||
encode::write_str(&mut output, &self.hostname)?;
|
||||
|
||||
match self.deleted_at {
|
||||
Some(d) => encode::write_u64(&mut output, d.unix_timestamp_nanos() as u64)?,
|
||||
None => encode::write_nil(&mut output)?,
|
||||
}
|
||||
|
||||
Ok(DecryptedData(output))
|
||||
}
|
||||
|
||||
fn deserialize_v0(bytes: &[u8]) -> Result<History> {
|
||||
use rmp::decode;
|
||||
|
||||
fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
|
||||
eyre!("{err:?}")
|
||||
}
|
||||
|
||||
let mut bytes = Bytes::new(bytes);
|
||||
|
||||
let version = decode::read_u16(&mut bytes).map_err(error_report)?;
|
||||
|
||||
if version != 0 {
|
||||
bail!("expected decoding v0 record, found v{version}");
|
||||
}
|
||||
|
||||
let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
|
||||
|
||||
if nfields != 9 {
|
||||
bail!("cannot decrypt history from a different version of Atuin");
|
||||
}
|
||||
|
||||
let bytes = bytes.remaining_slice();
|
||||
let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
|
||||
|
||||
let mut bytes = Bytes::new(bytes);
|
||||
let timestamp = decode::read_u64(&mut bytes).map_err(error_report)?;
|
||||
let duration = decode::read_int(&mut bytes).map_err(error_report)?;
|
||||
let exit = decode::read_int(&mut bytes).map_err(error_report)?;
|
||||
|
||||
let bytes = bytes.remaining_slice();
|
||||
let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
|
||||
let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
|
||||
let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
|
||||
let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
|
||||
|
||||
// if we have more fields, try and get the deleted_at
|
||||
let mut bytes = Bytes::new(bytes);
|
||||
|
||||
let (deleted_at, bytes) = match decode::read_u64(&mut bytes) {
|
||||
Ok(unix) => (Some(unix), bytes.remaining_slice()),
|
||||
// we accept null here
|
||||
Err(ValueReadError::TypeMismatch(Marker::Null)) => (None, bytes.remaining_slice()),
|
||||
Err(err) => return Err(error_report(err)),
|
||||
};
|
||||
|
||||
if !bytes.is_empty() {
|
||||
bail!("trailing bytes in encoded history. malformed")
|
||||
}
|
||||
|
||||
Ok(History {
|
||||
id: id.to_owned(),
|
||||
timestamp: OffsetDateTime::from_unix_timestamp_nanos(timestamp as i128)?,
|
||||
duration,
|
||||
exit,
|
||||
command: command.to_owned(),
|
||||
cwd: cwd.to_owned(),
|
||||
session: session.to_owned(),
|
||||
hostname: hostname.to_owned(),
|
||||
deleted_at: deleted_at
|
||||
.map(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128))
|
||||
.transpose()?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn deserialize(bytes: &[u8], version: &str) -> Result<History> {
|
||||
match version {
|
||||
HISTORY_VERSION => Self::deserialize_v0(bytes),
|
||||
|
||||
_ => bail!("unknown version {version:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for a history entry that is imported from shell history.
|
||||
///
|
||||
/// The only two required fields are `timestamp` and `command`.
|
||||
@ -202,8 +313,9 @@ impl History {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use regex::RegexSet;
|
||||
use time::macros::datetime;
|
||||
|
||||
use crate::settings::Settings;
|
||||
use crate::{history::HISTORY_VERSION, settings::Settings};
|
||||
|
||||
use super::History;
|
||||
|
||||
@ -274,4 +386,100 @@ mod tests {
|
||||
|
||||
assert!(stripe_key.should_save(&settings));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize_deserialize() {
|
||||
let bytes = [
|
||||
205, 0, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55,
|
||||
53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99,
|
||||
98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116,
|
||||
97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100,
|
||||
46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115,
|
||||
47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97,
|
||||
51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49,
|
||||
102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58,
|
||||
99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192,
|
||||
];
|
||||
|
||||
let history = History {
|
||||
id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned(),
|
||||
timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00),
|
||||
duration: 49206000,
|
||||
exit: 0,
|
||||
command: "git status".to_owned(),
|
||||
cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(),
|
||||
session: "b97d9a306f274473a203d2eba41f9457".to_owned(),
|
||||
hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(),
|
||||
deleted_at: None,
|
||||
};
|
||||
|
||||
let serialized = history.serialize().expect("failed to serialize history");
|
||||
assert_eq!(serialized.0, bytes);
|
||||
|
||||
let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION)
|
||||
.expect("failed to deserialize history");
|
||||
assert_eq!(history, deserialized);
|
||||
|
||||
// test the snapshot too
|
||||
let deserialized =
|
||||
History::deserialize(&bytes, HISTORY_VERSION).expect("failed to deserialize history");
|
||||
assert_eq!(history, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize_deserialize_deleted() {
|
||||
let history = History {
|
||||
id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned(),
|
||||
timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00),
|
||||
duration: 49206000,
|
||||
exit: 0,
|
||||
command: "git status".to_owned(),
|
||||
cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(),
|
||||
session: "b97d9a306f274473a203d2eba41f9457".to_owned(),
|
||||
hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(),
|
||||
deleted_at: Some(datetime!(2023-11-19 20:18 +00:00)),
|
||||
};
|
||||
|
||||
let serialized = history.serialize().expect("failed to serialize history");
|
||||
|
||||
let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION)
|
||||
.expect("failed to deserialize history");
|
||||
|
||||
assert_eq!(history, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize_deserialize_version() {
|
||||
// v0
|
||||
let bytes_v0 = [
|
||||
205, 0, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55,
|
||||
53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99,
|
||||
98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116,
|
||||
97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100,
|
||||
46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115,
|
||||
47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97,
|
||||
51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49,
|
||||
102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58,
|
||||
99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192,
|
||||
];
|
||||
|
||||
// some other version
|
||||
let bytes_v1 = [
|
||||
205, 1, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55,
|
||||
53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99,
|
||||
98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116,
|
||||
97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100,
|
||||
46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115,
|
||||
47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97,
|
||||
51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49,
|
||||
102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58,
|
||||
99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192,
|
||||
];
|
||||
|
||||
let deserialized = History::deserialize(&bytes_v0, HISTORY_VERSION);
|
||||
assert!(deserialized.is_ok());
|
||||
|
||||
let deserialized = History::deserialize(&bytes_v1, HISTORY_VERSION);
|
||||
assert!(deserialized.is_err());
|
||||
}
|
||||
}
|
||||
|
219
atuin-client/src/history/store.rs
Normal file
219
atuin-client/src/history/store.rs
Normal file
@ -0,0 +1,219 @@
|
||||
use eyre::{bail, eyre, Result};
|
||||
use rmp::decode::Bytes;
|
||||
|
||||
use crate::record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store};
|
||||
use atuin_common::record::{DecryptedData, Host, HostId, Record, RecordIdx};
|
||||
|
||||
use super::{History, HISTORY_TAG, HISTORY_VERSION};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HistoryStore {
|
||||
pub store: SqliteStore,
|
||||
pub host_id: HostId,
|
||||
pub encryption_key: [u8; 32],
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||
pub enum HistoryRecord {
|
||||
Create(History), // Create a history record
|
||||
Delete(String), // Delete a history record, identified by ID
|
||||
}
|
||||
|
||||
impl HistoryRecord {
|
||||
/// Serialize a history record, returning DecryptedData
|
||||
/// The record will be of a certain type
|
||||
/// We map those like so:
|
||||
///
|
||||
/// HistoryRecord::Create -> 0
|
||||
/// HistoryRecord::Delete-> 1
|
||||
///
|
||||
/// This numeric identifier is then written as the first byte to the buffer. For history, we
|
||||
/// append the serialized history right afterwards, to avoid having to handle serialization
|
||||
/// twice.
|
||||
///
|
||||
/// Deletion simply refers to the history by ID
|
||||
pub fn serialize(&self) -> Result<DecryptedData> {
|
||||
// probably don't actually need to use rmp here, but if we ever need to extend it, it's a
|
||||
// nice wrapper around raw byte stuff
|
||||
use rmp::encode;
|
||||
|
||||
let mut output = vec![];
|
||||
|
||||
match self {
|
||||
HistoryRecord::Create(history) => {
|
||||
// 0 -> a history create
|
||||
encode::write_u8(&mut output, 0)?;
|
||||
|
||||
let bytes = history.serialize()?;
|
||||
|
||||
encode::write_bin(&mut output, &bytes.0)?;
|
||||
}
|
||||
HistoryRecord::Delete(id) => {
|
||||
// 1 -> a history delete
|
||||
encode::write_u8(&mut output, 1)?;
|
||||
encode::write_str(&mut output, id)?;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(DecryptedData(output))
|
||||
}
|
||||
|
||||
pub fn deserialize(bytes: &[u8], version: &str) -> Result<Self> {
|
||||
use rmp::decode;
|
||||
|
||||
fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
|
||||
eyre!("{err:?}")
|
||||
}
|
||||
|
||||
let mut bytes = Bytes::new(bytes);
|
||||
|
||||
let record_type = decode::read_u8(&mut bytes).map_err(error_report)?;
|
||||
|
||||
match record_type {
|
||||
// 0 -> HistoryRecord::Create
|
||||
0 => {
|
||||
// not super useful to us atm, but perhaps in the future
|
||||
// written by write_bin above
|
||||
let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?;
|
||||
|
||||
let record = History::deserialize(bytes.remaining_slice(), version)?;
|
||||
|
||||
Ok(HistoryRecord::Create(record))
|
||||
}
|
||||
|
||||
// 1 -> HistoryRecord::Delete
|
||||
1 => {
|
||||
let bytes = bytes.remaining_slice();
|
||||
let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
|
||||
|
||||
if !bytes.is_empty() {
|
||||
bail!(
|
||||
"trailing bytes decoding HistoryRecord::Delete - malformed? got {bytes:?}"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(HistoryRecord::Delete(id.to_string()))
|
||||
}
|
||||
|
||||
n => {
|
||||
bail!("unknown HistoryRecord type {n}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HistoryStore {
|
||||
pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> Self {
|
||||
HistoryStore {
|
||||
store,
|
||||
host_id,
|
||||
encryption_key,
|
||||
}
|
||||
}
|
||||
|
||||
async fn push_record(&self, record: HistoryRecord) -> Result<RecordIdx> {
|
||||
let bytes = record.serialize()?;
|
||||
let idx = self
|
||||
.store
|
||||
.last(self.host_id, HISTORY_TAG)
|
||||
.await?
|
||||
.map_or(0, |p| p.idx + 1);
|
||||
|
||||
let record = Record::builder()
|
||||
.host(Host::new(self.host_id))
|
||||
.version(HISTORY_VERSION.to_string())
|
||||
.tag(HISTORY_TAG.to_string())
|
||||
.idx(idx)
|
||||
.data(bytes)
|
||||
.build();
|
||||
|
||||
self.store
|
||||
.push(&record.encrypt::<PASETO_V4>(&self.encryption_key))
|
||||
.await?;
|
||||
|
||||
Ok(idx)
|
||||
}
|
||||
|
||||
pub async fn delete(&self, id: String) -> Result<RecordIdx> {
|
||||
let record = HistoryRecord::Delete(id);
|
||||
|
||||
self.push_record(record).await
|
||||
}
|
||||
|
||||
pub async fn push(&self, history: History) -> Result<RecordIdx> {
|
||||
// TODO(ellie): move the history store to its own file
|
||||
// it's tiny rn so fine as is
|
||||
let record = HistoryRecord::Create(history);
|
||||
|
||||
self.push_record(record).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use time::macros::datetime;
|
||||
|
||||
use crate::history::{store::HistoryRecord, HISTORY_VERSION};
|
||||
|
||||
use super::History;
|
||||
|
||||
#[test]
|
||||
fn test_serialize_deserialize_create() {
|
||||
let bytes = [
|
||||
204, 0, 196, 141, 205, 0, 0, 153, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49,
|
||||
55, 53, 55, 99, 100, 50, 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99,
|
||||
56, 49, 207, 23, 166, 251, 212, 181, 82, 0, 0, 100, 0, 162, 108, 115, 217, 41, 47, 85,
|
||||
115, 101, 114, 115, 47, 101, 108, 108, 105, 101, 47, 115, 114, 99, 47, 103, 105, 116,
|
||||
104, 117, 98, 46, 99, 111, 109, 47, 97, 116, 117, 105, 110, 115, 104, 47, 97, 116, 117,
|
||||
105, 110, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 97, 100, 56, 57, 55, 53, 57, 55,
|
||||
56, 53, 50, 53, 50, 55, 97, 51, 49, 99, 57, 57, 56, 48, 53, 57, 170, 98, 111, 111, 112,
|
||||
58, 101, 108, 108, 105, 101, 192,
|
||||
];
|
||||
|
||||
let history = History {
|
||||
id: "018cd4fe81757cd2aee65cd7861f9c81".to_owned(),
|
||||
timestamp: datetime!(2024-01-04 00:00:00.000000 +00:00),
|
||||
duration: 100,
|
||||
exit: 0,
|
||||
command: "ls".to_owned(),
|
||||
cwd: "/Users/ellie/src/github.com/atuinsh/atuin".to_owned(),
|
||||
session: "018cd4fead897597852527a31c998059".to_owned(),
|
||||
hostname: "boop:ellie".to_owned(),
|
||||
deleted_at: None,
|
||||
};
|
||||
|
||||
let record = HistoryRecord::Create(history);
|
||||
|
||||
let serialized = record.serialize().expect("failed to serialize history");
|
||||
assert_eq!(serialized.0, bytes);
|
||||
|
||||
let deserialized = HistoryRecord::deserialize(&serialized.0, HISTORY_VERSION)
|
||||
.expect("failed to deserialize HistoryRecord");
|
||||
assert_eq!(deserialized, record);
|
||||
|
||||
// check the snapshot too
|
||||
let deserialized = HistoryRecord::deserialize(&bytes, HISTORY_VERSION)
|
||||
.expect("failed to deserialize HistoryRecord");
|
||||
assert_eq!(deserialized, record);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize_deserialize_delete() {
|
||||
let bytes = [
|
||||
204, 1, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, 55, 53, 55, 99, 100, 50,
|
||||
97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, 56, 49,
|
||||
];
|
||||
let record = HistoryRecord::Delete("018cd4fe81757cd2aee65cd7861f9c81".to_string());
|
||||
|
||||
let serialized = record.serialize().expect("failed to serialize history");
|
||||
assert_eq!(serialized.0, bytes);
|
||||
|
||||
let deserialized = HistoryRecord::deserialize(&serialized.0, HISTORY_VERSION)
|
||||
.expect("failed to deserialize HistoryRecord");
|
||||
assert_eq!(deserialized, record);
|
||||
|
||||
let deserialized = HistoryRecord::deserialize(&bytes, HISTORY_VERSION)
|
||||
.expect("failed to deserialize HistoryRecord");
|
||||
assert_eq!(deserialized, record);
|
||||
}
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use atuin_common::record::{DecryptedData, HostId};
|
||||
use atuin_common::record::{DecryptedData, Host, HostId};
|
||||
use eyre::{bail, ensure, eyre, Result};
|
||||
use serde::Deserialize;
|
||||
|
||||
@ -89,7 +89,7 @@ impl KvStore {
|
||||
|
||||
pub async fn set(
|
||||
&self,
|
||||
store: &mut (impl Store + Send + Sync),
|
||||
store: &(impl Store + Send + Sync),
|
||||
encryption_key: &[u8; 32],
|
||||
host_id: HostId,
|
||||
namespace: &str,
|
||||
@ -111,13 +111,16 @@ impl KvStore {
|
||||
|
||||
let bytes = record.serialize()?;
|
||||
|
||||
let parent = store.tail(host_id, KV_TAG).await?.map(|entry| entry.id);
|
||||
let idx = store
|
||||
.last(host_id, KV_TAG)
|
||||
.await?
|
||||
.map_or(0, |entry| entry.idx + 1);
|
||||
|
||||
let record = atuin_common::record::Record::builder()
|
||||
.host(host_id)
|
||||
.host(Host::new(host_id))
|
||||
.version(KV_VERSION.to_string())
|
||||
.tag(KV_TAG.to_string())
|
||||
.parent(parent)
|
||||
.idx(idx)
|
||||
.data(bytes)
|
||||
.build();
|
||||
|
||||
@ -137,43 +140,18 @@ impl KvStore {
|
||||
namespace: &str,
|
||||
key: &str,
|
||||
) -> Result<Option<KvRecord>> {
|
||||
// Currently, this is O(n). When we have an actual KV store, it can be better
|
||||
// Just a poc for now!
|
||||
// TODO: don't rebuild every time...
|
||||
let map = self.build_kv(store, encryption_key).await?;
|
||||
|
||||
// iterate records to find the value we want
|
||||
// start at the end, so we get the most recent version
|
||||
let tails = store.tag_tails(KV_TAG).await?;
|
||||
let res = map.get(namespace);
|
||||
|
||||
if tails.is_empty() {
|
||||
return Ok(None);
|
||||
if let Some(ns) = res {
|
||||
let value = ns.get(key);
|
||||
|
||||
Ok(value.cloned())
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
// first, decide on a record.
|
||||
// try getting the newest first
|
||||
// we always need a way of deciding the "winner" of a write
|
||||
// TODO(ellie): something better than last-write-wins, what if two write at the same time?
|
||||
let mut record = tails.iter().max_by_key(|r| r.timestamp).unwrap().clone();
|
||||
|
||||
loop {
|
||||
let decrypted = match record.version.as_str() {
|
||||
KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?,
|
||||
version => bail!("unknown version {version:?}"),
|
||||
};
|
||||
|
||||
let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?;
|
||||
if kv.key == key && kv.namespace == namespace {
|
||||
return Ok(Some(kv));
|
||||
}
|
||||
|
||||
if let Some(parent) = decrypted.parent {
|
||||
record = store.get(parent).await?;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// if we get here, then... we didn't find the record with that key :(
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
// Build a kv map out of the linked list kv store
|
||||
@ -184,32 +162,30 @@ impl KvStore {
|
||||
&self,
|
||||
store: &impl Store,
|
||||
encryption_key: &[u8; 32],
|
||||
) -> Result<BTreeMap<String, BTreeMap<String, String>>> {
|
||||
) -> Result<BTreeMap<String, BTreeMap<String, KvRecord>>> {
|
||||
let mut map = BTreeMap::new();
|
||||
let tails = store.tag_tails(KV_TAG).await?;
|
||||
|
||||
if tails.is_empty() {
|
||||
return Ok(map);
|
||||
}
|
||||
// TODO: maybe don't load the entire tag into memory to build the kv
|
||||
// we can be smart about it and only load values since the last build
|
||||
// or, iterate/paginate
|
||||
let tagged = store.all_tagged(KV_TAG).await?;
|
||||
|
||||
let mut record = tails.iter().max_by_key(|r| r.timestamp).unwrap().clone();
|
||||
|
||||
loop {
|
||||
// iterate through all tags and play each KV record at a time
|
||||
// this is "last write wins"
|
||||
// probably good enough for now, but revisit in future
|
||||
for record in tagged {
|
||||
let decrypted = match record.version.as_str() {
|
||||
KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?,
|
||||
version => bail!("unknown version {version:?}"),
|
||||
};
|
||||
|
||||
let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?;
|
||||
let kv = KvRecord::deserialize(&decrypted.data, KV_VERSION)?;
|
||||
|
||||
let ns = map.entry(kv.namespace).or_insert_with(BTreeMap::new);
|
||||
ns.entry(kv.key).or_insert_with(|| kv.value);
|
||||
let ns = map
|
||||
.entry(kv.namespace.clone())
|
||||
.or_insert_with(BTreeMap::new);
|
||||
|
||||
if let Some(parent) = decrypted.parent {
|
||||
record = store.get(parent).await?;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
ns.insert(kv.key.clone(), kv);
|
||||
}
|
||||
|
||||
Ok(map)
|
||||
@ -261,19 +237,27 @@ mod tests {
|
||||
let map = kv.build_kv(&store, &key).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
map.get("test-kv")
|
||||
*map.get("test-kv")
|
||||
.expect("map namespace not set")
|
||||
.get("foo")
|
||||
.expect("map key not set"),
|
||||
"bar"
|
||||
KvRecord {
|
||||
namespace: String::from("test-kv"),
|
||||
key: String::from("foo"),
|
||||
value: String::from("bar")
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
map.get("test-kv")
|
||||
*map.get("test-kv")
|
||||
.expect("map namespace not set")
|
||||
.get("1")
|
||||
.expect("map key not set"),
|
||||
"2"
|
||||
KvRecord {
|
||||
namespace: String::from("test-kv"),
|
||||
key: String::from("1"),
|
||||
value: String::from("2")
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use atuin_common::record::{
|
||||
AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId,
|
||||
AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, RecordIdx,
|
||||
};
|
||||
use base64::{engine::general_purpose, Engine};
|
||||
use eyre::{ensure, Context, Result};
|
||||
@ -170,10 +170,10 @@ struct AtuinFooter {
|
||||
#[derive(Debug, Copy, Clone, Serialize)]
|
||||
struct Assertions<'a> {
|
||||
id: &'a RecordId,
|
||||
idx: &'a RecordIdx,
|
||||
version: &'a str,
|
||||
tag: &'a str,
|
||||
host: &'a HostId,
|
||||
parent: Option<&'a RecordId>,
|
||||
}
|
||||
|
||||
impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
|
||||
@ -183,7 +183,7 @@ impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
|
||||
version: ad.version,
|
||||
tag: ad.tag,
|
||||
host: ad.host,
|
||||
parent: ad.parent,
|
||||
idx: ad.idx,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -196,7 +196,10 @@ impl Assertions<'_> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use atuin_common::{record::Record, utils::uuid_v7};
|
||||
use atuin_common::{
|
||||
record::{Host, Record},
|
||||
utils::uuid_v7,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
@ -209,7 +212,7 @@ mod tests {
|
||||
version: "v0",
|
||||
tag: "kv",
|
||||
host: &HostId(uuid_v7()),
|
||||
parent: None,
|
||||
idx: &0,
|
||||
};
|
||||
|
||||
let data = DecryptedData(vec![1, 2, 3, 4]);
|
||||
@ -228,7 +231,7 @@ mod tests {
|
||||
version: "v0",
|
||||
tag: "kv",
|
||||
host: &HostId(uuid_v7()),
|
||||
parent: None,
|
||||
idx: &0,
|
||||
};
|
||||
|
||||
let data = DecryptedData(vec![1, 2, 3, 4]);
|
||||
@ -252,7 +255,7 @@ mod tests {
|
||||
version: "v0",
|
||||
tag: "kv",
|
||||
host: &HostId(uuid_v7()),
|
||||
parent: None,
|
||||
idx: &0,
|
||||
};
|
||||
|
||||
let data = DecryptedData(vec![1, 2, 3, 4]);
|
||||
@ -270,7 +273,7 @@ mod tests {
|
||||
version: "v0",
|
||||
tag: "kv",
|
||||
host: &HostId(uuid_v7()),
|
||||
parent: None,
|
||||
idx: &0,
|
||||
};
|
||||
|
||||
let data = DecryptedData(vec![1, 2, 3, 4]);
|
||||
@ -294,7 +297,7 @@ mod tests {
|
||||
version: "v0",
|
||||
tag: "kv",
|
||||
host: &HostId(uuid_v7()),
|
||||
parent: None,
|
||||
idx: &0,
|
||||
};
|
||||
|
||||
let data = DecryptedData(vec![1, 2, 3, 4]);
|
||||
@ -323,9 +326,10 @@ mod tests {
|
||||
.id(RecordId(uuid_v7()))
|
||||
.version("v0".to_owned())
|
||||
.tag("kv".to_owned())
|
||||
.host(HostId(uuid_v7()))
|
||||
.host(Host::new(HostId(uuid_v7())))
|
||||
.timestamp(1687244806000000)
|
||||
.data(DecryptedData(vec![1, 2, 3, 4]))
|
||||
.idx(0)
|
||||
.build();
|
||||
|
||||
let encrypted = record.encrypt::<PASETO_V4>(&key);
|
||||
@ -345,15 +349,16 @@ mod tests {
|
||||
.id(RecordId(uuid_v7()))
|
||||
.version("v0".to_owned())
|
||||
.tag("kv".to_owned())
|
||||
.host(HostId(uuid_v7()))
|
||||
.host(Host::new(HostId(uuid_v7())))
|
||||
.timestamp(1687244806000000)
|
||||
.data(DecryptedData(vec![1, 2, 3, 4]))
|
||||
.idx(0)
|
||||
.build();
|
||||
|
||||
let encrypted = record.encrypt::<PASETO_V4>(&key);
|
||||
|
||||
let mut enc1 = encrypted.clone();
|
||||
enc1.host = HostId(uuid_v7());
|
||||
enc1.host = Host::new(HostId(uuid_v7()));
|
||||
let _ = enc1
|
||||
.decrypt::<PASETO_V4>(&key)
|
||||
.expect_err("tampering with the host should result in auth failure");
|
||||
|
@ -8,17 +8,20 @@ use std::str::FromStr;
|
||||
use async_trait::async_trait;
|
||||
use eyre::{eyre, Result};
|
||||
use fs_err as fs;
|
||||
use futures::TryStreamExt;
|
||||
|
||||
use sqlx::{
|
||||
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
|
||||
Row,
|
||||
};
|
||||
|
||||
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
|
||||
use atuin_common::record::{
|
||||
EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::store::Store;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SqliteStore {
|
||||
pool: SqlitePool,
|
||||
}
|
||||
@ -38,6 +41,7 @@ impl SqliteStore {
|
||||
|
||||
let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())?
|
||||
.journal_mode(SqliteJournalMode::Wal)
|
||||
.foreign_keys(true)
|
||||
.create_if_missing(true);
|
||||
|
||||
let pool = SqlitePoolOptions::new().connect_with(opts).await?;
|
||||
@ -61,14 +65,14 @@ impl SqliteStore {
|
||||
) -> Result<()> {
|
||||
// In sqlite, we are "limited" to i64. But that is still fine, until 2262.
|
||||
sqlx::query(
|
||||
"insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek)
|
||||
"insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek)
|
||||
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
|
||||
)
|
||||
.bind(r.id.0.as_simple().to_string())
|
||||
.bind(r.host.0.as_simple().to_string())
|
||||
.bind(r.id.0.as_hyphenated().to_string())
|
||||
.bind(r.idx as i64)
|
||||
.bind(r.host.id.0.as_hyphenated().to_string())
|
||||
.bind(r.tag.as_str())
|
||||
.bind(r.timestamp as i64)
|
||||
.bind(r.parent.map(|p| p.0.as_simple().to_string()))
|
||||
.bind(r.version.as_str())
|
||||
.bind(r.data.data.as_str())
|
||||
.bind(r.data.content_encryption_key.as_str())
|
||||
@ -79,20 +83,17 @@ impl SqliteStore {
|
||||
}
|
||||
|
||||
fn query_row(row: SqliteRow) -> Record<EncryptedData> {
|
||||
let idx: i64 = row.get("idx");
|
||||
let timestamp: i64 = row.get("timestamp");
|
||||
|
||||
// tbh at this point things are pretty fucked so just panic
|
||||
let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
|
||||
let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
|
||||
let parent: Option<&str> = row.get("parent");
|
||||
|
||||
let parent = parent
|
||||
.map(|parent| Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB"));
|
||||
|
||||
Record {
|
||||
id: RecordId(id),
|
||||
host: HostId(host),
|
||||
parent: parent.map(RecordId),
|
||||
idx: idx as u64,
|
||||
host: Host::new(HostId(host)),
|
||||
timestamp: timestamp as u64,
|
||||
tag: row.get("tag"),
|
||||
version: row.get("version"),
|
||||
@ -122,8 +123,8 @@ impl Store for SqliteStore {
|
||||
}
|
||||
|
||||
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
|
||||
let res = sqlx::query("select * from records where id = ?1")
|
||||
.bind(id.0.as_simple().to_string())
|
||||
let res = sqlx::query("select * from store where store.id = ?1")
|
||||
.bind(id.0.as_hyphenated().to_string())
|
||||
.map(Self::query_row)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
@ -131,20 +132,66 @@ impl Store for SqliteStore {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
|
||||
let res: (i64,) =
|
||||
sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
|
||||
.bind(host.0.as_simple().to_string())
|
||||
async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
|
||||
let res =
|
||||
sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1")
|
||||
.bind(host.0.as_hyphenated().to_string())
|
||||
.bind(tag)
|
||||
.map(Self::query_row)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
.await;
|
||||
|
||||
Ok(res.0 as u64)
|
||||
match res {
|
||||
Err(sqlx::Error::RowNotFound) => Ok(None),
|
||||
Err(e) => Err(eyre!("an error occured: {}", e)),
|
||||
Ok(record) => Ok(Some(record)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> {
|
||||
let res = sqlx::query("select * from records where parent = ?1")
|
||||
.bind(record.id.0.as_simple().to_string())
|
||||
async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
|
||||
self.idx(host, tag, 0).await
|
||||
}
|
||||
|
||||
async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
|
||||
let last = self.last(host, tag).await?;
|
||||
|
||||
if let Some(last) = last {
|
||||
return Ok(last.idx + 1);
|
||||
}
|
||||
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
async fn next(
|
||||
&self,
|
||||
host: HostId,
|
||||
tag: &str,
|
||||
idx: RecordIdx,
|
||||
limit: u64,
|
||||
) -> Result<Vec<Record<EncryptedData>>> {
|
||||
let res =
|
||||
sqlx::query("select * from store where idx >= ?1 and host = ?2 and tag = ?3 limit ?4")
|
||||
.bind(idx as i64)
|
||||
.bind(host.0.as_hyphenated().to_string())
|
||||
.bind(tag)
|
||||
.bind(limit as i64)
|
||||
.map(Self::query_row)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn idx(
|
||||
&self,
|
||||
host: HostId,
|
||||
tag: &str,
|
||||
idx: RecordIdx,
|
||||
) -> Result<Option<Record<EncryptedData>>> {
|
||||
let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3")
|
||||
.bind(idx as i64)
|
||||
.bind(host.0.as_hyphenated().to_string())
|
||||
.bind(tag)
|
||||
.map(Self::query_row)
|
||||
.fetch_one(&self.pool)
|
||||
.await;
|
||||
@ -156,58 +203,36 @@ impl Store for SqliteStore {
|
||||
}
|
||||
}
|
||||
|
||||
async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
|
||||
let res = sqlx::query(
|
||||
"select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
|
||||
)
|
||||
.bind(host.0.as_simple().to_string())
|
||||
.bind(tag)
|
||||
.map(Self::query_row)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
async fn status(&self) -> Result<RecordStatus> {
|
||||
let mut status = RecordStatus::new();
|
||||
|
||||
Ok(res)
|
||||
let res: Result<Vec<(String, String, i64)>, sqlx::Error> =
|
||||
sqlx::query_as("select host, tag, max(idx) from store group by host, tag")
|
||||
.fetch_all(&self.pool)
|
||||
.await;
|
||||
|
||||
let res = match res {
|
||||
Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)),
|
||||
Ok(v) => v,
|
||||
};
|
||||
|
||||
for i in res {
|
||||
let host = HostId(
|
||||
Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"),
|
||||
);
|
||||
|
||||
status.set_raw(host, i.1, i.2 as u64);
|
||||
}
|
||||
|
||||
Ok(status)
|
||||
}
|
||||
|
||||
async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
|
||||
let res = sqlx::query(
|
||||
"select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;",
|
||||
)
|
||||
.bind(tag)
|
||||
.bind(host.0.as_simple().to_string())
|
||||
.map(Self::query_row)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
|
||||
let res = sqlx::query(
|
||||
"select * from records rp where tag=?1 and (select count(1) from records where parent=rp.id) = 0;",
|
||||
)
|
||||
.bind(tag)
|
||||
.map(Self::query_row)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn tail_records(&self) -> Result<RecordIndex> {
|
||||
let res = sqlx::query(
|
||||
"select host, tag, id from records rp where (select count(1) from records where parent=rp.id) = 0;",
|
||||
)
|
||||
.map(|row: SqliteRow| {
|
||||
let host: Uuid= Uuid::from_str(row.get("host")).expect("invalid uuid in db host");
|
||||
let tag: String= row.get("tag");
|
||||
let id: Uuid= Uuid::from_str(row.get("id")).expect("invalid uuid in db id");
|
||||
|
||||
(HostId(host), tag, RecordId(id))
|
||||
})
|
||||
.fetch(&self.pool)
|
||||
.try_collect()
|
||||
.await?;
|
||||
async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
|
||||
let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc")
|
||||
.bind(tag)
|
||||
.map(Self::query_row)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
@ -215,7 +240,7 @@ impl Store for SqliteStore {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use atuin_common::record::{EncryptedData, HostId, Record};
|
||||
use atuin_common::record::{EncryptedData, Host, HostId, Record};
|
||||
|
||||
use crate::record::{encryption::PASETO_V4, store::Store};
|
||||
|
||||
@ -223,13 +248,14 @@ mod tests {
|
||||
|
||||
fn test_record() -> Record<EncryptedData> {
|
||||
Record::builder()
|
||||
.host(HostId(atuin_common::utils::uuid_v7()))
|
||||
.host(Host::new(HostId(atuin_common::utils::uuid_v7())))
|
||||
.version("v1".into())
|
||||
.tag(atuin_common::utils::uuid_v7().simple().to_string())
|
||||
.data(EncryptedData {
|
||||
data: "1234".into(),
|
||||
content_encryption_key: "1234".into(),
|
||||
})
|
||||
.idx(0)
|
||||
.build()
|
||||
}
|
||||
|
||||
@ -263,6 +289,42 @@ mod tests {
|
||||
assert_eq!(record, new_record, "records are not equal");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn last() {
|
||||
let db = SqliteStore::new(":memory:").await.unwrap();
|
||||
let record = test_record();
|
||||
db.push(&record).await.unwrap();
|
||||
|
||||
let last = db
|
||||
.last(record.host.id, record.tag.as_str())
|
||||
.await
|
||||
.expect("failed to get store len");
|
||||
|
||||
assert_eq!(
|
||||
last.unwrap().id,
|
||||
record.id,
|
||||
"expected to get back the same record that was inserted"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn first() {
|
||||
let db = SqliteStore::new(":memory:").await.unwrap();
|
||||
let record = test_record();
|
||||
db.push(&record).await.unwrap();
|
||||
|
||||
let first = db
|
||||
.first(record.host.id, record.tag.as_str())
|
||||
.await
|
||||
.expect("failed to get store len");
|
||||
|
||||
assert_eq!(
|
||||
first.unwrap().id,
|
||||
record.id,
|
||||
"expected to get back the same record that was inserted"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn len() {
|
||||
let db = SqliteStore::new(":memory:").await.unwrap();
|
||||
@ -270,7 +332,7 @@ mod tests {
|
||||
db.push(&record).await.unwrap();
|
||||
|
||||
let len = db
|
||||
.len(record.host, record.tag.as_str())
|
||||
.len(record.host.id, record.tag.as_str())
|
||||
.await
|
||||
.expect("failed to get store len");
|
||||
|
||||
@ -290,8 +352,8 @@ mod tests {
|
||||
db.push(&first).await.unwrap();
|
||||
db.push(&second).await.unwrap();
|
||||
|
||||
let first_len = db.len(first.host, first.tag.as_str()).await.unwrap();
|
||||
let second_len = db.len(second.host, second.tag.as_str()).await.unwrap();
|
||||
let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap();
|
||||
let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap();
|
||||
|
||||
assert_eq!(first_len, 1, "expected length of 1 after insert");
|
||||
assert_eq!(second_len, 1, "expected length of 1 after insert");
|
||||
@ -305,14 +367,12 @@ mod tests {
|
||||
db.push(&tail).await.expect("failed to push record");
|
||||
|
||||
for _ in 1..100 {
|
||||
tail = tail
|
||||
.new_child(vec![1, 2, 3, 4])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]);
|
||||
db.push(&tail).await.unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
db.len(tail.host, tail.tag.as_str()).await.unwrap(),
|
||||
db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
|
||||
100,
|
||||
"failed to insert 100 records"
|
||||
);
|
||||
@ -328,50 +388,16 @@ mod tests {
|
||||
records.push(tail.clone());
|
||||
|
||||
for _ in 1..10000 {
|
||||
tail = tail.new_child(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
|
||||
tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
|
||||
records.push(tail.clone());
|
||||
}
|
||||
|
||||
db.push_batch(records.iter()).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
db.len(tail.host, tail.tag.as_str()).await.unwrap(),
|
||||
db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
|
||||
10000,
|
||||
"failed to insert 10k records"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_chain() {
|
||||
let db = SqliteStore::new(":memory:").await.unwrap();
|
||||
|
||||
let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(1000);
|
||||
|
||||
let mut tail = test_record();
|
||||
records.push(tail.clone());
|
||||
|
||||
for _ in 1..1000 {
|
||||
tail = tail.new_child(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
|
||||
records.push(tail.clone());
|
||||
}
|
||||
|
||||
db.push_batch(records.iter()).await.unwrap();
|
||||
|
||||
let mut record = db
|
||||
.head(tail.host, tail.tag.as_str())
|
||||
.await
|
||||
.expect("in memory sqlite should not fail")
|
||||
.expect("entry exists");
|
||||
|
||||
let mut count = 1;
|
||||
|
||||
while let Some(next) = db.next(&record).await.unwrap() {
|
||||
assert_eq!(record.id, next.clone().parent.unwrap());
|
||||
record = next;
|
||||
|
||||
count += 1;
|
||||
}
|
||||
|
||||
assert_eq!(count, 1000);
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,7 @@
|
||||
use async_trait::async_trait;
|
||||
use eyre::Result;
|
||||
|
||||
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
|
||||
|
||||
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus};
|
||||
/// A record store stores records
|
||||
/// In more detail - we tend to need to process this into _another_ format to actually query it.
|
||||
/// As is, the record store is intended as the source of truth for arbitratry data, which could
|
||||
@ -23,19 +22,30 @@ pub trait Store {
|
||||
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
|
||||
async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
|
||||
|
||||
/// Get the record that follows this record
|
||||
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>;
|
||||
async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
|
||||
async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
|
||||
|
||||
/// Get the next `limit` records, after and including the given index
|
||||
async fn next(
|
||||
&self,
|
||||
host: HostId,
|
||||
tag: &str,
|
||||
idx: RecordIdx,
|
||||
limit: u64,
|
||||
) -> Result<Vec<Record<EncryptedData>>>;
|
||||
|
||||
/// Get the first record for a given host and tag
|
||||
async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
|
||||
async fn idx(
|
||||
&self,
|
||||
host: HostId,
|
||||
tag: &str,
|
||||
idx: RecordIdx,
|
||||
) -> Result<Option<Record<EncryptedData>>>;
|
||||
|
||||
/// Get the last record for a given host and tag
|
||||
async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
|
||||
async fn status(&self) -> Result<RecordStatus>;
|
||||
|
||||
// Get the last record for all hosts for a given tag, useful for the read path of apps.
|
||||
async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>;
|
||||
|
||||
// Get the latest host/tag/record tuple for every set in the store. useful for building an
|
||||
// index
|
||||
async fn tail_records(&self) -> Result<RecordIndex>;
|
||||
/// Get every start record for a given tag, regardless of host.
|
||||
/// Useful when actually operating on synchronized data, and will often have conflict
|
||||
/// resolution applied.
|
||||
async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>;
|
||||
}
|
||||
|
@ -1,27 +1,51 @@
|
||||
// do a sync :O
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use eyre::Result;
|
||||
use thiserror::Error;
|
||||
|
||||
use super::store::Store;
|
||||
use crate::{api_client::Client, settings::Settings};
|
||||
|
||||
use atuin_common::record::{Diff, HostId, RecordId, RecordIndex};
|
||||
use atuin_common::record::{Diff, HostId, RecordIdx, RecordStatus};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SyncError {
|
||||
#[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
|
||||
LocalAheadOtherHost,
|
||||
|
||||
#[error("an issue with the local database occured")]
|
||||
LocalStoreError,
|
||||
|
||||
#[error("something has gone wrong with the sync logic: {msg:?}")]
|
||||
SyncLogicError { msg: String },
|
||||
|
||||
#[error("a request to the sync server failed")]
|
||||
RemoteRequestError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
pub enum Operation {
|
||||
// Either upload or download until the tail matches the below
|
||||
// Either upload or download until the states matches the below
|
||||
Upload {
|
||||
tail: RecordId,
|
||||
local: RecordIdx,
|
||||
remote: Option<RecordIdx>,
|
||||
host: HostId,
|
||||
tag: String,
|
||||
},
|
||||
Download {
|
||||
tail: RecordId,
|
||||
local: Option<RecordIdx>,
|
||||
remote: RecordIdx,
|
||||
host: HostId,
|
||||
tag: String,
|
||||
},
|
||||
Noop {
|
||||
host: HostId,
|
||||
tag: String,
|
||||
},
|
||||
}
|
||||
|
||||
pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Diff>, RecordIndex)> {
|
||||
pub async fn diff(settings: &Settings, store: &impl Store) -> Result<(Vec<Diff>, RecordStatus)> {
|
||||
let client = Client::new(
|
||||
&settings.sync_address,
|
||||
&settings.session_token,
|
||||
@ -29,8 +53,8 @@ pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Di
|
||||
settings.network_timeout,
|
||||
)?;
|
||||
|
||||
let local_index = store.tail_records().await?;
|
||||
let remote_index = client.record_index().await?;
|
||||
let local_index = store.status().await?;
|
||||
let remote_index = client.record_status().await?;
|
||||
|
||||
let diff = local_index.diff(&remote_index);
|
||||
|
||||
@ -41,39 +65,57 @@ pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Di
|
||||
// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download.
|
||||
// In theory this could be done as a part of the diffing stage, but it's easier to reason
|
||||
// about and test this way
|
||||
pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Operation>> {
|
||||
pub async fn operations(
|
||||
diffs: Vec<Diff>,
|
||||
_store: &impl Store,
|
||||
) -> Result<Vec<Operation>, SyncError> {
|
||||
let mut operations = Vec::with_capacity(diffs.len());
|
||||
|
||||
for diff in diffs {
|
||||
// First, try to fetch the tail
|
||||
// If it exists locally, then that means we need to update the remote
|
||||
// host until it has the same tail. Ie, upload.
|
||||
// If it does not exist locally, that means remote is ahead of us.
|
||||
// Therefore, we need to download until our local tail matches
|
||||
let record = store.get(diff.tail).await;
|
||||
let op = match (diff.local, diff.remote) {
|
||||
// We both have it! Could be either. Compare.
|
||||
(Some(local), Some(remote)) => match local.cmp(&remote) {
|
||||
Ordering::Equal => Operation::Noop {
|
||||
host: diff.host,
|
||||
tag: diff.tag,
|
||||
},
|
||||
Ordering::Greater => Operation::Upload {
|
||||
local,
|
||||
remote: Some(remote),
|
||||
host: diff.host,
|
||||
tag: diff.tag,
|
||||
},
|
||||
Ordering::Less => Operation::Download {
|
||||
local: Some(local),
|
||||
remote,
|
||||
host: diff.host,
|
||||
tag: diff.tag,
|
||||
},
|
||||
},
|
||||
|
||||
let op = if record.is_ok() {
|
||||
// if local has the ID, then we should find the actual tail of this
|
||||
// store, so we know what we need to update the remote to.
|
||||
let tail = store
|
||||
.tail(diff.host, diff.tag.as_str())
|
||||
.await?
|
||||
.expect("failed to fetch last record, expected tag/host to exist");
|
||||
|
||||
// TODO(ellie) update the diffing so that it stores the context of the current tail
|
||||
// that way, we can determine how much we need to upload.
|
||||
// For now just keep uploading until tails match
|
||||
|
||||
Operation::Upload {
|
||||
tail: tail.id,
|
||||
// Remote has it, we don't. Gotta be download
|
||||
(None, Some(remote)) => Operation::Download {
|
||||
local: None,
|
||||
remote,
|
||||
host: diff.host,
|
||||
tag: diff.tag,
|
||||
}
|
||||
} else {
|
||||
Operation::Download {
|
||||
tail: diff.tail,
|
||||
},
|
||||
|
||||
// We have it, remote doesn't. Gotta be upload.
|
||||
(Some(local), None) => Operation::Upload {
|
||||
local,
|
||||
remote: None,
|
||||
host: diff.host,
|
||||
tag: diff.tag,
|
||||
},
|
||||
|
||||
// something is pretty fucked.
|
||||
(None, None) => {
|
||||
return Err(SyncError::SyncLogicError {
|
||||
msg: String::from(
|
||||
"diff has nothing for local or remote - (host, tag) does not exist",
|
||||
),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
@ -86,149 +128,130 @@ pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Oper
|
||||
// with the same properties
|
||||
|
||||
operations.sort_by_key(|op| match op {
|
||||
Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
|
||||
Operation::Download { tail, host, .. } => ("download", *host, *tail),
|
||||
Operation::Noop { host, tag } => (0, *host, tag.clone()),
|
||||
|
||||
Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
|
||||
|
||||
Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
|
||||
});
|
||||
|
||||
Ok(operations)
|
||||
}
|
||||
|
||||
async fn sync_upload(
|
||||
store: &mut impl Store,
|
||||
remote_index: &RecordIndex,
|
||||
store: &impl Store,
|
||||
client: &Client<'_>,
|
||||
op: (HostId, String, RecordId),
|
||||
) -> Result<i64> {
|
||||
host: HostId,
|
||||
tag: String,
|
||||
local: RecordIdx,
|
||||
remote: Option<RecordIdx>,
|
||||
) -> Result<i64, SyncError> {
|
||||
let remote = remote.unwrap_or(0);
|
||||
let expected = local - remote;
|
||||
let upload_page_size = 100;
|
||||
let mut total = 0;
|
||||
|
||||
// so. we have an upload operation, with the tail representing the state
|
||||
// we want to get the remote to
|
||||
let current_tail = remote_index.get(op.0, op.1.clone());
|
||||
let mut progress = 0;
|
||||
|
||||
println!(
|
||||
"Syncing local {:?}/{}/{:?}, remote has {:?}",
|
||||
op.0, op.1, op.2, current_tail
|
||||
"Uploading {} records to {}/{}",
|
||||
expected,
|
||||
host.0.as_simple(),
|
||||
tag
|
||||
);
|
||||
|
||||
let start = if let Some(current_tail) = current_tail {
|
||||
current_tail
|
||||
} else {
|
||||
store
|
||||
.head(op.0, op.1.as_str())
|
||||
// preload with the first entry if remote does not know of this store
|
||||
loop {
|
||||
let page = store
|
||||
.next(host, tag.as_str(), remote + progress, upload_page_size)
|
||||
.await
|
||||
.expect("failed to fetch host/tag head")
|
||||
.expect("host/tag not in current index")
|
||||
.id
|
||||
};
|
||||
.map_err(|e| {
|
||||
error!("failed to read upload page: {e:?}");
|
||||
|
||||
debug!("starting push to remote from: {:?}", start);
|
||||
SyncError::LocalStoreError
|
||||
})?;
|
||||
|
||||
// we have the start point for sync. it is either the head of the store if
|
||||
// the remote has no data for it, or the tail that the remote has
|
||||
// we need to iterate from the remote tail, and keep going until
|
||||
// remote tail = current local tail
|
||||
client.post_records(&page).await.map_err(|e| {
|
||||
error!("failed to post records: {e:?}");
|
||||
|
||||
let mut record = if current_tail.is_some() {
|
||||
let r = store.get(start).await.unwrap();
|
||||
store.next(&r).await?
|
||||
} else {
|
||||
Some(store.get(start).await.unwrap())
|
||||
};
|
||||
SyncError::RemoteRequestError
|
||||
})?;
|
||||
|
||||
let mut buf = Vec::with_capacity(upload_page_size);
|
||||
println!(
|
||||
"uploaded {} to remote, progress {}/{}",
|
||||
page.len(),
|
||||
progress,
|
||||
expected
|
||||
);
|
||||
progress += page.len() as u64;
|
||||
|
||||
while let Some(r) = record {
|
||||
if buf.len() < upload_page_size {
|
||||
buf.push(r.clone());
|
||||
} else {
|
||||
client.post_records(&buf).await?;
|
||||
|
||||
// can we reset what we have? len = 0 but keep capacity
|
||||
buf = Vec::with_capacity(upload_page_size);
|
||||
if progress >= expected {
|
||||
break;
|
||||
}
|
||||
record = store.next(&r).await?;
|
||||
|
||||
total += 1;
|
||||
}
|
||||
|
||||
if !buf.is_empty() {
|
||||
client.post_records(&buf).await?;
|
||||
}
|
||||
|
||||
Ok(total)
|
||||
Ok(progress as i64)
|
||||
}
|
||||
|
||||
async fn sync_download(
|
||||
store: &mut impl Store,
|
||||
remote_index: &RecordIndex,
|
||||
store: &impl Store,
|
||||
client: &Client<'_>,
|
||||
op: (HostId, String, RecordId),
|
||||
) -> Result<i64> {
|
||||
// TODO(ellie): implement variable page sizing like on history sync
|
||||
let download_page_size = 1000;
|
||||
host: HostId,
|
||||
tag: String,
|
||||
local: Option<RecordIdx>,
|
||||
remote: RecordIdx,
|
||||
) -> Result<i64, SyncError> {
|
||||
let local = local.unwrap_or(0);
|
||||
let expected = remote - local;
|
||||
let download_page_size = 100;
|
||||
let mut progress = 0;
|
||||
|
||||
let mut total = 0;
|
||||
println!(
|
||||
"Downloading {} records from {}/{}",
|
||||
expected,
|
||||
host.0.as_simple(),
|
||||
tag
|
||||
);
|
||||
|
||||
// We know that the remote is ahead of us, so let's keep downloading until both
|
||||
// 1) The remote stops returning full pages
|
||||
// 2) The tail equals what we expect
|
||||
//
|
||||
// If (1) occurs without (2), then something is wrong with our index calculation
|
||||
// and we should bail.
|
||||
let remote_tail = remote_index
|
||||
.get(op.0, op.1.clone())
|
||||
.expect("remote index does not contain expected tail during download");
|
||||
let local_tail = store.tail(op.0, op.1.as_str()).await?;
|
||||
//
|
||||
// We expect that the operations diff will represent the desired state
|
||||
// In this case, that contains the remote tail.
|
||||
assert_eq!(remote_tail, op.2);
|
||||
// preload with the first entry if remote does not know of this store
|
||||
loop {
|
||||
let page = client
|
||||
.next_records(host, tag.clone(), local + progress, download_page_size)
|
||||
.await
|
||||
.map_err(|_| SyncError::RemoteRequestError)?;
|
||||
|
||||
println!("Downloading {:?}/{}/{:?} to local", op.0, op.1, op.2);
|
||||
store
|
||||
.push_batch(page.iter())
|
||||
.await
|
||||
.map_err(|_| SyncError::LocalStoreError)?;
|
||||
|
||||
let mut records = client
|
||||
.next_records(
|
||||
op.0,
|
||||
op.1.clone(),
|
||||
local_tail.map(|r| r.id),
|
||||
download_page_size,
|
||||
)
|
||||
.await?;
|
||||
println!(
|
||||
"downloaded {} records from remote, progress {}/{}",
|
||||
page.len(),
|
||||
progress,
|
||||
expected
|
||||
);
|
||||
|
||||
while !records.is_empty() {
|
||||
total += std::cmp::min(download_page_size, records.len() as u64);
|
||||
store.push_batch(records.iter()).await?;
|
||||
progress += page.len() as u64;
|
||||
|
||||
if records.last().unwrap().id == remote_tail {
|
||||
if progress >= expected {
|
||||
break;
|
||||
}
|
||||
|
||||
records = client
|
||||
.next_records(
|
||||
op.0,
|
||||
op.1.clone(),
|
||||
records.last().map(|r| r.id),
|
||||
download_page_size,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(total as i64)
|
||||
Ok(progress as i64)
|
||||
}
|
||||
|
||||
pub async fn sync_remote(
|
||||
operations: Vec<Operation>,
|
||||
remote_index: &RecordIndex,
|
||||
local_store: &mut impl Store,
|
||||
local_store: &impl Store,
|
||||
settings: &Settings,
|
||||
) -> Result<(i64, i64)> {
|
||||
) -> Result<(i64, i64), SyncError> {
|
||||
let client = Client::new(
|
||||
&settings.sync_address,
|
||||
&settings.session_token,
|
||||
settings.network_connect_timeout,
|
||||
settings.network_timeout,
|
||||
)?;
|
||||
)
|
||||
.expect("failed to create client");
|
||||
|
||||
let mut uploaded = 0;
|
||||
let mut downloaded = 0;
|
||||
@ -236,14 +259,23 @@ pub async fn sync_remote(
|
||||
// this can totally run in parallel, but lets get it working first
|
||||
for i in operations {
|
||||
match i {
|
||||
Operation::Upload { tail, host, tag } => {
|
||||
uploaded +=
|
||||
sync_upload(local_store, remote_index, &client, (host, tag, tail)).await?
|
||||
}
|
||||
Operation::Download { tail, host, tag } => {
|
||||
downloaded +=
|
||||
sync_download(local_store, remote_index, &client, (host, tag, tail)).await?
|
||||
Operation::Upload {
|
||||
host,
|
||||
tag,
|
||||
local,
|
||||
remote,
|
||||
} => uploaded += sync_upload(local_store, &client, host, tag, local, remote).await?,
|
||||
|
||||
Operation::Download {
|
||||
host,
|
||||
tag,
|
||||
local,
|
||||
remote,
|
||||
} => {
|
||||
downloaded += sync_download(local_store, &client, host, tag, local, remote).await?
|
||||
}
|
||||
|
||||
Operation::Noop { .. } => continue,
|
||||
}
|
||||
}
|
||||
|
||||
@ -264,13 +296,16 @@ mod tests {
|
||||
|
||||
fn test_record() -> Record<EncryptedData> {
|
||||
Record::builder()
|
||||
.host(HostId(atuin_common::utils::uuid_v7()))
|
||||
.host(atuin_common::record::Host::new(HostId(
|
||||
atuin_common::utils::uuid_v7(),
|
||||
)))
|
||||
.version("v1".into())
|
||||
.tag(atuin_common::utils::uuid_v7().simple().to_string())
|
||||
.data(EncryptedData {
|
||||
data: String::new(),
|
||||
content_encryption_key: String::new(),
|
||||
})
|
||||
.idx(0)
|
||||
.build()
|
||||
}
|
||||
|
||||
@ -296,8 +331,8 @@ mod tests {
|
||||
remote_store.push(&i).await.unwrap();
|
||||
}
|
||||
|
||||
let local_index = local_store.tail_records().await.unwrap();
|
||||
let remote_index = remote_store.tail_records().await.unwrap();
|
||||
let local_index = local_store.status().await.unwrap();
|
||||
let remote_index = remote_store.status().await.unwrap();
|
||||
|
||||
let diff = local_index.diff(&remote_index);
|
||||
|
||||
@ -320,9 +355,10 @@ mod tests {
|
||||
assert_eq!(
|
||||
operations[0],
|
||||
Operation::Upload {
|
||||
host: record.host,
|
||||
host: record.host.id,
|
||||
tag: record.tag,
|
||||
tail: record.id
|
||||
local: record.idx,
|
||||
remote: None,
|
||||
}
|
||||
);
|
||||
}
|
||||
@ -333,12 +369,14 @@ mod tests {
|
||||
// another. One upload, one download
|
||||
|
||||
let shared_record = test_record();
|
||||
|
||||
let remote_ahead = test_record();
|
||||
|
||||
let local_ahead = shared_record
|
||||
.new_child(vec![1, 2, 3])
|
||||
.append(vec![1, 2, 3])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
|
||||
assert_eq!(local_ahead.idx, 1);
|
||||
|
||||
let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store
|
||||
let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store
|
||||
|
||||
@ -350,15 +388,19 @@ mod tests {
|
||||
assert_eq!(
|
||||
operations,
|
||||
vec![
|
||||
Operation::Download {
|
||||
tail: remote_ahead.id,
|
||||
host: remote_ahead.host,
|
||||
tag: remote_ahead.tag,
|
||||
},
|
||||
// Or in otherwords, local is ahead by one
|
||||
Operation::Upload {
|
||||
tail: local_ahead.id,
|
||||
host: local_ahead.host,
|
||||
host: local_ahead.host.id,
|
||||
tag: local_ahead.tag,
|
||||
local: 1,
|
||||
remote: Some(0),
|
||||
},
|
||||
// Or in other words, remote knows of a record in an entirely new store (tag)
|
||||
Operation::Download {
|
||||
host: remote_ahead.host.id,
|
||||
tag: remote_ahead.tag,
|
||||
local: None,
|
||||
remote: 0,
|
||||
},
|
||||
]
|
||||
);
|
||||
@ -371,66 +413,160 @@ mod tests {
|
||||
// One known only by remote
|
||||
|
||||
let shared_record = test_record();
|
||||
let local_only = test_record();
|
||||
|
||||
let remote_known = test_record();
|
||||
let local_known = test_record();
|
||||
let local_only_20 = test_record();
|
||||
let local_only_21 = local_only_20
|
||||
.append(vec![1, 2, 3])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
let local_only_22 = local_only_21
|
||||
.append(vec![1, 2, 3])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
let local_only_23 = local_only_22
|
||||
.append(vec![1, 2, 3])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
|
||||
let remote_only = test_record();
|
||||
|
||||
let remote_only_20 = test_record();
|
||||
let remote_only_21 = remote_only_20
|
||||
.append(vec![2, 3, 2])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
let remote_only_22 = remote_only_21
|
||||
.append(vec![2, 3, 2])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
let remote_only_23 = remote_only_22
|
||||
.append(vec![2, 3, 2])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
let remote_only_24 = remote_only_23
|
||||
.append(vec![2, 3, 2])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
|
||||
let second_shared = test_record();
|
||||
let second_shared_remote_ahead = second_shared
|
||||
.new_child(vec![1, 2, 3])
|
||||
.append(vec![1, 2, 3])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
let second_shared_remote_ahead2 = second_shared_remote_ahead
|
||||
.append(vec![1, 2, 3])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
|
||||
let local_ahead = shared_record
|
||||
.new_child(vec![1, 2, 3])
|
||||
let third_shared = test_record();
|
||||
let third_shared_local_ahead = third_shared
|
||||
.append(vec![1, 2, 3])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
let third_shared_local_ahead2 = third_shared_local_ahead
|
||||
.append(vec![1, 2, 3])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
|
||||
let fourth_shared = test_record();
|
||||
let fourth_shared_remote_ahead = fourth_shared
|
||||
.append(vec![1, 2, 3])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead
|
||||
.append(vec![1, 2, 3])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
|
||||
let local = vec![
|
||||
shared_record.clone(),
|
||||
second_shared.clone(),
|
||||
local_known.clone(),
|
||||
local_ahead.clone(),
|
||||
third_shared.clone(),
|
||||
fourth_shared.clone(),
|
||||
fourth_shared_remote_ahead.clone(),
|
||||
// single store, only local has it
|
||||
local_only.clone(),
|
||||
// bigger store, also only known by local
|
||||
local_only_20.clone(),
|
||||
local_only_21.clone(),
|
||||
local_only_22.clone(),
|
||||
local_only_23.clone(),
|
||||
// another shared store, but local is ahead on this one
|
||||
third_shared_local_ahead.clone(),
|
||||
third_shared_local_ahead2.clone(),
|
||||
];
|
||||
|
||||
let remote = vec![
|
||||
remote_only.clone(),
|
||||
remote_only_20.clone(),
|
||||
remote_only_21.clone(),
|
||||
remote_only_22.clone(),
|
||||
remote_only_23.clone(),
|
||||
remote_only_24.clone(),
|
||||
shared_record.clone(),
|
||||
second_shared.clone(),
|
||||
third_shared.clone(),
|
||||
second_shared_remote_ahead.clone(),
|
||||
remote_known.clone(),
|
||||
second_shared_remote_ahead2.clone(),
|
||||
fourth_shared.clone(),
|
||||
fourth_shared_remote_ahead.clone(),
|
||||
fourth_shared_remote_ahead2.clone(),
|
||||
]; // remote knows about the already-synced, and one new record in a new store
|
||||
|
||||
let (store, diff) = build_test_diff(local, remote).await;
|
||||
let operations = sync::operations(diff, &store).await.unwrap();
|
||||
|
||||
assert_eq!(operations.len(), 4);
|
||||
assert_eq!(operations.len(), 7);
|
||||
|
||||
let mut result_ops = vec![
|
||||
// We started with a shared record, but the remote knows of two newer records in the
|
||||
// same store
|
||||
Operation::Download {
|
||||
tail: remote_known.id,
|
||||
host: remote_known.host,
|
||||
tag: remote_known.tag,
|
||||
local: Some(0),
|
||||
remote: 2,
|
||||
host: second_shared_remote_ahead.host.id,
|
||||
tag: second_shared_remote_ahead.tag,
|
||||
},
|
||||
// We have a shared record, local knows of the first two but not the last
|
||||
Operation::Download {
|
||||
tail: second_shared_remote_ahead.id,
|
||||
host: second_shared.host,
|
||||
tag: second_shared.tag,
|
||||
local: Some(1),
|
||||
remote: 2,
|
||||
host: fourth_shared_remote_ahead2.host.id,
|
||||
tag: fourth_shared_remote_ahead2.tag,
|
||||
},
|
||||
Operation::Upload {
|
||||
tail: local_ahead.id,
|
||||
host: local_ahead.host,
|
||||
tag: local_ahead.tag,
|
||||
// Remote knows of a store with a single record that local does not have
|
||||
Operation::Download {
|
||||
local: None,
|
||||
remote: 0,
|
||||
host: remote_only.host.id,
|
||||
tag: remote_only.tag,
|
||||
},
|
||||
// Remote knows of a store with a bunch of records that local does not have
|
||||
Operation::Download {
|
||||
local: None,
|
||||
remote: 4,
|
||||
host: remote_only_20.host.id,
|
||||
tag: remote_only_20.tag,
|
||||
},
|
||||
// Local knows of a record in a store that remote does not have
|
||||
Operation::Upload {
|
||||
tail: local_known.id,
|
||||
host: local_known.host,
|
||||
tag: local_known.tag,
|
||||
local: 0,
|
||||
remote: None,
|
||||
host: local_only.host.id,
|
||||
tag: local_only.tag,
|
||||
},
|
||||
// Local knows of 4 records in a store that remote does not have
|
||||
Operation::Upload {
|
||||
local: 3,
|
||||
remote: None,
|
||||
host: local_only_20.host.id,
|
||||
tag: local_only_20.tag,
|
||||
},
|
||||
// Local knows of 2 more records in a shared store that remote only has one of
|
||||
Operation::Upload {
|
||||
local: 2,
|
||||
remote: Some(0),
|
||||
host: third_shared.host.id,
|
||||
tag: third_shared.tag,
|
||||
},
|
||||
];
|
||||
|
||||
result_ops.sort_by_key(|op| match op {
|
||||
Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
|
||||
Operation::Download { tail, host, .. } => ("download", *host, *tail),
|
||||
Operation::Noop { host, tag } => (0, *host, tag.clone()),
|
||||
|
||||
Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
|
||||
|
||||
Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
|
||||
});
|
||||
|
||||
assert_eq!(operations, result_ops);
|
||||
assert_eq!(result_ops, operations);
|
||||
}
|
||||
}
|
||||
|
@ -173,6 +173,11 @@ impl Default for Stats {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Default)]
|
||||
pub struct Sync {
|
||||
pub records: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Settings {
|
||||
pub dialect: Dialect,
|
||||
@ -217,6 +222,9 @@ pub struct Settings {
|
||||
#[serde(default)]
|
||||
pub stats: Stats,
|
||||
|
||||
#[serde(default)]
|
||||
pub sync: Sync,
|
||||
|
||||
// This is automatically loaded when settings is created. Do not set in
|
||||
// config! Keep secrets and settings apart.
|
||||
#[serde(skip)]
|
||||
@ -427,6 +435,7 @@ impl Settings {
|
||||
// muscle memory.
|
||||
// New users will get the new default, that is more similar to what they are used to.
|
||||
.set_default("enter_accept", false)?
|
||||
.set_default("sync.records", false)?
|
||||
.add_source(
|
||||
Environment::with_prefix("atuin")
|
||||
.prefix_separator("_")
|
||||
|
@ -14,13 +14,34 @@ pub struct EncryptedData {
|
||||
pub content_encryption_key: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
#[derive(Debug, PartialEq, PartialOrd, Ord, Eq)]
|
||||
pub struct Diff {
|
||||
pub host: HostId,
|
||||
pub tag: String,
|
||||
pub tail: RecordId,
|
||||
pub local: Option<RecordIdx>,
|
||||
pub remote: Option<RecordIdx>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
|
||||
pub struct Host {
|
||||
pub id: HostId,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl Host {
|
||||
pub fn new(id: HostId) -> Self {
|
||||
Host {
|
||||
id,
|
||||
name: String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
new_uuid!(RecordId);
|
||||
new_uuid!(HostId);
|
||||
|
||||
pub type RecordIdx = u64;
|
||||
|
||||
/// A single record stored inside of our local database
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)]
|
||||
pub struct Record<Data> {
|
||||
@ -28,18 +49,14 @@ pub struct Record<Data> {
|
||||
#[builder(default = RecordId(crate::utils::uuid_v7()))]
|
||||
pub id: RecordId,
|
||||
|
||||
/// The integer record ID. This is only unique per (host, tag).
|
||||
pub idx: RecordIdx,
|
||||
|
||||
/// The unique ID of the host.
|
||||
// TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store
|
||||
// as strings. I would rather avoid normalization, so store as UUID binary instead of
|
||||
// encoding to a string and wasting much more storage.
|
||||
pub host: HostId,
|
||||
|
||||
/// The ID of the parent entry
|
||||
// A store is technically just a double linked list
|
||||
// We can do some cheating with the timestamps, but should not rely upon them.
|
||||
// Clocks are tricksy.
|
||||
#[builder(default)]
|
||||
pub parent: Option<RecordId>,
|
||||
pub host: Host,
|
||||
|
||||
/// The creation time in nanoseconds since unix epoch
|
||||
#[builder(default = time::OffsetDateTime::now_utc().unix_timestamp_nanos() as u64)]
|
||||
@ -56,25 +73,22 @@ pub struct Record<Data> {
|
||||
pub data: Data,
|
||||
}
|
||||
|
||||
new_uuid!(RecordId);
|
||||
new_uuid!(HostId);
|
||||
|
||||
/// Extra data from the record that should be encoded in the data
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct AdditionalData<'a> {
|
||||
pub id: &'a RecordId,
|
||||
pub idx: &'a u64,
|
||||
pub version: &'a str,
|
||||
pub tag: &'a str,
|
||||
pub host: &'a HostId,
|
||||
pub parent: Option<&'a RecordId>,
|
||||
}
|
||||
|
||||
impl<Data> Record<Data> {
|
||||
pub fn new_child(&self, data: Vec<u8>) -> Record<DecryptedData> {
|
||||
pub fn append(&self, data: Vec<u8>) -> Record<DecryptedData> {
|
||||
Record::builder()
|
||||
.host(self.host)
|
||||
.host(self.host.clone())
|
||||
.version(self.version.clone())
|
||||
.parent(Some(self.id))
|
||||
.idx(self.idx + 1)
|
||||
.tag(self.tag.clone())
|
||||
.data(DecryptedData(data))
|
||||
.build()
|
||||
@ -84,74 +98,76 @@ impl<Data> Record<Data> {
|
||||
/// An index representing the current state of the record stores
|
||||
/// This can be both remote, or local, and compared in either direction
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct RecordIndex {
|
||||
// A map of host -> tag -> tail
|
||||
pub hosts: HashMap<HostId, HashMap<String, RecordId>>,
|
||||
pub struct RecordStatus {
|
||||
// A map of host -> tag -> max(idx)
|
||||
pub hosts: HashMap<HostId, HashMap<String, RecordIdx>>,
|
||||
}
|
||||
|
||||
impl Default for RecordIndex {
|
||||
impl Default for RecordStatus {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Extend<(HostId, String, RecordId)> for RecordIndex {
|
||||
fn extend<T: IntoIterator<Item = (HostId, String, RecordId)>>(&mut self, iter: T) {
|
||||
for (host, tag, tail_id) in iter {
|
||||
self.set_raw(host, tag, tail_id);
|
||||
impl Extend<(HostId, String, RecordIdx)> for RecordStatus {
|
||||
fn extend<T: IntoIterator<Item = (HostId, String, RecordIdx)>>(&mut self, iter: T) {
|
||||
for (host, tag, tail_idx) in iter {
|
||||
self.set_raw(host, tag, tail_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RecordIndex {
|
||||
pub fn new() -> RecordIndex {
|
||||
RecordIndex {
|
||||
impl RecordStatus {
|
||||
pub fn new() -> RecordStatus {
|
||||
RecordStatus {
|
||||
hosts: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert a new tail record into the store
|
||||
pub fn set(&mut self, tail: Record<DecryptedData>) {
|
||||
self.set_raw(tail.host, tail.tag, tail.id)
|
||||
self.set_raw(tail.host.id, tail.tag, tail.idx)
|
||||
}
|
||||
|
||||
pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordId) {
|
||||
pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordIdx) {
|
||||
self.hosts.entry(host).or_default().insert(tag, tail_id);
|
||||
}
|
||||
|
||||
pub fn get(&self, host: HostId, tag: String) -> Option<RecordId> {
|
||||
pub fn get(&self, host: HostId, tag: String) -> Option<RecordIdx> {
|
||||
self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned()
|
||||
}
|
||||
|
||||
/// Diff this index with another, likely remote index.
|
||||
/// The two diffs can then be reconciled, and the optimal change set calculated
|
||||
/// Returns a tuple, with (host, tag, Option(OTHER))
|
||||
/// OTHER is set to the value of the tail on the other machine. For example, if the
|
||||
/// other machine has a different tail, it will be the differing tail. This is useful to
|
||||
/// check if the other index is ahead of us, or behind.
|
||||
/// If the other index does not have the (host, tag) pair, then the other value will be None.
|
||||
/// OTHER is set to the value of the idx on the other machine. If it is greater than our index,
|
||||
/// then we need to do some downloading. If it is smaller, then we need to do some uploading
|
||||
/// Note that we cannot upload if we are not the owner of the record store - hosts can only
|
||||
/// write to their own store.
|
||||
pub fn diff(&self, other: &Self) -> Vec<Diff> {
|
||||
let mut ret = Vec::new();
|
||||
|
||||
// First, we check if other has everything that self has
|
||||
for (host, tag_map) in self.hosts.iter() {
|
||||
for (tag, tail) in tag_map.iter() {
|
||||
for (tag, idx) in tag_map.iter() {
|
||||
match other.get(*host, tag.clone()) {
|
||||
// The other store is all up to date! No diff.
|
||||
Some(t) if t.eq(tail) => continue,
|
||||
Some(t) if t.eq(idx) => continue,
|
||||
|
||||
// The other store does exist, but it is either ahead or behind us. A diff regardless
|
||||
// The other store does exist, and it is either ahead or behind us. A diff regardless
|
||||
Some(t) => ret.push(Diff {
|
||||
host: *host,
|
||||
tag: tag.clone(),
|
||||
tail: t,
|
||||
local: Some(*idx),
|
||||
remote: Some(t),
|
||||
}),
|
||||
|
||||
// The other store does not exist :O
|
||||
None => ret.push(Diff {
|
||||
host: *host,
|
||||
tag: tag.clone(),
|
||||
tail: *tail,
|
||||
local: Some(*idx),
|
||||
remote: None,
|
||||
}),
|
||||
};
|
||||
}
|
||||
@ -162,7 +178,7 @@ impl RecordIndex {
|
||||
|
||||
// account for that!
|
||||
for (host, tag_map) in other.hosts.iter() {
|
||||
for (tag, tail) in tag_map.iter() {
|
||||
for (tag, idx) in tag_map.iter() {
|
||||
match self.get(*host, tag.clone()) {
|
||||
// If we have this host/tag combo, the comparison and diff will have already happened above
|
||||
Some(_) => continue,
|
||||
@ -170,13 +186,15 @@ impl RecordIndex {
|
||||
None => ret.push(Diff {
|
||||
host: *host,
|
||||
tag: tag.clone(),
|
||||
tail: *tail,
|
||||
remote: Some(*idx),
|
||||
local: None,
|
||||
}),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
ret.sort_by(|a, b| (a.host, a.tag.clone(), a.tail).cmp(&(b.host, b.tag.clone(), b.tail)));
|
||||
// Stability is a nice property to have
|
||||
ret.sort();
|
||||
ret
|
||||
}
|
||||
}
|
||||
@ -201,14 +219,14 @@ impl Record<DecryptedData> {
|
||||
id: &self.id,
|
||||
version: &self.version,
|
||||
tag: &self.tag,
|
||||
host: &self.host,
|
||||
parent: self.parent.as_ref(),
|
||||
host: &self.host.id,
|
||||
idx: &self.idx,
|
||||
};
|
||||
Record {
|
||||
data: E::encrypt(self.data, ad, key),
|
||||
id: self.id,
|
||||
host: self.host,
|
||||
parent: self.parent,
|
||||
idx: self.idx,
|
||||
timestamp: self.timestamp,
|
||||
version: self.version,
|
||||
tag: self.tag,
|
||||
@ -222,14 +240,14 @@ impl Record<EncryptedData> {
|
||||
id: &self.id,
|
||||
version: &self.version,
|
||||
tag: &self.tag,
|
||||
host: &self.host,
|
||||
parent: self.parent.as_ref(),
|
||||
host: &self.host.id,
|
||||
idx: &self.idx,
|
||||
};
|
||||
Ok(Record {
|
||||
data: E::decrypt(self.data, ad, key)?,
|
||||
id: self.id,
|
||||
host: self.host,
|
||||
parent: self.parent,
|
||||
idx: self.idx,
|
||||
timestamp: self.timestamp,
|
||||
version: self.version,
|
||||
tag: self.tag,
|
||||
@ -245,14 +263,14 @@ impl Record<EncryptedData> {
|
||||
id: &self.id,
|
||||
version: &self.version,
|
||||
tag: &self.tag,
|
||||
host: &self.host,
|
||||
parent: self.parent.as_ref(),
|
||||
host: &self.host.id,
|
||||
idx: &self.idx,
|
||||
};
|
||||
Ok(Record {
|
||||
data: E::re_encrypt(self.data, ad, old_key, new_key)?,
|
||||
id: self.id,
|
||||
host: self.host,
|
||||
parent: self.parent,
|
||||
idx: self.idx,
|
||||
timestamp: self.timestamp,
|
||||
version: self.version,
|
||||
tag: self.tag,
|
||||
@ -262,31 +280,32 @@ impl Record<EncryptedData> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::record::HostId;
|
||||
use crate::record::{Host, HostId};
|
||||
|
||||
use super::{DecryptedData, Diff, Record, RecordIndex};
|
||||
use super::{DecryptedData, Diff, Record, RecordStatus};
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn test_record() -> Record<DecryptedData> {
|
||||
Record::builder()
|
||||
.host(HostId(crate::utils::uuid_v7()))
|
||||
.host(Host::new(HostId(crate::utils::uuid_v7())))
|
||||
.version("v1".into())
|
||||
.tag(crate::utils::uuid_v7().simple().to_string())
|
||||
.data(DecryptedData(vec![0, 1, 2, 3]))
|
||||
.idx(0)
|
||||
.build()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_index() {
|
||||
let mut index = RecordIndex::new();
|
||||
let mut index = RecordStatus::new();
|
||||
let record = test_record();
|
||||
|
||||
index.set(record.clone());
|
||||
|
||||
let tail = index.get(record.host, record.tag);
|
||||
let tail = index.get(record.host.id, record.tag);
|
||||
|
||||
assert_eq!(
|
||||
record.id,
|
||||
record.idx,
|
||||
tail.expect("tail not in store"),
|
||||
"tail in store did not match"
|
||||
);
|
||||
@ -294,17 +313,17 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn record_index_overwrite() {
|
||||
let mut index = RecordIndex::new();
|
||||
let mut index = RecordStatus::new();
|
||||
let record = test_record();
|
||||
let child = record.new_child(vec![1, 2, 3]);
|
||||
let child = record.append(vec![1, 2, 3]);
|
||||
|
||||
index.set(record.clone());
|
||||
index.set(child.clone());
|
||||
|
||||
let tail = index.get(record.host, record.tag);
|
||||
let tail = index.get(record.host.id, record.tag);
|
||||
|
||||
assert_eq!(
|
||||
child.id,
|
||||
child.idx,
|
||||
tail.expect("tail not in store"),
|
||||
"tail in store did not match"
|
||||
);
|
||||
@ -314,8 +333,8 @@ mod tests {
|
||||
fn record_index_no_diff() {
|
||||
// Here, they both have the same version and should have no diff
|
||||
|
||||
let mut index1 = RecordIndex::new();
|
||||
let mut index2 = RecordIndex::new();
|
||||
let mut index1 = RecordStatus::new();
|
||||
let mut index2 = RecordStatus::new();
|
||||
|
||||
let record1 = test_record();
|
||||
|
||||
@ -331,11 +350,11 @@ mod tests {
|
||||
fn record_index_single_diff() {
|
||||
// Here, they both have the same stores, but one is ahead by a single record
|
||||
|
||||
let mut index1 = RecordIndex::new();
|
||||
let mut index2 = RecordIndex::new();
|
||||
let mut index1 = RecordStatus::new();
|
||||
let mut index2 = RecordStatus::new();
|
||||
|
||||
let record1 = test_record();
|
||||
let record2 = record1.new_child(vec![1, 2, 3]);
|
||||
let record2 = record1.append(vec![1, 2, 3]);
|
||||
|
||||
index1.set(record1);
|
||||
index2.set(record2.clone());
|
||||
@ -346,9 +365,10 @@ mod tests {
|
||||
assert_eq!(
|
||||
diff[0],
|
||||
Diff {
|
||||
host: record2.host,
|
||||
host: record2.host.id,
|
||||
tag: record2.tag,
|
||||
tail: record2.id
|
||||
remote: Some(1),
|
||||
local: Some(0)
|
||||
}
|
||||
);
|
||||
}
|
||||
@ -356,14 +376,14 @@ mod tests {
|
||||
#[test]
|
||||
fn record_index_multi_diff() {
|
||||
// A much more complex case, with a bunch more checks
|
||||
let mut index1 = RecordIndex::new();
|
||||
let mut index2 = RecordIndex::new();
|
||||
let mut index1 = RecordStatus::new();
|
||||
let mut index2 = RecordStatus::new();
|
||||
|
||||
let store1record1 = test_record();
|
||||
let store1record2 = store1record1.new_child(vec![1, 2, 3]);
|
||||
let store1record2 = store1record1.append(vec![1, 2, 3]);
|
||||
|
||||
let store2record1 = test_record();
|
||||
let store2record2 = store2record1.new_child(vec![1, 2, 3]);
|
||||
let store2record2 = store2record1.append(vec![1, 2, 3]);
|
||||
|
||||
let store3record1 = test_record();
|
||||
|
||||
|
@ -14,7 +14,7 @@ use self::{
|
||||
models::{History, NewHistory, NewSession, NewUser, Session, User},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
|
||||
use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset};
|
||||
use tracing::instrument;
|
||||
@ -68,12 +68,12 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
|
||||
user: &User,
|
||||
host: HostId,
|
||||
tag: String,
|
||||
start: Option<RecordId>,
|
||||
start: Option<RecordIdx>,
|
||||
count: u64,
|
||||
) -> DbResult<Vec<Record<EncryptedData>>>;
|
||||
|
||||
// Return the tail record ID for each store, so (HostID, Tag, TailRecordID)
|
||||
async fn tail_records(&self, user: &User) -> DbResult<RecordIndex>;
|
||||
async fn status(&self, user: &User) -> DbResult<RecordStatus>;
|
||||
|
||||
async fn count_history_range(&self, user: &User, range: Range<OffsetDateTime>)
|
||||
-> DbResult<i64>;
|
||||
|
@ -0,0 +1,15 @@
|
||||
-- Add migration script here
|
||||
create table store (
|
||||
id uuid primary key, -- remember to use uuidv7 for happy indices <3
|
||||
client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key, even though it's fine mathematically
|
||||
host uuid not null, -- a unique identifier for the host
|
||||
idx bigint not null, -- the index of the record in this store, identified by (host, tag)
|
||||
timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision
|
||||
version text not null,
|
||||
tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host
|
||||
data text not null, -- store the actual history data, encrypted. I don't wanna know!
|
||||
cek text not null,
|
||||
|
||||
user_id bigint not null, -- allow multiple users
|
||||
created_at timestamp not null default current_timestamp
|
||||
);
|
@ -0,0 +1,2 @@
|
||||
-- Add migration script here
|
||||
create unique index record_uniq ON store(user_id, host, tag, idx);
|
@ -1,7 +1,7 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
|
||||
use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
|
||||
use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
|
||||
use atuin_server_database::{Database, DbError, DbResult};
|
||||
use futures_util::TryStreamExt;
|
||||
@ -11,6 +11,7 @@ use sqlx::Row;
|
||||
|
||||
use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
|
||||
use tracing::instrument;
|
||||
use uuid::Uuid;
|
||||
use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
|
||||
|
||||
mod wrappers;
|
||||
@ -361,16 +362,16 @@ impl Database for Postgres {
|
||||
let id = atuin_common::utils::uuid_v7();
|
||||
|
||||
sqlx::query(
|
||||
"insert into records
|
||||
(id, client_id, host, parent, timestamp, version, tag, data, cek, user_id)
|
||||
"insert into store
|
||||
(id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
|
||||
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
on conflict do nothing
|
||||
",
|
||||
)
|
||||
.bind(id)
|
||||
.bind(i.id)
|
||||
.bind(i.host)
|
||||
.bind(i.parent)
|
||||
.bind(i.host.id)
|
||||
.bind(i.idx as i64)
|
||||
.bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
|
||||
.bind(&i.version)
|
||||
.bind(&i.tag)
|
||||
@ -393,62 +394,69 @@ impl Database for Postgres {
|
||||
user: &User,
|
||||
host: HostId,
|
||||
tag: String,
|
||||
start: Option<RecordId>,
|
||||
start: Option<RecordIdx>,
|
||||
count: u64,
|
||||
) -> DbResult<Vec<Record<EncryptedData>>> {
|
||||
tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
|
||||
let mut ret = Vec::with_capacity(count as usize);
|
||||
let mut parent = start;
|
||||
let start = start.unwrap_or(0);
|
||||
|
||||
// yeah let's do something better
|
||||
for _ in 0..count {
|
||||
// a very much not ideal query. but it's simple at least?
|
||||
// we are basically using postgres as a kv store here, so... maybe consider using an actual
|
||||
// kv store?
|
||||
let record: Result<DbRecord, DbError> = sqlx::query_as(
|
||||
"select client_id, host, parent, timestamp, version, tag, data, cek from records
|
||||
let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
|
||||
"select client_id, host, idx, timestamp, version, tag, data, cek from store
|
||||
where user_id = $1
|
||||
and tag = $2
|
||||
and host = $3
|
||||
and parent is not distinct from $4",
|
||||
)
|
||||
.bind(user.id)
|
||||
.bind(tag.clone())
|
||||
.bind(host)
|
||||
.bind(parent)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error);
|
||||
and idx >= $4
|
||||
order by idx asc
|
||||
limit $5",
|
||||
)
|
||||
.bind(user.id)
|
||||
.bind(tag.clone())
|
||||
.bind(host)
|
||||
.bind(start as i64)
|
||||
.bind(count as i64)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error);
|
||||
|
||||
match record {
|
||||
Ok(record) => {
|
||||
let record: Record<EncryptedData> = record.into();
|
||||
ret.push(record.clone());
|
||||
let ret = match records {
|
||||
Ok(records) => {
|
||||
let records: Vec<Record<EncryptedData>> = records
|
||||
.into_iter()
|
||||
.map(|f| {
|
||||
let record: Record<EncryptedData> = f.into();
|
||||
record
|
||||
})
|
||||
.collect();
|
||||
|
||||
parent = Some(record.id);
|
||||
}
|
||||
Err(DbError::NotFound) => {
|
||||
tracing::debug!("hit tail of store: {:?}/{}", host, tag);
|
||||
return Ok(ret);
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
records
|
||||
}
|
||||
}
|
||||
Err(DbError::NotFound) => {
|
||||
tracing::debug!("no records found in store: {:?}/{}", host, tag);
|
||||
return Ok(vec![]);
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
async fn tail_records(&self, user: &User) -> DbResult<RecordIndex> {
|
||||
const TAIL_RECORDS_SQL: &str = "select host, tag, client_id from records rp where (select count(1) from records where parent=rp.client_id and user_id = $1) = 0 and user_id = $1;";
|
||||
async fn status(&self, user: &User) -> DbResult<RecordStatus> {
|
||||
const STATUS_SQL: &str =
|
||||
"select host, tag, max(idx) from store where user_id = $1 group by host, tag";
|
||||
|
||||
let res = sqlx::query_as(TAIL_RECORDS_SQL)
|
||||
let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL)
|
||||
.bind(user.id)
|
||||
.fetch(&self.pool)
|
||||
.try_collect()
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
Ok(res)
|
||||
let mut status = RecordStatus::new();
|
||||
|
||||
for i in res {
|
||||
status.set_raw(HostId(i.0), i.1, i.2 as u64);
|
||||
}
|
||||
|
||||
Ok(status)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
use ::sqlx::{FromRow, Result};
|
||||
use atuin_common::record::{EncryptedData, Record};
|
||||
use atuin_common::record::{EncryptedData, Host, Record};
|
||||
use atuin_server_database::models::{History, Session, User};
|
||||
use sqlx::{postgres::PgRow, Row};
|
||||
use time::PrimitiveDateTime;
|
||||
@ -51,6 +51,7 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory {
|
||||
impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord {
|
||||
fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
|
||||
let timestamp: i64 = row.try_get("timestamp")?;
|
||||
let idx: i64 = row.try_get("idx")?;
|
||||
|
||||
let data = EncryptedData {
|
||||
data: row.try_get("data")?,
|
||||
@ -59,8 +60,8 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord {
|
||||
|
||||
Ok(Self(Record {
|
||||
id: row.try_get("client_id")?,
|
||||
host: row.try_get("host")?,
|
||||
parent: row.try_get("parent")?,
|
||||
host: Host::new(row.try_get("host")?),
|
||||
idx: idx as u64,
|
||||
timestamp: timestamp as u64,
|
||||
version: row.try_get("version")?,
|
||||
tag: row.try_get("tag")?,
|
||||
|
@ -8,6 +8,7 @@ pub mod history;
|
||||
pub mod record;
|
||||
pub mod status;
|
||||
pub mod user;
|
||||
pub mod v0;
|
||||
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
|
@ -1,109 +1,46 @@
|
||||
use axum::{extract::Query, extract::State, Json};
|
||||
use axum::{response::IntoResponse, Json};
|
||||
use http::StatusCode;
|
||||
use metrics::counter;
|
||||
use serde::Deserialize;
|
||||
use tracing::{error, instrument};
|
||||
use serde_json::json;
|
||||
use tracing::instrument;
|
||||
|
||||
use super::{ErrorResponse, ErrorResponseStatus, RespExt};
|
||||
use crate::router::{AppState, UserAuth};
|
||||
use crate::router::UserAuth;
|
||||
use atuin_server_database::Database;
|
||||
|
||||
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
|
||||
use atuin_common::record::{EncryptedData, Record};
|
||||
|
||||
#[instrument(skip_all, fields(user.id = user.id))]
|
||||
pub async fn post<DB: Database>(
|
||||
UserAuth(user): UserAuth,
|
||||
state: State<AppState<DB>>,
|
||||
Json(records): Json<Vec<Record<EncryptedData>>>,
|
||||
) -> Result<(), ErrorResponseStatus<'static>> {
|
||||
let State(AppState { database, settings }) = state;
|
||||
// anyone who has actually used the old record store (a very small number) will see this error
|
||||
// upon trying to sync.
|
||||
// 1. The status endpoint will say that the server has nothing
|
||||
// 2. The client will try to upload local records
|
||||
// 3. Sync will fail with this error
|
||||
|
||||
tracing::debug!(
|
||||
count = records.len(),
|
||||
user = user.username,
|
||||
"request to add records"
|
||||
// If the client has no local records, they will see the empty index and do nothing. For the
|
||||
// vast majority of users, this is the case.
|
||||
return Err(
|
||||
ErrorResponse::reply("record store deprecated; please upgrade")
|
||||
.with_status(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
|
||||
counter!("atuin_record_uploaded", records.len() as u64);
|
||||
|
||||
let too_big = records
|
||||
.iter()
|
||||
.any(|r| r.data.data.len() >= settings.max_record_size || settings.max_record_size == 0);
|
||||
|
||||
if too_big {
|
||||
counter!("atuin_record_too_large", 1);
|
||||
|
||||
return Err(
|
||||
ErrorResponse::reply("could not add records; record too large")
|
||||
.with_status(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
|
||||
if let Err(e) = database.add_records(&user, &records).await {
|
||||
error!("failed to add record: {}", e);
|
||||
|
||||
return Err(ErrorResponse::reply("failed to add record")
|
||||
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(user.id = user.id))]
|
||||
pub async fn index<DB: Database>(
|
||||
UserAuth(user): UserAuth,
|
||||
state: State<AppState<DB>>,
|
||||
) -> Result<Json<RecordIndex>, ErrorResponseStatus<'static>> {
|
||||
let State(AppState {
|
||||
database,
|
||||
settings: _,
|
||||
}) = state;
|
||||
pub async fn index<DB: Database>(UserAuth(user): UserAuth) -> axum::response::Response {
|
||||
let ret = json!({
|
||||
"hosts": {}
|
||||
});
|
||||
|
||||
let record_index = match database.tail_records(&user).await {
|
||||
Ok(index) => index,
|
||||
Err(e) => {
|
||||
error!("failed to get record index: {}", e);
|
||||
|
||||
return Err(ErrorResponse::reply("failed to calculate record index")
|
||||
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(record_index))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct NextParams {
|
||||
host: HostId,
|
||||
tag: String,
|
||||
start: Option<RecordId>,
|
||||
count: u64,
|
||||
ret.to_string().into_response()
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(user.id = user.id))]
|
||||
pub async fn next<DB: Database>(
|
||||
params: Query<NextParams>,
|
||||
pub async fn next(
|
||||
UserAuth(user): UserAuth,
|
||||
state: State<AppState<DB>>,
|
||||
) -> Result<Json<Vec<Record<EncryptedData>>>, ErrorResponseStatus<'static>> {
|
||||
let State(AppState {
|
||||
database,
|
||||
settings: _,
|
||||
}) = state;
|
||||
let params = params.0;
|
||||
|
||||
let records = match database
|
||||
.next_records(&user, params.host, params.tag, params.start, params.count)
|
||||
.await
|
||||
{
|
||||
Ok(records) => records,
|
||||
Err(e) => {
|
||||
error!("failed to get record index: {}", e);
|
||||
|
||||
return Err(ErrorResponse::reply("failed to calculate record index")
|
||||
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
|
||||
}
|
||||
};
|
||||
let records = Vec::new();
|
||||
|
||||
Ok(Json(records))
|
||||
}
|
||||
|
1
atuin-server/src/handlers/v0/mod.rs
Normal file
1
atuin-server/src/handlers/v0/mod.rs
Normal file
@ -0,0 +1 @@
|
||||
pub(crate) mod record;
|
111
atuin-server/src/handlers/v0/record.rs
Normal file
111
atuin-server/src/handlers/v0/record.rs
Normal file
@ -0,0 +1,111 @@
|
||||
use axum::{extract::Query, extract::State, Json};
|
||||
use http::StatusCode;
|
||||
use metrics::counter;
|
||||
use serde::Deserialize;
|
||||
use tracing::{error, instrument};
|
||||
|
||||
use crate::{
|
||||
handlers::{ErrorResponse, ErrorResponseStatus, RespExt},
|
||||
router::{AppState, UserAuth},
|
||||
};
|
||||
use atuin_server_database::Database;
|
||||
|
||||
use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
|
||||
|
||||
#[instrument(skip_all, fields(user.id = user.id))]
|
||||
pub async fn post<DB: Database>(
|
||||
UserAuth(user): UserAuth,
|
||||
state: State<AppState<DB>>,
|
||||
Json(records): Json<Vec<Record<EncryptedData>>>,
|
||||
) -> Result<(), ErrorResponseStatus<'static>> {
|
||||
let State(AppState { database, settings }) = state;
|
||||
|
||||
tracing::debug!(
|
||||
count = records.len(),
|
||||
user = user.username,
|
||||
"request to add records"
|
||||
);
|
||||
|
||||
counter!("atuin_record_uploaded", records.len() as u64);
|
||||
|
||||
let too_big = records
|
||||
.iter()
|
||||
.any(|r| r.data.data.len() >= settings.max_record_size || settings.max_record_size == 0);
|
||||
|
||||
if too_big {
|
||||
counter!("atuin_record_too_large", 1);
|
||||
|
||||
return Err(
|
||||
ErrorResponse::reply("could not add records; record too large")
|
||||
.with_status(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
|
||||
if let Err(e) = database.add_records(&user, &records).await {
|
||||
error!("failed to add record: {}", e);
|
||||
|
||||
return Err(ErrorResponse::reply("failed to add record")
|
||||
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(user.id = user.id))]
|
||||
pub async fn index<DB: Database>(
|
||||
UserAuth(user): UserAuth,
|
||||
state: State<AppState<DB>>,
|
||||
) -> Result<Json<RecordStatus>, ErrorResponseStatus<'static>> {
|
||||
let State(AppState {
|
||||
database,
|
||||
settings: _,
|
||||
}) = state;
|
||||
|
||||
let record_index = match database.status(&user).await {
|
||||
Ok(index) => index,
|
||||
Err(e) => {
|
||||
error!("failed to get record index: {}", e);
|
||||
|
||||
return Err(ErrorResponse::reply("failed to calculate record index")
|
||||
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(record_index))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct NextParams {
|
||||
host: HostId,
|
||||
tag: String,
|
||||
start: Option<RecordIdx>,
|
||||
count: u64,
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(user.id = user.id))]
|
||||
pub async fn next<DB: Database>(
|
||||
params: Query<NextParams>,
|
||||
UserAuth(user): UserAuth,
|
||||
state: State<AppState<DB>>,
|
||||
) -> Result<Json<Vec<Record<EncryptedData>>>, ErrorResponseStatus<'static>> {
|
||||
let State(AppState {
|
||||
database,
|
||||
settings: _,
|
||||
}) = state;
|
||||
let params = params.0;
|
||||
|
||||
let records = match database
|
||||
.next_records(&user, params.host, params.tag, params.start, params.count)
|
||||
.await
|
||||
{
|
||||
Ok(records) => records,
|
||||
Err(e) => {
|
||||
error!("failed to get record index: {}", e);
|
||||
|
||||
return Err(ErrorResponse::reply("failed to calculate record index")
|
||||
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(records))
|
||||
}
|
@ -118,13 +118,16 @@ pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> R
|
||||
.route("/sync/status", get(handlers::status::status))
|
||||
.route("/history", post(handlers::history::add))
|
||||
.route("/history", delete(handlers::history::delete))
|
||||
.route("/record", post(handlers::record::post))
|
||||
.route("/record", get(handlers::record::index))
|
||||
.route("/record/next", get(handlers::record::next))
|
||||
.route("/user/:username", get(handlers::user::get))
|
||||
.route("/account", delete(handlers::user::delete))
|
||||
.route("/register", post(handlers::user::register))
|
||||
.route("/login", post(handlers::user::login));
|
||||
.route("/login", post(handlers::user::login))
|
||||
.route("/record", post(handlers::record::post::<DB>))
|
||||
.route("/record", get(handlers::record::index::<DB>))
|
||||
.route("/record/next", get(handlers::record::next))
|
||||
.route("/api/v0/record", post(handlers::v0::record::post))
|
||||
.route("/api/v0/record", get(handlers::v0::record::index))
|
||||
.route("/api/v0/record/next", get(handlers::v0::record::next));
|
||||
|
||||
let path = settings.path.as_str();
|
||||
if path.is_empty() {
|
||||
|
@ -16,6 +16,7 @@ mod config;
|
||||
mod history;
|
||||
mod import;
|
||||
mod kv;
|
||||
mod record;
|
||||
mod search;
|
||||
mod stats;
|
||||
|
||||
@ -46,6 +47,9 @@ pub enum Cmd {
|
||||
#[command(subcommand)]
|
||||
Kv(kv::Cmd),
|
||||
|
||||
#[command(subcommand)]
|
||||
Record(record::Cmd),
|
||||
|
||||
/// Print example configuration
|
||||
#[command()]
|
||||
DefaultConfig,
|
||||
@ -79,21 +83,23 @@ impl Cmd {
|
||||
let record_store_path = PathBuf::from(settings.record_store_path.as_str());
|
||||
|
||||
let db = Sqlite::new(db_path).await?;
|
||||
let mut store = SqliteStore::new(record_store_path).await?;
|
||||
let store = SqliteStore::new(record_store_path).await?;
|
||||
|
||||
match self {
|
||||
Self::History(history) => history.run(&settings, &db).await,
|
||||
Self::History(history) => history.run(&settings, &db, store).await,
|
||||
Self::Import(import) => import.run(&db).await,
|
||||
Self::Stats(stats) => stats.run(&db, &settings).await,
|
||||
Self::Search(search) => search.run(db, &mut settings).await,
|
||||
|
||||
#[cfg(feature = "sync")]
|
||||
Self::Sync(sync) => sync.run(settings, &db, &mut store).await,
|
||||
Self::Sync(sync) => sync.run(settings, &db, &store).await,
|
||||
|
||||
#[cfg(feature = "sync")]
|
||||
Self::Account(account) => account.run(settings).await,
|
||||
|
||||
Self::Kv(kv) => kv.run(&settings, &mut store).await,
|
||||
Self::Kv(kv) => kv.run(&settings, &store).await,
|
||||
|
||||
Self::Record(record) => record.run(&settings, &store).await,
|
||||
|
||||
Self::DefaultConfig => {
|
||||
config::run();
|
||||
|
@ -12,7 +12,9 @@ use runtime_format::{FormatKey, FormatKeyError, ParseSegment, ParsedFmt};
|
||||
|
||||
use atuin_client::{
|
||||
database::{current_context, Database},
|
||||
history::History,
|
||||
encryption,
|
||||
history::{store::HistoryStore, History},
|
||||
record::{self, sqlite_store::SqliteStore},
|
||||
settings::Settings,
|
||||
};
|
||||
|
||||
@ -84,6 +86,10 @@ pub enum Cmd {
|
||||
#[arg(long, short)]
|
||||
format: Option<String>,
|
||||
},
|
||||
|
||||
/// Import all old history.db data into the record store. Do not run more than once, and do not
|
||||
/// run unless you know what you're doing (or the docs ask you to)
|
||||
InitStore,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
@ -266,11 +272,14 @@ impl Cmd {
|
||||
// we use this as the key for calling end
|
||||
println!("{}", h.id);
|
||||
db.save(&h).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_end(
|
||||
db: &impl Database,
|
||||
store: SqliteStore,
|
||||
history_store: HistoryStore,
|
||||
settings: &Settings,
|
||||
id: &str,
|
||||
exit: i64,
|
||||
@ -300,10 +309,20 @@ impl Cmd {
|
||||
};
|
||||
|
||||
db.update(&h).await?;
|
||||
history_store.push(h).await?;
|
||||
|
||||
if settings.should_sync()? {
|
||||
#[cfg(feature = "sync")]
|
||||
{
|
||||
if settings.sync.records {
|
||||
let (diff, _) = record::sync::diff(settings, &store).await?;
|
||||
let operations = record::sync::operations(diff, &store).await?;
|
||||
let (uploaded, downloaded) =
|
||||
record::sync::sync_remote(operations, &store, settings).await?;
|
||||
|
||||
println!("{uploaded}/{downloaded} up/down to record store");
|
||||
}
|
||||
|
||||
debug!("running periodic background sync");
|
||||
sync::sync(settings, false, db).await?;
|
||||
}
|
||||
@ -367,13 +386,56 @@ impl Cmd {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run(self, settings: &Settings, db: &impl Database) -> Result<()> {
|
||||
async fn init_store(
|
||||
context: atuin_client::database::Context,
|
||||
db: &impl Database,
|
||||
store: HistoryStore,
|
||||
) -> Result<()> {
|
||||
println!("Importing all history.db data into records.db");
|
||||
|
||||
let history = db
|
||||
.list(
|
||||
atuin_client::settings::FilterMode::Global,
|
||||
&context,
|
||||
None,
|
||||
false,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
|
||||
for i in history {
|
||||
println!("loaded {}", i.id);
|
||||
|
||||
if i.deleted_at.is_some() {
|
||||
store.push(i.clone()).await?;
|
||||
store.delete(i.id).await?;
|
||||
} else {
|
||||
store.push(i).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
self,
|
||||
settings: &Settings,
|
||||
db: &impl Database,
|
||||
store: SqliteStore,
|
||||
) -> Result<()> {
|
||||
let context = current_context();
|
||||
|
||||
let encryption_key: [u8; 32] = encryption::load_key(settings)
|
||||
.context("could not load encryption key")?
|
||||
.into();
|
||||
|
||||
let host_id = Settings::host_id().expect("failed to get host_id");
|
||||
let history_store = HistoryStore::new(store.clone(), host_id, encryption_key);
|
||||
|
||||
match self {
|
||||
Self::Start { command } => Self::handle_start(db, settings, &command).await,
|
||||
Self::End { id, exit, duration } => {
|
||||
Self::handle_end(db, settings, &id, exit, duration).await
|
||||
Self::handle_end(db, store, history_store, settings, &id, exit, duration).await
|
||||
}
|
||||
Self::List {
|
||||
session,
|
||||
@ -408,6 +470,8 @@ impl Cmd {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Self::InitStore => Self::init_store(context, db, history_store).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -35,11 +35,7 @@ pub enum Cmd {
|
||||
}
|
||||
|
||||
impl Cmd {
|
||||
pub async fn run(
|
||||
&self,
|
||||
settings: &Settings,
|
||||
store: &mut (impl Store + Send + Sync),
|
||||
) -> Result<()> {
|
||||
pub async fn run(&self, settings: &Settings, store: &(impl Store + Send + Sync)) -> Result<()> {
|
||||
let kv_store = KvStore::new();
|
||||
|
||||
let encryption_key: [u8; 32] = encryption::load_key(settings)
|
||||
|
63
atuin/src/command/client/record.rs
Normal file
63
atuin/src/command/client/record.rs
Normal file
@ -0,0 +1,63 @@
|
||||
use clap::Subcommand;
|
||||
use eyre::Result;
|
||||
|
||||
use atuin_client::{record::store::Store, settings::Settings};
|
||||
use time::OffsetDateTime;
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
#[command(infer_subcommands = true)]
|
||||
pub enum Cmd {
|
||||
Status,
|
||||
}
|
||||
|
||||
impl Cmd {
|
||||
pub async fn run(
|
||||
&self,
|
||||
_settings: &Settings,
|
||||
store: &(impl Store + Send + Sync),
|
||||
) -> Result<()> {
|
||||
let host_id = Settings::host_id().expect("failed to get host_id");
|
||||
|
||||
let status = store.status().await?;
|
||||
|
||||
// TODO: should probs build some data structure and then pretty-print it or smth
|
||||
for (host, st) in &status.hosts {
|
||||
let host_string = if host == &host_id {
|
||||
format!("host: {} <- CURRENT HOST", host.0.as_hyphenated())
|
||||
} else {
|
||||
format!("host: {}", host.0.as_hyphenated())
|
||||
};
|
||||
|
||||
println!("{host_string}");
|
||||
|
||||
for (tag, idx) in st {
|
||||
println!("\tstore: {tag}");
|
||||
|
||||
let first = store.first(*host, tag).await?;
|
||||
let last = store.last(*host, tag).await?;
|
||||
|
||||
println!("\t\tidx: {idx}");
|
||||
|
||||
if let Some(first) = first {
|
||||
println!("\t\tfirst: {}", first.id.0.as_hyphenated());
|
||||
|
||||
let time =
|
||||
OffsetDateTime::from_unix_timestamp_nanos(i128::from(first.timestamp))?;
|
||||
println!("\t\t\tcreated: {time}");
|
||||
}
|
||||
|
||||
if let Some(last) = last {
|
||||
println!("\t\tlast: {}", last.id.0.as_hyphenated());
|
||||
|
||||
let time =
|
||||
OffsetDateTime::from_unix_timestamp_nanos(i128::from(last.timestamp))?;
|
||||
println!("\t\t\tcreated: {time}");
|
||||
}
|
||||
}
|
||||
|
||||
println!();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -45,7 +45,7 @@ impl Cmd {
|
||||
self,
|
||||
settings: Settings,
|
||||
db: &impl Database,
|
||||
store: &mut (impl Store + Send + Sync),
|
||||
store: &(impl Store + Send + Sync),
|
||||
) -> Result<()> {
|
||||
match self {
|
||||
Self::Sync { force } => run(&settings, force, db, store).await,
|
||||
@ -75,14 +75,15 @@ async fn run(
|
||||
settings: &Settings,
|
||||
force: bool,
|
||||
db: &impl Database,
|
||||
store: &mut (impl Store + Send + Sync),
|
||||
store: &(impl Store + Send + Sync),
|
||||
) -> Result<()> {
|
||||
let (diff, remote_index) = sync::diff(settings, store).await?;
|
||||
let operations = sync::operations(diff, store).await?;
|
||||
let (uploaded, downloaded) =
|
||||
sync::sync_remote(operations, &remote_index, store, settings).await?;
|
||||
if settings.sync.records {
|
||||
let (diff, _) = sync::diff(settings, store).await?;
|
||||
let operations = sync::operations(diff, store).await?;
|
||||
let (uploaded, downloaded) = sync::sync_remote(operations, store, settings).await?;
|
||||
|
||||
println!("{uploaded}/{downloaded} up/down to record store");
|
||||
println!("{uploaded}/{downloaded} up/down to record store");
|
||||
}
|
||||
|
||||
atuin_client::sync::sync(settings, force, db).await?;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user