wip: need to do sqlite now

This commit is contained in:
Ellie Huxtable 2023-12-02 09:14:30 +00:00
parent aa9fda8e71
commit 26abf41be4
9 changed files with 331 additions and 249 deletions

1
Cargo.lock generated
View File

@ -230,6 +230,7 @@ dependencies = [
"shellexpand",
"sql-builder",
"sqlx",
"thiserror",
"time",
"tokio",
"typed-builder",

View File

@ -46,6 +46,7 @@ uuid = { version = "1.3", features = ["v4", "serde"] }
whoami = "1.1.2"
typed-builder = "0.15.0"
pretty_assertions = "1.3.0"
thiserror = "1.0"
[workspace.dependencies.reqwest]
version = "0.11"

View File

@ -48,6 +48,7 @@ rmp = { version = "0.8.11" }
typed-builder = { workspace = true }
tokio = { workspace = true }
semver = { workspace = true }
thiserror = { workspace = true }
futures = "0.3"
crypto_secretbox = "0.1.1"
generic-array = { version = "0.14", features = ["serde"] }

View File

@ -8,10 +8,9 @@ create table if not exists host(
-- this will become more useful when we allow for multiple recipients of
-- some data (same cek, multiple long term keys)
-- This could be a key per host rather than one global key, or separate users.
create table if not exists cek(
create table if not exists cek (
id integer primary key, -- normalization rowid
wpk text not null, -- the encryption key, wrapped with the main key
kid text not null, -- the key id we used to wrap the wpk
cek text unique not null,
);
create table if not exists store (

View File

@ -14,7 +14,7 @@ use sqlx::{
Row,
};
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
use atuin_common::record::{EncryptedData, Host, HostId, Record, RecordId, RecordIndex};
use uuid::Uuid;
use super::store::Store;
@ -49,6 +49,49 @@ impl SqliteStore {
Ok(Self { pool })
}
async fn host(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, host: &Host) -> Result<u64> {
// try selecting the id from the host. return if exists, or insert new and return id
let res: Result<(i64,), sqlx::Error> =
sqlx::query_as("select id from host where host = ?1")
.bind(host.id.0.as_hyphenated().to_string())
.fetch_one(&mut **tx)
.await;
if let Ok(res) = res {
return Ok(res.0 as u64);
}
let res: (i64,) =
sqlx::query_as("insert into host(host, name) values (?1, ?2) returning id")
.bind(host.id.0.as_hyphenated().to_string())
.bind(host.name.as_str())
.fetch_one(&mut **tx)
.await?;
Ok(res.0 as u64)
}
async fn cek(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, cek: &str) -> Result<u64> {
// try selecting the id from the host. return if exists, or insert new and return id
let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select id from cek where cek = ?1")
.bind(cek)
.fetch_one(&mut **tx)
.await;
if let Ok(res) = res {
return Ok(res.0 as u64);
}
let res: (i64,) = sqlx::query_as("insert into cek(cek) values (?1) returning id")
.bind(cek)
.fetch_one(&mut **tx)
.await?;
Ok(res.0 as u64)
}
async fn setup_db(pool: &SqlitePool) -> Result<()> {
debug!("running sqlite database setup");
@ -61,19 +104,22 @@ impl SqliteStore {
tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
r: &Record<EncryptedData>,
) -> Result<()> {
let host = Self::host(tx, &r.host).await?;
let cek = Self::cek(tx, r.data.content_encryption_key.as_str()).await?;
// In sqlite, we are "limited" to i64. But that is still fine, until 2262.
sqlx::query(
"insert or ignore into records(id, host, tag, timestamp, parent, version, data, cek)
"insert or ignore into store(id, idx, host, cek, timestamp, tag, version, data)
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)
.bind(r.id.0.as_simple().to_string())
.bind(r.host.0.as_simple().to_string())
.bind(r.tag.as_str())
.bind(r.idx as i64)
.bind(host as i64)
.bind(cek as i64)
.bind(r.timestamp as i64)
.bind(r.parent.map(|p| p.0.as_simple().to_string()))
.bind(r.tag.as_str())
.bind(r.version.as_str())
.bind(r.data.data.as_str())
.bind(r.data.content_encryption_key.as_str())
.execute(&mut **tx)
.await?;
@ -81,26 +127,24 @@ impl SqliteStore {
}
fn query_row(row: SqliteRow) -> Record<EncryptedData> {
let idx: i64 = row.get("idx");
let timestamp: i64 = row.get("timestamp");
// tbh at this point things are pretty fucked so just panic
let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
let parent: Option<&str> = row.get("parent");
let parent = parent
.map(|parent| Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB"));
let host =
Uuid::from_str(row.get("host.host")).expect("invalid host UUID format in sqlite DB");
Record {
id: RecordId(id),
host: HostId(host),
parent: parent.map(RecordId),
idx: idx as u64,
host: Host::new(HostId(host)),
timestamp: timestamp as u64,
tag: row.get("tag"),
version: row.get("version"),
data: EncryptedData {
data: row.get("data"),
content_encryption_key: row.get("cek"),
content_encryption_key: row.get("cek.cek"),
},
}
}
@ -124,7 +168,7 @@ impl Store for SqliteStore {
}
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
let res = sqlx::query("select * from records where id = ?1")
let res = sqlx::query("select * from store inner join host on store.host=host.id inner join cek on store.cek=cek.id where store.id = ?1")
.bind(id.0.as_simple().to_string())
.map(Self::query_row)
.fetch_one(&self.pool)
@ -133,9 +177,9 @@ impl Store for SqliteStore {
Ok(res)
}
async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
async fn last(&self, host: HostId, tag: &str) -> Result<u64> {
let res: (i64,) =
sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
sqlx::query_as("select max(idx) from records where host = ?1 and tag = ?2")
.bind(host.0.as_simple().to_string())
.bind(tag)
.fetch_one(&self.pool)

View File

@ -21,7 +21,7 @@ pub trait Store {
) -> Result<()>;
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
async fn last(&self, host: HostId, tag: &str) -> Result<Option<u64>>;
/// Get the record that follows this record
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>;

View File

@ -1,27 +1,49 @@
// do a sync :O
use eyre::Result;
use thiserror::Error;
use super::store::Store;
use crate::{api_client::Client, settings::Settings};
use atuin_common::record::{Diff, HostId, RecordId, RecordIndex};
use atuin_common::record::{Diff, HostId, RecordId, RecordIdx, RecordStatus};
#[derive(Error, Debug)]
pub enum SyncError {
#[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
LocalAheadOtherHost,
#[error("some issue with the local database occured")]
LocalStoreError,
#[error("something has gone wrong with the sync logic: {msg:?}")]
SyncLogicError { msg: String },
}
#[derive(Debug, Eq, PartialEq)]
pub enum Operation {
// Either upload or download until the tail matches the below
// Either upload or download until the states matches the below
Upload {
tail: RecordId,
local: RecordIdx,
remote: Option<RecordIdx>,
host: HostId,
tag: String,
},
Download {
tail: RecordId,
local: Option<RecordIdx>,
remote: RecordIdx,
host: HostId,
tag: String,
},
Noop {
host: HostId,
tag: String,
},
}
pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Diff>, RecordIndex)> {
pub async fn diff(
settings: &Settings,
store: &mut impl Store,
) -> Result<(Vec<Diff>, RecordStatus)> {
let client = Client::new(
&settings.sync_address,
&settings.session_token,
@ -41,8 +63,9 @@ pub async fn diff(settings: &Settings, store: &mut impl Store) -> Result<(Vec<Di
// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download.
// In theory this could be done as a part of the diffing stage, but it's easier to reason
// about and test this way
pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Operation>> {
pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Operation>, SyncError> {
let mut operations = Vec::with_capacity(diffs.len());
let host = Settings::host_id().expect("got to record sync without a host id; abort");
for diff in diffs {
// First, try to fetch the tail
@ -50,30 +73,61 @@ pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Oper
// host until it has the same tail. Ie, upload.
// If it does not exist locally, that means remote is ahead of us.
// Therefore, we need to download until our local tail matches
let record = store.get(diff.tail).await;
let last = store
.last(diff.host, diff.tag.as_str())
.await
.map_err(|_| SyncError::LocalStoreError)?;
let op = if record.is_ok() {
// if local has the ID, then we should find the actual tail of this
// store, so we know what we need to update the remote to.
let tail = store
.tail(diff.host, diff.tag.as_str())
.await?
.expect("failed to fetch last record, expected tag/host to exist");
// TODO(ellie) update the diffing so that it stores the context of the current tail
// that way, we can determine how much we need to upload.
// For now just keep uploading until tails match
Operation::Upload {
tail: tail.id,
host: diff.host,
tag: diff.tag,
let op = match (last, diff.remote) {
// We both have it! Could be either. Compare.
(Some(last), Some(remote)) => {
if last == remote {
// between the diff and now, a sync has somehow occured.
// regardless, no work is needed!
Operation::Noop {
host: diff.host,
tag: diff.tag,
}
} else if last > remote {
Operation::Upload {
local: last,
remote: Some(remote),
host: diff.host,
tag: diff.tag,
}
} else {
Operation::Download {
local: Some(last),
remote,
host: diff.host,
tag: diff.tag,
}
}
}
} else {
Operation::Download {
tail: diff.tail,
// Remote has it, we don't. Gotta be download
(None, Some(remote)) => Operation::Download {
local: None,
remote,
host: diff.host,
tag: diff.tag,
},
// We have it, remote doesn't. Gotta be upload.
(Some(last), None) => Operation::Upload {
local: last,
remote: None,
host: diff.host,
tag: diff.tag,
},
// something is pretty fucked.
(None, None) => {
return Err(SyncError::SyncLogicError {
msg: String::from(
"diff has nothing for local or remote - (host, tag) does not exist",
),
})
}
};
@ -86,8 +140,11 @@ pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Oper
// with the same properties
operations.sort_by_key(|op| match op {
Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
Operation::Download { tail, host, .. } => ("download", *host, *tail),
Operation::Noop { host, tag } => (0, *host, tag.clone()),
Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
});
Ok(operations)
@ -95,133 +152,66 @@ pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Oper
async fn sync_upload(
store: &mut impl Store,
remote_index: &RecordIndex,
client: &Client<'_>,
op: (HostId, String, RecordId),
) -> Result<i64> {
host: HostId,
tag: String,
local: RecordIdx,
remote: Option<RecordIdx>,
) -> Result<i64, SyncError> {
let expected = local - remote.unwrap_or(0);
let upload_page_size = 100;
let mut total = 0;
// so. we have an upload operation, with the tail representing the state
// we want to get the remote to
let current_tail = remote_index.get(op.0, op.1.clone());
if expected < 0 {
return Err(SyncError::SyncLogicError {
msg: String::from("ran upload, but remote ahead of local"),
});
}
println!(
"Syncing local {:?}/{}/{:?}, remote has {:?}",
op.0, op.1, op.2, current_tail
"Uploading {} records to {}/{}",
expected,
host.0.as_simple().to_string(),
tag
);
let start = if let Some(current_tail) = current_tail {
current_tail
} else {
store
.head(op.0, op.1.as_str())
.await
.expect("failed to fetch host/tag head")
.expect("host/tag not in current index")
.id
};
// TODO: actually upload lmfao
debug!("starting push to remote from: {:?}", start);
// we have the start point for sync. it is either the head of the store if
// the remote has no data for it, or the tail that the remote has
// we need to iterate from the remote tail, and keep going until
// remote tail = current local tail
let mut record = if current_tail.is_some() {
let r = store.get(start).await.unwrap();
store.next(&r).await?
} else {
Some(store.get(start).await.unwrap())
};
let mut buf = Vec::with_capacity(upload_page_size);
while let Some(r) = record {
if buf.len() < upload_page_size {
buf.push(r.clone());
} else {
client.post_records(&buf).await?;
// can we reset what we have? len = 0 but keep capacity
buf = Vec::with_capacity(upload_page_size);
}
record = store.next(&r).await?;
total += 1;
}
if !buf.is_empty() {
client.post_records(&buf).await?;
}
Ok(total)
Ok(0)
}
async fn sync_download(
store: &mut impl Store,
remote_index: &RecordIndex,
client: &Client<'_>,
op: (HostId, String, RecordId),
) -> Result<i64> {
// TODO(ellie): implement variable page sizing like on history sync
let download_page_size = 1;
host: HostId,
tag: String,
local: Option<RecordIdx>,
remote: RecordIdx,
) -> Result<i64, SyncError> {
let expected = remote - local.unwrap_or(0);
let download_page_size = 100;
let mut total = 0;
// We know that the remote is ahead of us, so let's keep downloading until both
// 1) The remote stops returning full pages
// 2) The tail equals what we expect
//
// If (1) occurs without (2), then something is wrong with our index calculation
// and we should bail.
let remote_tail = remote_index
.get(op.0, op.1.clone())
.expect("remote index does not contain expected tail during download");
let local_tail = store.tail(op.0, op.1.as_str()).await?;
//
// We expect that the operations diff will represent the desired state
// In this case, that contains the remote tail.
assert_eq!(remote_tail, op.2);
debug!("Downloading {:?}/{}/{:?} to local", op.0, op.1, op.2);
let mut records = client
.next_records(
op.0,
op.1.clone(),
local_tail.map(|r| r.id),
download_page_size,
)
.await?;
debug!("received {} records from remote", records.len());
while !records.is_empty() {
total += std::cmp::min(download_page_size, records.len() as u64);
store.push_batch(records.iter()).await?;
if records.last().unwrap().id == remote_tail {
break;
}
records = client
.next_records(
op.0,
op.1.clone(),
records.last().map(|r| r.id),
download_page_size,
)
.await?;
if expected < 0 {
return Err(SyncError::SyncLogicError {
msg: String::from("ran download, but local ahead of remote"),
});
}
Ok(total as i64)
println!(
"Downloading {} records from {}/{}",
expected,
host.0.as_simple().to_string(),
tag
);
// TODO: actually upload lmfao
Ok(0)
}
pub async fn sync_remote(
operations: Vec<Operation>,
remote_index: &RecordIndex,
local_store: &mut impl Store,
settings: &Settings,
) -> Result<(i64, i64)> {
@ -238,14 +228,23 @@ pub async fn sync_remote(
// this can totally run in parallel, but lets get it working first
for i in operations {
match i {
Operation::Upload { tail, host, tag } => {
uploaded +=
sync_upload(local_store, remote_index, &client, (host, tag, tail)).await?
}
Operation::Download { tail, host, tag } => {
downloaded +=
sync_download(local_store, remote_index, &client, (host, tag, tail)).await?
Operation::Upload {
host,
tag,
local,
remote,
} => uploaded += sync_upload(local_store, &client, host, tag, local, remote).await?,
Operation::Download {
host,
tag,
local,
remote,
} => {
downloaded += sync_download(local_store, &client, host, tag, local, remote).await?
}
Operation::Noop { .. } => continue,
}
}
@ -266,13 +265,16 @@ mod tests {
fn test_record() -> Record<EncryptedData> {
Record::builder()
.host(HostId(atuin_common::utils::uuid_v7()))
.host(atuin_common::record::Host::new(HostId(
atuin_common::utils::uuid_v7(),
)))
.version("v1".into())
.tag(atuin_common::utils::uuid_v7().simple().to_string())
.data(EncryptedData {
data: String::new(),
content_encryption_key: String::new(),
})
.idx(0)
.build()
}
@ -322,9 +324,10 @@ mod tests {
assert_eq!(
operations[0],
Operation::Upload {
host: record.host,
host: record.host.id,
tag: record.tag,
tail: record.id
local: record.idx,
remote: None,
}
);
}
@ -338,7 +341,7 @@ mod tests {
let remote_ahead = test_record();
let local_ahead = shared_record
.new_child(vec![1, 2, 3])
.append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store
@ -353,14 +356,16 @@ mod tests {
operations,
vec![
Operation::Download {
tail: remote_ahead.id,
host: remote_ahead.host,
host: remote_ahead.host.id,
tag: remote_ahead.tag,
local: None,
remote: 0,
},
Operation::Upload {
tail: local_ahead.id,
host: local_ahead.host,
host: local_ahead.host.id,
tag: local_ahead.tag,
local: 0,
remote: None,
},
]
);
@ -379,11 +384,11 @@ mod tests {
let second_shared = test_record();
let second_shared_remote_ahead = second_shared
.new_child(vec![1, 2, 3])
.append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let local_ahead = shared_record
.new_child(vec![1, 2, 3])
.append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let local = vec![
@ -407,30 +412,37 @@ mod tests {
let mut result_ops = vec![
Operation::Download {
tail: remote_known.id,
host: remote_known.host,
host: remote_known.host.id,
tag: remote_known.tag,
local: Some(second_shared.idx),
remote: second_shared_remote_ahead.idx,
},
Operation::Download {
tail: second_shared_remote_ahead.id,
host: second_shared.host,
host: second_shared.host.id,
tag: second_shared.tag,
local: None,
remote: remote_known.idx,
},
Operation::Upload {
tail: local_ahead.id,
host: local_ahead.host,
host: local_ahead.host.id,
tag: local_ahead.tag,
local: local_ahead.idx,
remote: Some(shared_record.idx),
},
Operation::Upload {
tail: local_known.id,
host: local_known.host,
host: local_known.host.id,
tag: local_known.tag,
local: local_known.idx,
remote: None,
},
];
result_ops.sort_by_key(|op| match op {
Operation::Upload { tail, host, .. } => ("upload", *host, *tail),
Operation::Download { tail, host, .. } => ("download", *host, *tail),
Operation::Noop { host, tag } => (0, *host, tag.clone()),
Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
});
assert_eq!(operations, result_ops);

View File

@ -18,16 +18,30 @@ pub struct EncryptedData {
pub struct Diff {
pub host: HostId,
pub tag: String,
pub tail: RecordId,
pub local: Option<RecordIdx>,
pub remote: Option<RecordIdx>,
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct Host {
pub id: u64,
pub host: HostId,
pub id: HostId,
pub name: String,
}
impl Host {
pub fn new(id: HostId) -> Self {
Host {
id,
name: String::new(),
}
}
}
new_uuid!(RecordId);
new_uuid!(HostId);
pub type RecordIdx = u64;
/// A single record stored inside of our local database
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)]
pub struct Record<Data> {
@ -36,7 +50,7 @@ pub struct Record<Data> {
pub id: RecordId,
/// The integer record ID. This is only unique per (host, tag).
pub idx: u64,
pub idx: RecordIdx,
/// The unique ID of the host.
// TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store
@ -59,9 +73,6 @@ pub struct Record<Data> {
pub data: Data,
}
new_uuid!(RecordId);
new_uuid!(HostId);
/// Extra data from the record that should be encoded in the data
#[derive(Debug, Copy, Clone)]
pub struct AdditionalData<'a> {
@ -73,11 +84,11 @@ pub struct AdditionalData<'a> {
}
impl<Data> Record<Data> {
pub fn new_child(&self, data: Vec<u8>) -> Record<DecryptedData> {
pub fn append(&self, data: Vec<u8>) -> Record<DecryptedData> {
Record::builder()
.host(self.host)
.host(self.host.clone())
.version(self.version.clone())
.parent(Some(self.id))
.idx(self.idx + 1)
.tag(self.tag.clone())
.data(DecryptedData(data))
.build()
@ -87,74 +98,76 @@ impl<Data> Record<Data> {
/// An index representing the current state of the record stores
/// This can be both remote, or local, and compared in either direction
#[derive(Debug, Serialize, Deserialize)]
pub struct RecordIndex {
// A map of host -> tag -> tail
pub hosts: HashMap<HostId, HashMap<String, RecordId>>,
pub struct RecordStatus {
// A map of host -> tag -> max(idx)
pub hosts: HashMap<HostId, HashMap<String, RecordIdx>>,
}
impl Default for RecordIndex {
impl Default for RecordStatus {
fn default() -> Self {
Self::new()
}
}
impl Extend<(HostId, String, RecordId)> for RecordIndex {
fn extend<T: IntoIterator<Item = (HostId, String, RecordId)>>(&mut self, iter: T) {
for (host, tag, tail_id) in iter {
self.set_raw(host, tag, tail_id);
impl Extend<(HostId, String, RecordIdx)> for RecordStatus {
fn extend<T: IntoIterator<Item = (HostId, String, RecordIdx)>>(&mut self, iter: T) {
for (host, tag, tail_idx) in iter {
self.set_raw(host, tag, tail_idx);
}
}
}
impl RecordIndex {
pub fn new() -> RecordIndex {
RecordIndex {
impl RecordStatus {
pub fn new() -> RecordStatus {
RecordStatus {
hosts: HashMap::new(),
}
}
/// Insert a new tail record into the store
pub fn set(&mut self, tail: Record<DecryptedData>) {
self.set_raw(tail.host, tail.tag, tail.id)
self.set_raw(tail.host.id, tail.tag, tail.idx)
}
pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordId) {
pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordIdx) {
self.hosts.entry(host).or_default().insert(tag, tail_id);
}
pub fn get(&self, host: HostId, tag: String) -> Option<RecordId> {
pub fn get(&self, host: HostId, tag: String) -> Option<RecordIdx> {
self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned()
}
/// Diff this index with another, likely remote index.
/// The two diffs can then be reconciled, and the optimal change set calculated
/// Returns a tuple, with (host, tag, Option(OTHER))
/// OTHER is set to the value of the tail on the other machine. For example, if the
/// other machine has a different tail, it will be the differing tail. This is useful to
/// check if the other index is ahead of us, or behind.
/// If the other index does not have the (host, tag) pair, then the other value will be None.
/// OTHER is set to the value of the idx on the other machine. If it is greater than our index,
/// then we need to do some downloading. If it is smaller, then we need to do some uploading
/// Note that we cannot upload if we are not the owner of the record store - hosts can only
/// write to their own store.
pub fn diff(&self, other: &Self) -> Vec<Diff> {
let mut ret = Vec::new();
// First, we check if other has everything that self has
for (host, tag_map) in self.hosts.iter() {
for (tag, tail) in tag_map.iter() {
for (tag, idx) in tag_map.iter() {
match other.get(*host, tag.clone()) {
// The other store is all up to date! No diff.
Some(t) if t.eq(tail) => continue,
Some(t) if t.eq(idx) => continue,
// The other store does exist, but it is either ahead or behind us. A diff regardless
// The other store does exist, and it is either ahead or behind us. A diff regardless
Some(t) => ret.push(Diff {
host: *host,
tag: tag.clone(),
tail: t,
local: Some(*idx),
remote: Some(t),
}),
// The other store does not exist :O
None => ret.push(Diff {
host: *host,
tag: tag.clone(),
tail: *tail,
local: Some(*idx),
remote: None,
}),
};
}
@ -165,7 +178,7 @@ impl RecordIndex {
// account for that!
for (host, tag_map) in other.hosts.iter() {
for (tag, tail) in tag_map.iter() {
for (tag, idx) in tag_map.iter() {
match self.get(*host, tag.clone()) {
// If we have this host/tag combo, the comparison and diff will have already happened above
Some(_) => continue,
@ -173,13 +186,22 @@ impl RecordIndex {
None => ret.push(Diff {
host: *host,
tag: tag.clone(),
tail: *tail,
remote: Some(*idx),
local: None,
}),
};
}
}
ret.sort_by(|a, b| (a.host, a.tag.clone(), a.tail).cmp(&(b.host, b.tag.clone(), b.tail)));
// Stability is a nice property to have
ret.sort_by(|a, b| {
(a.host, a.tag.clone(), a.local, a.remote).cmp(&(
b.host,
b.tag.clone(),
b.local,
b.remote,
))
});
ret
}
}
@ -204,14 +226,14 @@ impl Record<DecryptedData> {
id: &self.id,
version: &self.version,
tag: &self.tag,
host: &self.host,
parent: self.parent.as_ref(),
host: &self.host.id,
idx: &self.idx,
};
Record {
data: E::encrypt(self.data, ad, key),
id: self.id,
host: self.host,
parent: self.parent,
idx: self.idx,
timestamp: self.timestamp,
version: self.version,
tag: self.tag,
@ -225,14 +247,14 @@ impl Record<EncryptedData> {
id: &self.id,
version: &self.version,
tag: &self.tag,
host: &self.host,
parent: self.parent.as_ref(),
host: &self.host.id,
idx: &self.idx,
};
Ok(Record {
data: E::decrypt(self.data, ad, key)?,
id: self.id,
host: self.host,
parent: self.parent,
idx: self.idx,
timestamp: self.timestamp,
version: self.version,
tag: self.tag,
@ -248,14 +270,14 @@ impl Record<EncryptedData> {
id: &self.id,
version: &self.version,
tag: &self.tag,
host: &self.host,
parent: self.parent.as_ref(),
host: &self.host.id,
idx: &self.idx,
};
Ok(Record {
data: E::re_encrypt(self.data, ad, old_key, new_key)?,
id: self.id,
host: self.host,
parent: self.parent,
idx: self.idx,
timestamp: self.timestamp,
version: self.version,
tag: self.tag,
@ -265,31 +287,32 @@ impl Record<EncryptedData> {
#[cfg(test)]
mod tests {
use crate::record::HostId;
use crate::record::{Host, HostId};
use super::{DecryptedData, Diff, Record, RecordIndex};
use super::{DecryptedData, Diff, Record, RecordStatus};
use pretty_assertions::assert_eq;
fn test_record() -> Record<DecryptedData> {
Record::builder()
.host(HostId(crate::utils::uuid_v7()))
.host(Host::new(HostId(crate::utils::uuid_v7())))
.version("v1".into())
.tag(crate::utils::uuid_v7().simple().to_string())
.data(DecryptedData(vec![0, 1, 2, 3]))
.idx(0)
.build()
}
#[test]
fn record_index() {
let mut index = RecordIndex::new();
let mut index = RecordStatus::new();
let record = test_record();
index.set(record.clone());
let tail = index.get(record.host, record.tag);
let tail = index.get(record.host.id, record.tag);
assert_eq!(
record.id,
record.idx,
tail.expect("tail not in store"),
"tail in store did not match"
);
@ -297,17 +320,17 @@ mod tests {
#[test]
fn record_index_overwrite() {
let mut index = RecordIndex::new();
let mut index = RecordStatus::new();
let record = test_record();
let child = record.new_child(vec![1, 2, 3]);
let child = record.append(vec![1, 2, 3]);
index.set(record.clone());
index.set(child.clone());
let tail = index.get(record.host, record.tag);
let tail = index.get(record.host.id, record.tag);
assert_eq!(
child.id,
child.idx,
tail.expect("tail not in store"),
"tail in store did not match"
);
@ -317,8 +340,8 @@ mod tests {
fn record_index_no_diff() {
// Here, they both have the same version and should have no diff
let mut index1 = RecordIndex::new();
let mut index2 = RecordIndex::new();
let mut index1 = RecordStatus::new();
let mut index2 = RecordStatus::new();
let record1 = test_record();
@ -334,11 +357,11 @@ mod tests {
fn record_index_single_diff() {
// Here, they both have the same stores, but one is ahead by a single record
let mut index1 = RecordIndex::new();
let mut index2 = RecordIndex::new();
let mut index1 = RecordStatus::new();
let mut index2 = RecordStatus::new();
let record1 = test_record();
let record2 = record1.new_child(vec![1, 2, 3]);
let record2 = record1.append(vec![1, 2, 3]);
index1.set(record1);
index2.set(record2.clone());
@ -349,9 +372,10 @@ mod tests {
assert_eq!(
diff[0],
Diff {
host: record2.host,
host: record2.host.id,
tag: record2.tag,
tail: record2.id
remote: Some(1),
local: Some(0)
}
);
}
@ -359,14 +383,14 @@ mod tests {
#[test]
fn record_index_multi_diff() {
// A much more complex case, with a bunch more checks
let mut index1 = RecordIndex::new();
let mut index2 = RecordIndex::new();
let mut index1 = RecordStatus::new();
let mut index2 = RecordStatus::new();
let store1record1 = test_record();
let store1record2 = store1record1.new_child(vec![1, 2, 3]);
let store1record2 = store1record1.append(vec![1, 2, 3]);
let store2record1 = test_record();
let store2record2 = store2record1.new_child(vec![1, 2, 3]);
let store2record2 = store2record1.append(vec![1, 2, 3]);
let store3record1 = test_record();

View File

@ -14,7 +14,7 @@ use self::{
models::{History, NewHistory, NewSession, NewUser, Session, User},
};
use async_trait::async_trait;
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordStatus};
use serde::{de::DeserializeOwned, Serialize};
use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset};
use tracing::instrument;
@ -73,7 +73,7 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
) -> DbResult<Vec<Record<EncryptedData>>>;
// Return the tail record ID for each store, so (HostID, Tag, TailRecordID)
async fn tail_records(&self, user: &User) -> DbResult<RecordIndex>;
async fn tail_records(&self, user: &User) -> DbResult<RecordStatus>;
async fn count_history_range(&self, user: &User, range: Range<OffsetDateTime>)
-> DbResult<i64>;