mirror of
https://github.com/atuinsh/atuin.git
synced 2024-11-25 01:34:13 +01:00
wip: need to do sqlite now
This commit is contained in:
parent
aa9fda8e71
commit
26abf41be4
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -230,6 +230,7 @@ dependencies = [
|
||||
"shellexpand",
|
||||
"sql-builder",
|
||||
"sqlx",
|
||||
"thiserror",
|
||||
"time",
|
||||
"tokio",
|
||||
"typed-builder",
|
||||
|
@ -46,6 +46,7 @@ uuid = { version = "1.3", features = ["v4", "serde"] }
|
||||
whoami = "1.1.2"
|
||||
typed-builder = "0.15.0"
|
||||
pretty_assertions = "1.3.0"
|
||||
thiserror = "1.0"
|
||||
|
||||
[workspace.dependencies.reqwest]
|
||||
version = "0.11"
|
||||
|
@ -48,6 +48,7 @@ rmp = { version = "0.8.11" }
|
||||
typed-builder = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
semver = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
futures = "0.3"
|
||||
crypto_secretbox = "0.1.1"
|
||||
generic-array = { version = "0.14", features = ["serde"] }
|
||||
|
@ -8,10 +8,9 @@ create table if not exists host(
|
||||
-- this will become more useful when we allow for multiple recipients of
|
||||
-- some data (same cek, multiple long term keys)
|
||||
-- This could be a key per host rather than one global key, or separate users.
|
||||
create table if not exists cek(
|
||||
create table if not exists cek (
|
||||
id integer primary key, -- normalization rowid
|
||||
wpk text not null, -- the encryption key, wrapped with the main key
|
||||
kid text not null, -- the key id we used to wrap the wpk
|
||||
cek text unique not null,
|
||||
);
|
||||
|
||||
create table if not exists store (
|
||||
|
@ -14,7 +14,7 @@ use sqlx::{
|
||||
Row,
|
||||
};
|
||||
|
||||
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
|
||||
use atuin_common::record::{EncryptedData, Host, HostId, Record, RecordId, RecordIndex};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::store::Store;
|
||||
@ -49,6 +49,49 @@ impl SqliteStore {
|
||||
Ok(Self { pool })
|
||||
}
|
||||
|
||||
async fn host(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, host: &Host) -> Result<u64> {
|
||||
// try selecting the id from the host. return if exists, or insert new and return id
|
||||
|
||||
let res: Result<(i64,), sqlx::Error> =
|
||||
sqlx::query_as("select id from host where host = ?1")
|
||||
.bind(host.id.0.as_hyphenated().to_string())
|
||||
.fetch_one(&mut **tx)
|
||||
.await;
|
||||
|
||||
if let Ok(res) = res {
|
||||
return Ok(res.0 as u64);
|
||||
}
|
||||
|
||||
let res: (i64,) =
|
||||
sqlx::query_as("insert into host(host, name) values (?1, ?2) returning id")
|
||||
.bind(host.id.0.as_hyphenated().to_string())
|
||||
.bind(host.name.as_str())
|
||||
.fetch_one(&mut **tx)
|
||||
.await?;
|
||||
|
||||
Ok(res.0 as u64)
|
||||
}
|
||||
|
||||
async fn cek(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, cek: &str) -> Result<u64> {
|
||||
// try selecting the id from the host. return if exists, or insert new and return id
|
||||
|
||||
let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select id from cek where cek = ?1")
|
||||
.bind(cek)
|
||||
.fetch_one(&mut **tx)
|
||||
.await;
|
||||
|
||||
if let Ok(res) = res {
|
||||
return Ok(res.0 as u64);
|
||||
}
|
||||
|
||||
let res: (i64,) = sqlx::query_as("insert into cek(cek) values (?1) returning id")
|
||||
.bind(cek)
|
||||
.fetch_one(&mut **tx)
|
||||
.await?;
|
||||
|
||||
Ok(res.0 as u64)
|
||||
}
|
||||
|
||||
async fn setup_db(pool: &SqlitePool) -> Result<()> {
|
||||
debug!("running sqlite database setup");
|
||||
|
||||
@ -61,19 +104,22 @@ impl SqliteStore {
|
||||
tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
|
||||
r: &Record<EncryptedData>,
|
||||
) -> Result<()> {
|
||||
let host = Self::host(tx, &r.host).await?;
|
||||
let cek = Self::cek(tx, r.data.content_encryption_key.as_str()).await?;
|
||||
|
||||
// In sqlite, we are "limited" to i64. But that is still fine, until 2262.
|
||||
sqlx::query(
|
||||
"insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek)
|
||||
"insert or ignore into store(id, idx, host, cek, timestamp, tag, version, data)
|
||||
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
|
||||
)
|
||||
.bind(r.id.0.as_simple().to_string())
|
||||
.bind(r.host.0.as_simple().to_string())
|
||||
.bind(r.tag.as_str())
|
||||
.bind(r.idx as i64)
|
||||
.bind(host as i64)
|
||||
.bind(cek as i64)
|
||||
.bind(r.timestamp as i64)
|
||||
.bind(r.parent.map(|p| p.0.as_simple().to_string()))
|
||||
.bind(r.tag.as_str())
|
||||
.bind(r.version.as_str())
|
||||
.bind(r.data.data.as_str())
|
||||
.bind(r.data.content_encryption_key.as_str())
|
||||
.execute(&mut **tx)
|
||||
.await?;
|
||||
|
||||
@ -81,26 +127,24 @@ impl SqliteStore {
|
||||
}
|
||||
|
||||
fn query_row(row: SqliteRow) -> Record<EncryptedData> {
|
||||
let idx: i64 = row.get("idx");
|
||||
let timestamp: i64 = row.get("timestamp");
|
||||
|
||||
// tbh at this point things are pretty fucked so just panic
|
||||
let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
|
||||
let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
|
||||
let parent: Option<&str> = row.get("parent");
|
||||
|
||||
let parent = parent
|
||||
.map(|parent| Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB"));
|
||||
let host =
|
||||
Uuid::from_str(row.get("host.host")).expect("invalid host UUID format in sqlite DB");
|
||||
|
||||
Record {
|
||||
id: RecordId(id),
|
||||
host: HostId(host),
|
||||
parent: parent.map(RecordId),
|
||||
idx: idx as u64,
|
||||
host: Host::new(HostId(host)),
|
||||
timestamp: timestamp as u64,
|
||||
tag: row.get("tag"),
|
||||
version: row.get("version"),
|
||||
data: EncryptedData {
|
||||
data: row.get("data"),
|
||||
content_encryption_key: row.get("cek"),
|
||||
content_encryption_key: row.get("cek.cek"),
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -124,7 +168,7 @@ impl Store for SqliteStore {
|
||||
}
|
||||
|
||||
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
|
||||
let res = sqlx::query("select * from records where id = ?1")
|
||||
let res = sqlx::query("select * from store inner join host on store.host=host.id inner join cek on store.cek=cek.id where store.id = ?1")
|
||||
.bind(id.0.as_simple().to_string())
|
||||
.map(Self::query_row)
|
||||
.fetch_one(&self.pool)
|
||||
@ -133,9 +177,9 @@ impl Store for SqliteStore {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
|
||||
async fn last(&self, host: HostId, tag: &str) -> Result<u64> {
|
||||
let res: (i64,) =
|
||||
sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
|
||||
sqlx::query_as("select max(idx) from records where host = ?1 and tag = ?2")
|
||||
.bind(host.0.as_simple().to_string())
|
||||
.bind(tag)
|
||||
.fetch_one(&self.pool)
|
||||
|
@ -21,7 +21,7 @@ pub trait Store {
|
||||
) -> Result<()>;
|
||||
|
||||
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
|
||||
async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
|
||||
async fn last(&self, host: HostId, tag: &str) -> Result<Option<u64>>;
|
||||
|
||||
/// Get the record that follows this record
|
||||
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>;
|
||||
|
@ -1,27 +1,49 @@
|
||||
// do a sync :O
|
||||
use eyre::Result;
|
||||
use thiserror::Error;
|
||||
|
||||
use super::store::Store;
|
||||
use crate::{api_client::Client, settings::Settings};
|
||||
|
||||
use atuin_common::record::{Diff, HostId, RecordId, RecordIndex};
|
||||
use atuin_common::record::{Diff, HostId, RecordId, RecordIdx, RecordStatus};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SyncError {
|
||||
#[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
|
||||
LocalAheadOtherHost,
|
||||
|
||||
#[error("some issue with the local database occured")]
|
||||
LocalStoreError,
|
||||
|
||||
#[error("something has gone wrong with the sync logic: {msg:?}")]
|
||||
SyncLogicError { msg: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
pub enum Operation {
|
||||
// Either upload or download until the tail matches the below
|
||||
// Either upload or download until the states matches the below
|
||||
Upload {
|
||||
tail: RecordId,
|
||||
local: RecordIdx,
|
||||
remote: Option<RecordIdx>,
|
||||
host: HostId,
|
||||
tag: String,
|
||||
},
|
||||
Download {
|
||||
tail: RecordId,
|
||||
local: Option<RecordIdx>,
|
||||
remote: RecordIdx,
|
||||
host: HostId,
|
||||
tag: String,
|
||||
},
|
||||
Noop {
|
||||
host: HostId,
|
||||
tag: String,
|
||||
},
|
||||
}
|
||||
|
||||
pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Diff>, RecordIndex)> {
|
||||
pub async fn diff(
|
||||
settings: &Settings,
|
||||
store: &mut impl Store,
|
||||
) -> Result<(Vec<Diff>, RecordStatus)> {
|
||||
let client = Client::new(
|
||||
&settings.sync_address,
|
||||
&settings.session_token,
|
||||
@ -41,8 +63,9 @@ pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Di
|
||||
// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download.
|
||||
// In theory this could be done as a part of the diffing stage, but it's easier to reason
|
||||
// about and test this way
|
||||
pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Operation>> {
|
||||
pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Operation>, SyncError> {
|
||||
let mut operations = Vec::with_capacity(diffs.len());
|
||||
let host = Settings::host_id().expect("got to record sync without a host id; abort");
|
||||
|
||||
for diff in diffs {
|
||||
// First, try to fetch the tail
|
||||
@ -50,30 +73,61 @@ pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Oper
|
||||
// host until it has the same tail. Ie, upload.
|
||||
// If it does not exist locally, that means remote is ahead of us.
|
||||
// Therefore, we need to download until our local tail matches
|
||||
let record = store.get(diff.tail).await;
|
||||
let last = store
|
||||
.last(diff.host, diff.tag.as_str())
|
||||
.await
|
||||
.map_err(|_| SyncError::LocalStoreError)?;
|
||||
|
||||
let op = if record.is_ok() {
|
||||
// if local has the ID, then we should find the actual tail of this
|
||||
// store, so we know what we need to update the remote to.
|
||||
let tail = store
|
||||
.tail(diff.host, diff.tag.as_str())
|
||||
.await?
|
||||
.expect("failed to fetch last record, expected tag/host to exist");
|
||||
|
||||
// TODO(ellie) update the diffing so that it stores the context of the current tail
|
||||
// that way, we can determine how much we need to upload.
|
||||
// For now just keep uploading until tails match
|
||||
|
||||
Operation::Upload {
|
||||
tail: tail.id,
|
||||
host: diff.host,
|
||||
tag: diff.tag,
|
||||
let op = match (last, diff.remote) {
|
||||
// We both have it! Could be either. Compare.
|
||||
(Some(last), Some(remote)) => {
|
||||
if last == remote {
|
||||
// between the diff and now, a sync has somehow occured.
|
||||
// regardless, no work is needed!
|
||||
Operation::Noop {
|
||||
host: diff.host,
|
||||
tag: diff.tag,
|
||||
}
|
||||
} else if last > remote {
|
||||
Operation::Upload {
|
||||
local: last,
|
||||
remote: Some(remote),
|
||||
host: diff.host,
|
||||
tag: diff.tag,
|
||||
}
|
||||
} else {
|
||||
Operation::Download {
|
||||
local: Some(last),
|
||||
remote,
|
||||
host: diff.host,
|
||||
tag: diff.tag,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Operation::Download {
|
||||
tail: diff.tail,
|
||||
|
||||
// Remote has it, we don't. Gotta be download
|
||||
(None, Some(remote)) => Operation::Download {
|
||||
local: None,
|
||||
remote,
|
||||
host: diff.host,
|
||||
tag: diff.tag,
|
||||
},
|
||||
|
||||
// We have it, remote doesn't. Gotta be upload.
|
||||
(Some(last), None) => Operation::Upload {
|
||||
local: last,
|
||||
remote: None,
|
||||
host: diff.host,
|
||||
tag: diff.tag,
|
||||
},
|
||||
|
||||
// something is pretty fucked.
|
||||
(None, None) => {
|
||||
return Err(SyncError::SyncLogicError {
|
||||
msg: String::from(
|
||||
"diff has nothing for local or remote - (host, tag) does not exist",
|
||||
),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
@ -86,8 +140,11 @@ pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Oper
|
||||
// with the same properties
|
||||
|
||||
operations.sort_by_key(|op| match op {
|
||||
Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
|
||||
Operation::Download { tail, host, .. } => ("download", *host, *tail),
|
||||
Operation::Noop { host, tag } => (0, *host, tag.clone()),
|
||||
|
||||
Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
|
||||
|
||||
Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
|
||||
});
|
||||
|
||||
Ok(operations)
|
||||
@ -95,133 +152,66 @@ pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Oper
|
||||
|
||||
async fn sync_upload(
|
||||
store: &mut impl Store,
|
||||
remote_index: &RecordIndex,
|
||||
client: &Client<'_>,
|
||||
op: (HostId, String, RecordId),
|
||||
) -> Result<i64> {
|
||||
host: HostId,
|
||||
tag: String,
|
||||
local: RecordIdx,
|
||||
remote: Option<RecordIdx>,
|
||||
) -> Result<i64, SyncError> {
|
||||
let expected = local - remote.unwrap_or(0);
|
||||
let upload_page_size = 100;
|
||||
let mut total = 0;
|
||||
|
||||
// so. we have an upload operation, with the tail representing the state
|
||||
// we want to get the remote to
|
||||
let current_tail = remote_index.get(op.0, op.1.clone());
|
||||
if expected < 0 {
|
||||
return Err(SyncError::SyncLogicError {
|
||||
msg: String::from("ran upload, but remote ahead of local"),
|
||||
});
|
||||
}
|
||||
|
||||
println!(
|
||||
"Syncing local {:?}/{}/{:?}, remote has {:?}",
|
||||
op.0, op.1, op.2, current_tail
|
||||
"Uploading {} records to {}/{}",
|
||||
expected,
|
||||
host.0.as_simple().to_string(),
|
||||
tag
|
||||
);
|
||||
|
||||
let start = if let Some(current_tail) = current_tail {
|
||||
current_tail
|
||||
} else {
|
||||
store
|
||||
.head(op.0, op.1.as_str())
|
||||
.await
|
||||
.expect("failed to fetch host/tag head")
|
||||
.expect("host/tag not in current index")
|
||||
.id
|
||||
};
|
||||
// TODO: actually upload lmfao
|
||||
|
||||
debug!("starting push to remote from: {:?}", start);
|
||||
|
||||
// we have the start point for sync. it is either the head of the store if
|
||||
// the remote has no data for it, or the tail that the remote has
|
||||
// we need to iterate from the remote tail, and keep going until
|
||||
// remote tail = current local tail
|
||||
|
||||
let mut record = if current_tail.is_some() {
|
||||
let r = store.get(start).await.unwrap();
|
||||
store.next(&r).await?
|
||||
} else {
|
||||
Some(store.get(start).await.unwrap())
|
||||
};
|
||||
|
||||
let mut buf = Vec::with_capacity(upload_page_size);
|
||||
|
||||
while let Some(r) = record {
|
||||
if buf.len() < upload_page_size {
|
||||
buf.push(r.clone());
|
||||
} else {
|
||||
client.post_records(&buf).await?;
|
||||
|
||||
// can we reset what we have? len = 0 but keep capacity
|
||||
buf = Vec::with_capacity(upload_page_size);
|
||||
}
|
||||
record = store.next(&r).await?;
|
||||
|
||||
total += 1;
|
||||
}
|
||||
|
||||
if !buf.is_empty() {
|
||||
client.post_records(&buf).await?;
|
||||
}
|
||||
|
||||
Ok(total)
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
async fn sync_download(
|
||||
store: &mut impl Store,
|
||||
remote_index: &RecordIndex,
|
||||
client: &Client<'_>,
|
||||
op: (HostId, String, RecordId),
|
||||
) -> Result<i64> {
|
||||
// TODO(ellie): implement variable page sizing like on history sync
|
||||
let download_page_size = 1;
|
||||
|
||||
host: HostId,
|
||||
tag: String,
|
||||
local: Option<RecordIdx>,
|
||||
remote: RecordIdx,
|
||||
) -> Result<i64, SyncError> {
|
||||
let expected = remote - local.unwrap_or(0);
|
||||
let download_page_size = 100;
|
||||
let mut total = 0;
|
||||
|
||||
// We know that the remote is ahead of us, so let's keep downloading until both
|
||||
// 1) The remote stops returning full pages
|
||||
// 2) The tail equals what we expect
|
||||
//
|
||||
// If (1) occurs without (2), then something is wrong with our index calculation
|
||||
// and we should bail.
|
||||
let remote_tail = remote_index
|
||||
.get(op.0, op.1.clone())
|
||||
.expect("remote index does not contain expected tail during download");
|
||||
let local_tail = store.tail(op.0, op.1.as_str()).await?;
|
||||
//
|
||||
// We expect that the operations diff will represent the desired state
|
||||
// In this case, that contains the remote tail.
|
||||
assert_eq!(remote_tail, op.2);
|
||||
|
||||
debug!("Downloading {:?}/{}/{:?} to local", op.0, op.1, op.2);
|
||||
|
||||
let mut records = client
|
||||
.next_records(
|
||||
op.0,
|
||||
op.1.clone(),
|
||||
local_tail.map(|r| r.id),
|
||||
download_page_size,
|
||||
)
|
||||
.await?;
|
||||
|
||||
debug!("received {} records from remote", records.len());
|
||||
|
||||
while !records.is_empty() {
|
||||
total += std::cmp::min(download_page_size, records.len() as u64);
|
||||
store.push_batch(records.iter()).await?;
|
||||
|
||||
if records.last().unwrap().id == remote_tail {
|
||||
break;
|
||||
}
|
||||
|
||||
records = client
|
||||
.next_records(
|
||||
op.0,
|
||||
op.1.clone(),
|
||||
records.last().map(|r| r.id),
|
||||
download_page_size,
|
||||
)
|
||||
.await?;
|
||||
if expected < 0 {
|
||||
return Err(SyncError::SyncLogicError {
|
||||
msg: String::from("ran download, but local ahead of remote"),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(total as i64)
|
||||
println!(
|
||||
"Downloading {} records from {}/{}",
|
||||
expected,
|
||||
host.0.as_simple().to_string(),
|
||||
tag
|
||||
);
|
||||
|
||||
// TODO: actually upload lmfao
|
||||
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
pub async fn sync_remote(
|
||||
operations: Vec<Operation>,
|
||||
remote_index: &RecordIndex,
|
||||
local_store: &mut impl Store,
|
||||
settings: &Settings,
|
||||
) -> Result<(i64, i64)> {
|
||||
@ -238,14 +228,23 @@ pub async fn sync_remote(
|
||||
// this can totally run in parallel, but lets get it working first
|
||||
for i in operations {
|
||||
match i {
|
||||
Operation::Upload { tail, host, tag } => {
|
||||
uploaded +=
|
||||
sync_upload(local_store, remote_index, &client, (host, tag, tail)).await?
|
||||
}
|
||||
Operation::Download { tail, host, tag } => {
|
||||
downloaded +=
|
||||
sync_download(local_store, remote_index, &client, (host, tag, tail)).await?
|
||||
Operation::Upload {
|
||||
host,
|
||||
tag,
|
||||
local,
|
||||
remote,
|
||||
} => uploaded += sync_upload(local_store, &client, host, tag, local, remote).await?,
|
||||
|
||||
Operation::Download {
|
||||
host,
|
||||
tag,
|
||||
local,
|
||||
remote,
|
||||
} => {
|
||||
downloaded += sync_download(local_store, &client, host, tag, local, remote).await?
|
||||
}
|
||||
|
||||
Operation::Noop { .. } => continue,
|
||||
}
|
||||
}
|
||||
|
||||
@ -266,13 +265,16 @@ mod tests {
|
||||
|
||||
fn test_record() -> Record<EncryptedData> {
|
||||
Record::builder()
|
||||
.host(HostId(atuin_common::utils::uuid_v7()))
|
||||
.host(atuin_common::record::Host::new(HostId(
|
||||
atuin_common::utils::uuid_v7(),
|
||||
)))
|
||||
.version("v1".into())
|
||||
.tag(atuin_common::utils::uuid_v7().simple().to_string())
|
||||
.data(EncryptedData {
|
||||
data: String::new(),
|
||||
content_encryption_key: String::new(),
|
||||
})
|
||||
.idx(0)
|
||||
.build()
|
||||
}
|
||||
|
||||
@ -322,9 +324,10 @@ mod tests {
|
||||
assert_eq!(
|
||||
operations[0],
|
||||
Operation::Upload {
|
||||
host: record.host,
|
||||
host: record.host.id,
|
||||
tag: record.tag,
|
||||
tail: record.id
|
||||
local: record.idx,
|
||||
remote: None,
|
||||
}
|
||||
);
|
||||
}
|
||||
@ -338,7 +341,7 @@ mod tests {
|
||||
|
||||
let remote_ahead = test_record();
|
||||
let local_ahead = shared_record
|
||||
.new_child(vec![1, 2, 3])
|
||||
.append(vec![1, 2, 3])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
|
||||
let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store
|
||||
@ -353,14 +356,16 @@ mod tests {
|
||||
operations,
|
||||
vec![
|
||||
Operation::Download {
|
||||
tail: remote_ahead.id,
|
||||
host: remote_ahead.host,
|
||||
host: remote_ahead.host.id,
|
||||
tag: remote_ahead.tag,
|
||||
local: None,
|
||||
remote: 0,
|
||||
},
|
||||
Operation::Upload {
|
||||
tail: local_ahead.id,
|
||||
host: local_ahead.host,
|
||||
host: local_ahead.host.id,
|
||||
tag: local_ahead.tag,
|
||||
local: 0,
|
||||
remote: None,
|
||||
},
|
||||
]
|
||||
);
|
||||
@ -379,11 +384,11 @@ mod tests {
|
||||
|
||||
let second_shared = test_record();
|
||||
let second_shared_remote_ahead = second_shared
|
||||
.new_child(vec![1, 2, 3])
|
||||
.append(vec![1, 2, 3])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
|
||||
let local_ahead = shared_record
|
||||
.new_child(vec![1, 2, 3])
|
||||
.append(vec![1, 2, 3])
|
||||
.encrypt::<PASETO_V4>(&[0; 32]);
|
||||
|
||||
let local = vec![
|
||||
@ -407,30 +412,37 @@ mod tests {
|
||||
|
||||
let mut result_ops = vec![
|
||||
Operation::Download {
|
||||
tail: remote_known.id,
|
||||
host: remote_known.host,
|
||||
host: remote_known.host.id,
|
||||
tag: remote_known.tag,
|
||||
local: Some(second_shared.idx),
|
||||
remote: second_shared_remote_ahead.idx,
|
||||
},
|
||||
Operation::Download {
|
||||
tail: second_shared_remote_ahead.id,
|
||||
host: second_shared.host,
|
||||
host: second_shared.host.id,
|
||||
tag: second_shared.tag,
|
||||
local: None,
|
||||
remote: remote_known.idx,
|
||||
},
|
||||
Operation::Upload {
|
||||
tail: local_ahead.id,
|
||||
host: local_ahead.host,
|
||||
host: local_ahead.host.id,
|
||||
tag: local_ahead.tag,
|
||||
local: local_ahead.idx,
|
||||
remote: Some(shared_record.idx),
|
||||
},
|
||||
Operation::Upload {
|
||||
tail: local_known.id,
|
||||
host: local_known.host,
|
||||
host: local_known.host.id,
|
||||
tag: local_known.tag,
|
||||
local: local_known.idx,
|
||||
remote: None,
|
||||
},
|
||||
];
|
||||
|
||||
result_ops.sort_by_key(|op| match op {
|
||||
Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
|
||||
Operation::Download { tail, host, .. } => ("download", *host, *tail),
|
||||
Operation::Noop { host, tag } => (0, *host, tag.clone()),
|
||||
|
||||
Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
|
||||
|
||||
Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
|
||||
});
|
||||
|
||||
assert_eq!(operations, result_ops);
|
||||
|
@ -18,16 +18,30 @@ pub struct EncryptedData {
|
||||
pub struct Diff {
|
||||
pub host: HostId,
|
||||
pub tag: String,
|
||||
pub tail: RecordId,
|
||||
pub local: Option<RecordIdx>,
|
||||
pub remote: Option<RecordIdx>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
|
||||
pub struct Host {
|
||||
pub id: u64,
|
||||
pub host: HostId,
|
||||
pub id: HostId,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl Host {
|
||||
pub fn new(id: HostId) -> Self {
|
||||
Host {
|
||||
id,
|
||||
name: String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
new_uuid!(RecordId);
|
||||
new_uuid!(HostId);
|
||||
|
||||
pub type RecordIdx = u64;
|
||||
|
||||
/// A single record stored inside of our local database
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)]
|
||||
pub struct Record<Data> {
|
||||
@ -36,7 +50,7 @@ pub struct Record<Data> {
|
||||
pub id: RecordId,
|
||||
|
||||
/// The integer record ID. This is only unique per (host, tag).
|
||||
pub idx: u64,
|
||||
pub idx: RecordIdx,
|
||||
|
||||
/// The unique ID of the host.
|
||||
// TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store
|
||||
@ -59,9 +73,6 @@ pub struct Record<Data> {
|
||||
pub data: Data,
|
||||
}
|
||||
|
||||
new_uuid!(RecordId);
|
||||
new_uuid!(HostId);
|
||||
|
||||
/// Extra data from the record that should be encoded in the data
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct AdditionalData<'a> {
|
||||
@ -73,11 +84,11 @@ pub struct AdditionalData<'a> {
|
||||
}
|
||||
|
||||
impl<Data> Record<Data> {
|
||||
pub fn new_child(&self, data: Vec<u8>) -> Record<DecryptedData> {
|
||||
pub fn append(&self, data: Vec<u8>) -> Record<DecryptedData> {
|
||||
Record::builder()
|
||||
.host(self.host)
|
||||
.host(self.host.clone())
|
||||
.version(self.version.clone())
|
||||
.parent(Some(self.id))
|
||||
.idx(self.idx + 1)
|
||||
.tag(self.tag.clone())
|
||||
.data(DecryptedData(data))
|
||||
.build()
|
||||
@ -87,74 +98,76 @@ impl<Data> Record<Data> {
|
||||
/// An index representing the current state of the record stores
|
||||
/// This can be both remote, or local, and compared in either direction
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct RecordIndex {
|
||||
// A map of host -> tag -> tail
|
||||
pub hosts: HashMap<HostId, HashMap<String, RecordId>>,
|
||||
pub struct RecordStatus {
|
||||
// A map of host -> tag -> max(idx)
|
||||
pub hosts: HashMap<HostId, HashMap<String, RecordIdx>>,
|
||||
}
|
||||
|
||||
impl Default for RecordIndex {
|
||||
impl Default for RecordStatus {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Extend<(HostId, String, RecordId)> for RecordIndex {
|
||||
fn extend<T: IntoIterator<Item = (HostId, String, RecordId)>>(&mut self, iter: T) {
|
||||
for (host, tag, tail_id) in iter {
|
||||
self.set_raw(host, tag, tail_id);
|
||||
impl Extend<(HostId, String, RecordIdx)> for RecordStatus {
|
||||
fn extend<T: IntoIterator<Item = (HostId, String, RecordIdx)>>(&mut self, iter: T) {
|
||||
for (host, tag, tail_idx) in iter {
|
||||
self.set_raw(host, tag, tail_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RecordIndex {
|
||||
pub fn new() -> RecordIndex {
|
||||
RecordIndex {
|
||||
impl RecordStatus {
|
||||
pub fn new() -> RecordStatus {
|
||||
RecordStatus {
|
||||
hosts: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert a new tail record into the store
|
||||
pub fn set(&mut self, tail: Record<DecryptedData>) {
|
||||
self.set_raw(tail.host, tail.tag, tail.id)
|
||||
self.set_raw(tail.host.id, tail.tag, tail.idx)
|
||||
}
|
||||
|
||||
pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordId) {
|
||||
pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordIdx) {
|
||||
self.hosts.entry(host).or_default().insert(tag, tail_id);
|
||||
}
|
||||
|
||||
pub fn get(&self, host: HostId, tag: String) -> Option<RecordId> {
|
||||
pub fn get(&self, host: HostId, tag: String) -> Option<RecordIdx> {
|
||||
self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned()
|
||||
}
|
||||
|
||||
/// Diff this index with another, likely remote index.
|
||||
/// The two diffs can then be reconciled, and the optimal change set calculated
|
||||
/// Returns a tuple, with (host, tag, Option(OTHER))
|
||||
/// OTHER is set to the value of the tail on the other machine. For example, if the
|
||||
/// other machine has a different tail, it will be the differing tail. This is useful to
|
||||
/// check if the other index is ahead of us, or behind.
|
||||
/// If the other index does not have the (host, tag) pair, then the other value will be None.
|
||||
/// OTHER is set to the value of the idx on the other machine. If it is greater than our index,
|
||||
/// then we need to do some downloading. If it is smaller, then we need to do some uploading
|
||||
/// Note that we cannot upload if we are not the owner of the record store - hosts can only
|
||||
/// write to their own store.
|
||||
pub fn diff(&self, other: &Self) -> Vec<Diff> {
|
||||
let mut ret = Vec::new();
|
||||
|
||||
// First, we check if other has everything that self has
|
||||
for (host, tag_map) in self.hosts.iter() {
|
||||
for (tag, tail) in tag_map.iter() {
|
||||
for (tag, idx) in tag_map.iter() {
|
||||
match other.get(*host, tag.clone()) {
|
||||
// The other store is all up to date! No diff.
|
||||
Some(t) if t.eq(tail) => continue,
|
||||
Some(t) if t.eq(idx) => continue,
|
||||
|
||||
// The other store does exist, but it is either ahead or behind us. A diff regardless
|
||||
// The other store does exist, and it is either ahead or behind us. A diff regardless
|
||||
Some(t) => ret.push(Diff {
|
||||
host: *host,
|
||||
tag: tag.clone(),
|
||||
tail: t,
|
||||
local: Some(*idx),
|
||||
remote: Some(t),
|
||||
}),
|
||||
|
||||
// The other store does not exist :O
|
||||
None => ret.push(Diff {
|
||||
host: *host,
|
||||
tag: tag.clone(),
|
||||
tail: *tail,
|
||||
local: Some(*idx),
|
||||
remote: None,
|
||||
}),
|
||||
};
|
||||
}
|
||||
@ -165,7 +178,7 @@ impl RecordIndex {
|
||||
|
||||
// account for that!
|
||||
for (host, tag_map) in other.hosts.iter() {
|
||||
for (tag, tail) in tag_map.iter() {
|
||||
for (tag, idx) in tag_map.iter() {
|
||||
match self.get(*host, tag.clone()) {
|
||||
// If we have this host/tag combo, the comparison and diff will have already happened above
|
||||
Some(_) => continue,
|
||||
@ -173,13 +186,22 @@ impl RecordIndex {
|
||||
None => ret.push(Diff {
|
||||
host: *host,
|
||||
tag: tag.clone(),
|
||||
tail: *tail,
|
||||
remote: Some(*idx),
|
||||
local: None,
|
||||
}),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
ret.sort_by(|a, b| (a.host, a.tag.clone(), a.tail).cmp(&(b.host, b.tag.clone(), b.tail)));
|
||||
// Stability is a nice property to have
|
||||
ret.sort_by(|a, b| {
|
||||
(a.host, a.tag.clone(), a.local, a.remote).cmp(&(
|
||||
b.host,
|
||||
b.tag.clone(),
|
||||
b.local,
|
||||
b.remote,
|
||||
))
|
||||
});
|
||||
ret
|
||||
}
|
||||
}
|
||||
@ -204,14 +226,14 @@ impl Record<DecryptedData> {
|
||||
id: &self.id,
|
||||
version: &self.version,
|
||||
tag: &self.tag,
|
||||
host: &self.host,
|
||||
parent: self.parent.as_ref(),
|
||||
host: &self.host.id,
|
||||
idx: &self.idx,
|
||||
};
|
||||
Record {
|
||||
data: E::encrypt(self.data, ad, key),
|
||||
id: self.id,
|
||||
host: self.host,
|
||||
parent: self.parent,
|
||||
idx: self.idx,
|
||||
timestamp: self.timestamp,
|
||||
version: self.version,
|
||||
tag: self.tag,
|
||||
@ -225,14 +247,14 @@ impl Record<EncryptedData> {
|
||||
id: &self.id,
|
||||
version: &self.version,
|
||||
tag: &self.tag,
|
||||
host: &self.host,
|
||||
parent: self.parent.as_ref(),
|
||||
host: &self.host.id,
|
||||
idx: &self.idx,
|
||||
};
|
||||
Ok(Record {
|
||||
data: E::decrypt(self.data, ad, key)?,
|
||||
id: self.id,
|
||||
host: self.host,
|
||||
parent: self.parent,
|
||||
idx: self.idx,
|
||||
timestamp: self.timestamp,
|
||||
version: self.version,
|
||||
tag: self.tag,
|
||||
@ -248,14 +270,14 @@ impl Record<EncryptedData> {
|
||||
id: &self.id,
|
||||
version: &self.version,
|
||||
tag: &self.tag,
|
||||
host: &self.host,
|
||||
parent: self.parent.as_ref(),
|
||||
host: &self.host.id,
|
||||
idx: &self.idx,
|
||||
};
|
||||
Ok(Record {
|
||||
data: E::re_encrypt(self.data, ad, old_key, new_key)?,
|
||||
id: self.id,
|
||||
host: self.host,
|
||||
parent: self.parent,
|
||||
idx: self.idx,
|
||||
timestamp: self.timestamp,
|
||||
version: self.version,
|
||||
tag: self.tag,
|
||||
@ -265,31 +287,32 @@ impl Record<EncryptedData> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::record::HostId;
|
||||
use crate::record::{Host, HostId};
|
||||
|
||||
use super::{DecryptedData, Diff, Record, RecordIndex};
|
||||
use super::{DecryptedData, Diff, Record, RecordStatus};
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn test_record() -> Record<DecryptedData> {
|
||||
Record::builder()
|
||||
.host(HostId(crate::utils::uuid_v7()))
|
||||
.host(Host::new(HostId(crate::utils::uuid_v7())))
|
||||
.version("v1".into())
|
||||
.tag(crate::utils::uuid_v7().simple().to_string())
|
||||
.data(DecryptedData(vec![0, 1, 2, 3]))
|
||||
.idx(0)
|
||||
.build()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_index() {
|
||||
let mut index = RecordIndex::new();
|
||||
let mut index = RecordStatus::new();
|
||||
let record = test_record();
|
||||
|
||||
index.set(record.clone());
|
||||
|
||||
let tail = index.get(record.host, record.tag);
|
||||
let tail = index.get(record.host.id, record.tag);
|
||||
|
||||
assert_eq!(
|
||||
record.id,
|
||||
record.idx,
|
||||
tail.expect("tail not in store"),
|
||||
"tail in store did not match"
|
||||
);
|
||||
@ -297,17 +320,17 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn record_index_overwrite() {
|
||||
let mut index = RecordIndex::new();
|
||||
let mut index = RecordStatus::new();
|
||||
let record = test_record();
|
||||
let child = record.new_child(vec![1, 2, 3]);
|
||||
let child = record.append(vec![1, 2, 3]);
|
||||
|
||||
index.set(record.clone());
|
||||
index.set(child.clone());
|
||||
|
||||
let tail = index.get(record.host, record.tag);
|
||||
let tail = index.get(record.host.id, record.tag);
|
||||
|
||||
assert_eq!(
|
||||
child.id,
|
||||
child.idx,
|
||||
tail.expect("tail not in store"),
|
||||
"tail in store did not match"
|
||||
);
|
||||
@ -317,8 +340,8 @@ mod tests {
|
||||
fn record_index_no_diff() {
|
||||
// Here, they both have the same version and should have no diff
|
||||
|
||||
let mut index1 = RecordIndex::new();
|
||||
let mut index2 = RecordIndex::new();
|
||||
let mut index1 = RecordStatus::new();
|
||||
let mut index2 = RecordStatus::new();
|
||||
|
||||
let record1 = test_record();
|
||||
|
||||
@ -334,11 +357,11 @@ mod tests {
|
||||
fn record_index_single_diff() {
|
||||
// Here, they both have the same stores, but one is ahead by a single record
|
||||
|
||||
let mut index1 = RecordIndex::new();
|
||||
let mut index2 = RecordIndex::new();
|
||||
let mut index1 = RecordStatus::new();
|
||||
let mut index2 = RecordStatus::new();
|
||||
|
||||
let record1 = test_record();
|
||||
let record2 = record1.new_child(vec![1, 2, 3]);
|
||||
let record2 = record1.append(vec![1, 2, 3]);
|
||||
|
||||
index1.set(record1);
|
||||
index2.set(record2.clone());
|
||||
@ -349,9 +372,10 @@ mod tests {
|
||||
assert_eq!(
|
||||
diff[0],
|
||||
Diff {
|
||||
host: record2.host,
|
||||
host: record2.host.id,
|
||||
tag: record2.tag,
|
||||
tail: record2.id
|
||||
remote: Some(1),
|
||||
local: Some(0)
|
||||
}
|
||||
);
|
||||
}
|
||||
@ -359,14 +383,14 @@ mod tests {
|
||||
#[test]
|
||||
fn record_index_multi_diff() {
|
||||
// A much more complex case, with a bunch more checks
|
||||
let mut index1 = RecordIndex::new();
|
||||
let mut index2 = RecordIndex::new();
|
||||
let mut index1 = RecordStatus::new();
|
||||
let mut index2 = RecordStatus::new();
|
||||
|
||||
let store1record1 = test_record();
|
||||
let store1record2 = store1record1.new_child(vec![1, 2, 3]);
|
||||
let store1record2 = store1record1.append(vec![1, 2, 3]);
|
||||
|
||||
let store2record1 = test_record();
|
||||
let store2record2 = store2record1.new_child(vec![1, 2, 3]);
|
||||
let store2record2 = store2record1.append(vec![1, 2, 3]);
|
||||
|
||||
let store3record1 = test_record();
|
||||
|
||||
|
@ -14,7 +14,7 @@ use self::{
|
||||
models::{History, NewHistory, NewSession, NewUser, Session, User},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
|
||||
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordStatus};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset};
|
||||
use tracing::instrument;
|
||||
@ -73,7 +73,7 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
|
||||
) -> DbResult<Vec<Record<EncryptedData>>>;
|
||||
|
||||
// Return the tail record ID for each store, so (HostID, Tag, TailRecordID)
|
||||
async fn tail_records(&self, user: &User) -> DbResult<RecordIndex>;
|
||||
async fn tail_records(&self, user: &User) -> DbResult<RecordStatus>;
|
||||
|
||||
async fn count_history_range(&self, user: &User, range: Range<OffsetDateTime>)
|
||||
-> DbResult<i64>;
|
||||
|
Loading…
Reference in New Issue
Block a user